発端: 自分は囲碁ファンでAIに興味をもつSEだから、アルファ碁が出た時点で非常にその裏をしりたかった。AlphoZeroを出た後に色々な情報が入手できたので、簡単な五目並べのゲームをAlphaZeroのアルゴリズムで実装してみました(他のブログにする予定)。その過程で強化学習ももう少しやってみようかと思って、このプログラムを書いた。
結果: Chromeがオフラインとなると、Dinoとの画面が出る(直接URL欄にchrome://dino/を入力してもOK)。スペースを押すと始まる。そのゲームを機械学習を使って、ある程度の訓練をすれば、普通に人間よりよく跳べるようになた。 ただし、ゲーム設定で簡単化した(スピード固定、鴉なし)
準備:
- Chrome のインストール
- ChromeDriverのダウンロード ここから
- Pythonのインストール (Anaconda推奨 )
- 必要なパッケージをインストール(Keras,selenium,OpenCV,hickle)
conda install -c conda-forge selenium
pip install opencv-python
conda install -c anaconda pillow
pip install hickle
conda install -c conda-forge keras
強化学習とは:
特に一から強化学習に関する説明をするつもりはないが、こちらのサイトを強くお勧めする。ただしSEのため非常にわかりやすく丁寧に説明されている半面数式の理論根拠があまり記述されていないので、気になる方はそのサイトに羅列されている参照資料も読んだ方がいいかもしれない。
アーキテクチャ:
OpenAIのGymならEnvを用意いてくれる、しかも各Stepも数値化されていて本当に楽だが、ここは全部ゼロから自前で用意しないといけない。大きくは以下のようなアーキテクチャとなる。
0. ゲームを進行(最初はランダムに、ある程度データを蓄積してからはDQNの出力に基づく)しながら画面のハードコピーとその時のReward(GameOverかどうかによって計算方法が違う)を取り、訓練用データバッファに保存する。スッテプ毎に無作為にデータバッファーから訓練データを取り出し、モデルの訓練を行う。
1. JS/selenium。 そのゲーム自体はJavaScriptで作成されているので、PythonからEnv(環境)とのやりとりは避けられないので、どうしてもPythonとJavaScriptの間に橋が必要となる。幸いにもseleniumとのツールがあって助かった。以前seleniumを使ってWebのテストをしていたが、ここでこの役目を果たしてくれるとは。ただし、seleniumとChromeの間も橋が必要、それはChromeDriverとなる。
ソースを眺めてみよ。
chrome_options = Options()
chrome_options.add_argument("disable-infobars")
chrome_options.add_argument("--mute-audio")
self._driver = webdriver.Chrome(executable_path=driver, chrome_options=chrome_options) # Chromedriverのパスを指定してChromeを開く
self._driver.set_window_position(x=10, y=10)
self._driver.get("chrome://dino") # ゲームのアドレスへ移動
self._driver.execute_script("Runner.config.ACCELERATION=0") # 訓練しやすいため、速度は一定に保つ
self._driver.execute_script("document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'")
2. opencv
画面のハードコピーのサイズが大きいのでそのまま使って訓練するとスピード的に耐えられない。ゲームをしなら訓練するので、遅くなるとFPSが落ちる、訓練の質も大幅に落ちる。また、このゲームはとても簡単で、サイズが大きい割に、特徴(feature)が少ない。そこでopencv を使ってサイズを落して、色も白黒にする。
COL_SIZE = 80
ROW_SIZE = 80
def process_img(image):
image = cv2.resize(image, (COL_SIZE, ROW_SIZE))
image = image[:300, :500] # Crop ROI
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # RGB to Grey Scale
return image
3. cnn 画像処理と言うと、CNNが主流となっている、このプログラムも例外ではない。とりわけ画面のハードコピーをそのまま入力として使っている。もっと正確に言うと、四枚連続ハードコピーのセットが入力となっている、そうでないと、予測に肝心なスピードの情報が抜けってしまうため。ちなみに、本来ならLSTMの方がいいかもしれないが、ここは簡単のため、画像のセットを使う。ちなみに、簡単のため、ゲームのスピードも固定にしている。
input_shape = (img_rows, img_cols, 4)
model = Sequential()
model.add(
Conv2D(32, kernel_size=(4, 4),
padding='same',
strides=(2, 2),
activation='relu',
input_shape=input_shape)) # 80*80*4
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), strides=(2, 2), activation='relu',padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (2, 2), strides=(1, 1), activation='relu', padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(512,activation='relu'))
model.add(Dense(ACTIONS))
adam = Adam(lr=LEARNING_RATE)
model.compile(loss='mse', optimizer=adam)
4. 強化学習のアルゴリズムはDeep e-Greedy、Q-Learningを使う(DQN)。もっとも普通な強化学習のアルゴリズムなので、説明は割愛とさせてもらう。
targets[i] = model.predict(state_t) # predicted q values from current step
Q_value = model.predict(state_t1) # predict q values for next step
if gameover:
targets[i, action_t] = reward_t
else:
targets[i, action_t] = reward_t + GAMMA * np.max(Q_value) # with future reward
訓練結果: GPUなしのMACで平均10+fps,2時間の訓練でスコアは500前後、5時間の訓練でスコアは1500,7時間で3000以上に達した。
注意点:
1. 前述の通り、画面コピーのセットを入力データとして使っていて、そこからスピード情報が得られる。なので、FPSが低くと訓練の効果が落ちる。自分のMACとWindowsPCの場合、GPUなしでも10ぐらいで十分でしたが、処理を重くして4以下にすると危ない。同じ理由で訓練のデータと中間結果をディスクに保存するときのみゲームを一時停止さくてから再開させる(言い換えれば、訓練時は一々ゲームのPause/Resumeはしない)
2. 画像セットを保存するために、わかりやすいdequeなどのCollectionを使っていたが、FPSが半分ぐらいまで落ちるので、今の書き方にした
self._image_stack = np.append(ob, self._image_stack[:, :, :, :3], axis=3)
3. CNNの設計にも、Strideを小さくすると、訓練の効果がよくなるように見えるが、FPSが落ちる、そこのバランスは試しながら調整する必要がある。
4. 途中結果の保存にはPickleを使っているが、サイズが大きいせいか、たまにエラーが発生するので、訓練データのBufferを保存する際に、Pickleの代わりにhickelを使うことにした。ただし1Gぐらいのファイルであるだけで、SDDを使っても、3分ぐらいかかる。保存もロードも気長く待つ必要。
5. e-GreedyのRandom Action率、バッチサイズ、Rewardの設計などは調整が可能ですが、今のパラメーターでままいける。
環境作りに色々面倒臭いこともあって、バージョンによって衝突もよくあるから、Docker Imageも作って見た。
ここです
このプログラムのソースはここ
また、アイディアも相当部分のソース(特にJavaScript部分)もここを参考しています。