第一夜は秘書問題を取り上げたが、どうも内容が難しいので、第二夜は自前のゲームモデルを学習するをテーマにしている以下のサイトのウワン的解説をしたいと思う。
※今回もほぼ参考リンク先のモデルのとおりである
【参考】
・OpenAI Gym で自前の環境をつくる
###やったこと
(1)参考サイトのモデルを動かす
(2)モデルを変更して学習させる
###(1)参考サイトのモデルを動かす
自前モデルは以下のとおりである。
※あとで変更できそうな部分を見たいので一つずつ解説する
以下が今回のゲームの一番肝心な部分で勇者が歩くフィールドを定義しています。
※このエリアは適当に変更できそう
MAX_STEPS = 100はふらふらできる最大歩数になります
class MyEnv(gym.Env):
metadata = {'render.modes': ['human', 'ansi']}
FIELD_TYPES = [
'S', # 0: スタート
'G', # 1: ゴール
'~', # 2: 芝生(敵の現れる確率1/10)
'w', # 3: 森(敵の現れる確率1/2)
'=', # 4: 毒沼(1step毎に1のダメージ, 敵の現れる確率1/2)
'A', # 5: 山(歩けない)
'Y', # 6: 勇者
]
MAP = np.array([
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], # "AAAAAAAAAAAA"
[5, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], # "AA~~~~~~~~~~"
[5, 5, 2, 0, 2, 2, 5, 2, 2, 4, 2, 2], # "AA~S~~A~~=~~"
[5, 2, 2, 2, 2, 2, 5, 5, 4, 4, 2, 2], # "A~~~~~AA==~~"
[2, 2, 3, 3, 3, 3, 5, 5, 2, 2, 3, 3], # "~~wwwwAA~~ww"
[2, 3, 3, 3, 3, 5, 2, 2, 1, 2, 2, 3], # "~wwwwA~~G~~w"
[2, 2, 2, 2, 2, 2, 4, 4, 2, 2, 2, 2], # "~~~~~~==~~~~"
])
MAX_STEPS = 100
この環境が呼び出されるとき最初に設定すべき定数が設定されます。
特に移動は東西南北の4次元
観察領域を定義しています。
報酬の大きさは[-1., 100.]
def __init__(self):
super().__init__()
# action_space, observation_space, reward_range を設定する
self.action_space = gym.spaces.Discrete(4) # 東西南北
self.observation_space = gym.spaces.Box(
low=0,
high=len(self.FIELD_TYPES),
shape=self.MAP.shape
)
self.reward_range = [-1., 100.]
self._reset()
スタート時の初期値設定、スタート位置、ゴール位置、ダメージ0、歩数0
def _reset(self):
# 諸々の変数を初期化する
self.pos = self._find_pos('S')[0]
self.goal = self._find_pos('G')[0]
self.done = False
self.damage = 0
self.steps = 0
return self._observe()
1ステップ進める処理、東西南北へ進む
移動できるエリアかチェック
マップに勇者の位置を重ねて返す
報酬、ダメージをゲットする
def _step(self, action):
# 1ステップ進める処理を記述。戻り値は observation, reward, done(ゲーム終了したか), info(追加の情報の辞書)
if action == 0:
next_pos = self.pos + [0, 1]
elif action == 1:
next_pos = self.pos + [0, -1]
elif action == 2:
next_pos = self.pos + [1, 0]
elif action == 3:
next_pos = self.pos + [-1, 0]
if self._is_movable(next_pos):
self.pos = next_pos
moved = True
else:
moved = False
observation = self._observe()
reward = self._get_reward(self.pos, moved)
self.damage += self._get_damage(self.pos)
self.done = self._is_done()
return observation, reward, self.done, {}
表示する、または外部出力する
def _render(self, mode='human', close=False):
# human の場合はコンソールに出力。ansiの場合は StringIO を返す
outfile = StringIO() if mode == 'ansi' else sys.stdout
outfile.write('\n'.join(' '.join(
self.FIELD_TYPES[elem] for elem in row
) for row in self._observe()
) + '\n'
)
return outfile
今回は、closeとseedは使っていません。
※ちなみに秘書問題では両方とも利用していました
(これは別途記事にしようと思います)
def _close(self):
pass
def _seed(self, seed=None):
pass
以下でゴールまで行くと報酬を計算します。
ただし、一歩ごとにreturn -1されます。
def _get_reward(self, pos, moved):
# 報酬を返す。報酬の与え方が難しいが、ここでは
# - ゴールにたどり着くと 100 ポイント
# - ダメージはゴール時にまとめて計算
# - 1ステップごとに-1ポイント(できるだけ短いステップでゴールにたどり着きたい)
# とした
if moved and (self.goal == pos).all():
return max(100 - self.damage, 0)
else:
return -1
ダメージを計算します。
※ここでゲームの難易度を調整できる
def _get_damage(self, pos):
# ダメージの計算
field_type = self.FIELD_TYPES[self.MAP[tuple(pos)]]
if field_type == 'S':
return 0
elif field_type == 'G':
return 0
elif field_type == '~':
return 10 if np.random.random() < 1/10. else 0
elif field_type == 'w':
return 10 if np.random.random() < 1/2. else 0
elif field_type == '=':
return 11 if np.random.random() < 1/2. else 1
マップの中で立ち入り禁止区域にいない場所かどうかのチェック
def _is_movable(self, pos):
# マップの中にいるか、歩けない場所にいないか
return (
0 <= pos[0] < self.MAP.shape[0]
and 0 <= pos[1] < self.MAP.shape[1]
and self.FIELD_TYPES[self.MAP[tuple(pos)]] != 'A'
)
マップに勇者の位置を重ねて返す
def _observe(self):
# マップに勇者の位置を重ねて返す
observation = self.MAP.copy()
observation[tuple(self.pos)] = self.FIELD_TYPES.index('Y')
return observation
ゴールにたどり着いたか、MAX_STEPを超えたかのチェック
def _is_done(self):
# 今回は最大で self.MAX_STEPS までとした
if (self.pos == self.goal).all():
return True
elif self.steps > self.MAX_STEPS:
return True
else:
return False
たぶん、FIELD_TYPEのマップを返している
def _find_pos(self, field_type):
return np.array(list(zip(*np.where(
self.MAP == self.FIELD_TYPES.index(field_type)
))))
ということで、どんなゲームか分かったと思います。
そこで、このゲームをmyenv/env.pyとして、保存します。
そして、gym に登録するために以下のコードを使います。
from gym.envs.registration import register
register(
id='myenv-v0',
entry_point='myenv.env:MyEnv'
)
このid='myenv-v0'でこのゲーム環境の呼び出しに使います。
※「ここで、id は<環境名>-v<バージョン>という形式である必要があります。」だそうです
このあと参考の解説では少し省略があって、全体としてCartPoleのコードを上記の新しいゲーム環境呼び出しに変更するだけで動きます。
つまり、ENV_NAME = 'CartPole-v0'をENV_NAME = 'myenv-v0'に変更するだけで動きます。
※ここがすごいと思います
import numpy as np
import gym
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory
ENV_NAME = 'CartPole-v0'
ENV_NAME = 'myenv-v0'
# Get the environment and extract the number of actions.
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n
# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())
# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])
# Okay, now it's time to learn something! We visualize the training here for show, but this
# slows down training quite a lot. You can always safely abort the training prematurely using
# Ctrl + C.
dqn.fit(env, nb_steps=50000, visualize=True, verbose=2)
# After training is done, we save the final weights.
dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True)
# Finally, evaluate our algorithm for 5 episodes.
dqn.test(env, nb_episodes=5, visualize=True)
上記のコードでうまく動きました。
###(2)モデルを変更して学習させる
今回は、ゲームのエリアを拡大したり、難易度を変えて動かしてみました。
例として以下の様にエリアをほぼ倍にしてやってみました。
class MyEnv(gym.Env):
metadata = {'render.modes': ['human', 'ansi']}
FIELD_TYPES = [
'S', # 0: スタート
'G', # 1: ゴール
'~', # 2: 芝生(敵の現れる確率1/10)
'w', # 3: 森(敵の現れる確率1/2→3/4)
'=', # 4: 毒沼(1step毎に1のダメージ, 敵の現れる確率1/2)
'A', # 5: 山(歩けない)
'Y', # 6: 勇者
]
MAP = np.array([
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], # "AAAAAAAAAAAA"
[5, 5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], # "AA~~~~~~~~~~"
[5, 5, 2, 0, 2, 2, 5, 2, 2, 4, 2, 2], # "AA~S~~A~~=~~"
[5, 2, 2, 2, 2, 2, 5, 5, 4, 4, 2, 2], # "A~~~~~AA==~~"
[2, 2, 3, 3, 3, 3, 5, 5, 2, 2, 3, 3], # "~~wwwwAA~~ww"
[2, 3, 3, 3, 3, 5, 2, 2, 2, 2, 2, 3], # "~wwwwA~~~~~w"
[2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4], # "~~~~~~~====="
[5, 5, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4], # "AA~~~~======"
[5, 5, 2, 2, 2, 2, 5, 2, 2, 4, 2, 2], # "AA~~~~A~~=~~"
[5, 2, 2, 2, 2, 2, 5, 5, 4, 4, 2, 2], # "A~~~~~AA==~~"
[2, 2, 3, 3, 3, 3, 5, 5, 2, 2, 3, 3], # "~~wwwwAA~~ww"
[2, 3, 3, 3, 3, 5, 2, 2, 1, 2, 2, 3], # "~wwwwA~~G~~w"
[2, 2, 2, 2, 2, 2, 4, 4, 2, 2, 2, 2], # "~~~~~~==~~~~"
])
MAX_STEPS = 100
このゲームだと、当初のモデルでは収束できず、以下の様にhiddenを84層や256層にして収束しました。
# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(256)) #16
model.add(Activation('relu'))
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear')) #linear
print(model.summary())
また、イテレーションもnb_steps=100000として収束させました。
dqn.fit(env, nb_steps=100000, visualize=False, verbose=2) #50000, True, 2
また、表示がうるさくて終了するころには標準出力から消えてしまうので、visualize=Falseとしました。
ということで、以下のおまけのようなテスト結果が得られます。
###まとめ
・自前ゲームで遊んでみた
・自前ゲームの登録の仕方を理解した
・他のゲームを同じように登録したい
###おまけ
勇者YがSから徐々に南下してGに到達しているのがわかります。
※最短ルートを見つけたようです
done, took 456.127 seconds
Testing for 5 episodes ...
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ Y ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w Y w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w Y w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ Y ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ Y ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ Y ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ Y ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w Y w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w Y w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ Y ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ Y ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ Y = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ G ~ ~ w
~ ~ ~ ~ ~ ~ Y = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A Y ~ G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ Y G ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
A A A A A A A A A A A A
A A ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
A A ~ S ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ ~ ~ ~ w
~ ~ ~ ~ ~ ~ ~ = = = = =
A A ~ ~ ~ ~ = = = = = =
A A ~ ~ ~ ~ A ~ ~ = ~ ~
A ~ ~ ~ ~ ~ A A = = ~ ~
~ ~ w w w w A A ~ ~ w w
~ w w w w A ~ ~ Y ~ ~ w
~ ~ ~ ~ ~ ~ = = ~ ~ ~ ~
Episode 5: reward: 64.000, steps: 16