LoginSignup
77
69

More than 5 years have passed since last update.

DQNプニキにホームランを打たせたい

Posted at

Deep Q-Network (DQN)

Deep Learning + 強化学習を使って行動パターンを学習させるDeep Q-Networkは面白いと思い、実装してみました。少しだけ結果が出たので公開します。

ソースコードは以下で公開しています。
https://github.com/dsanno/chainer-dqn

DQNについては以下が詳しいです。
DQNの生い立ち + Deep Q-NetworkをChainerで書いた

学習対象

今回DQNに学習させるゲームはくまのプーさんのホームランダービー!です。(リンク先で音が鳴るので注意)
プニキことプーのアニキに多くのホームランを打たせることを目標にします。

このゲームを選んだ理由は以下の通りです

  • ルールが単純
    ピッチャーが投げる球を丸太で打ち返し、規定数ホームランを打てばクリアとなります
  • 報酬の判定が簡単
    結果が「ホームラン」「ストライク」といった画像で示されるので、画像に対応した報酬を与えてやります

あとは人間にとって非常に難しいという理由もあったのですが、難しくなるところまで到達できませんでした。(参考: ニコニコ大百科)

開発環境

  • Windows 10
  • Chainer 1.5.1
  • 画面キャプチャと操作にPyAutoGUIを使用
  • GeForce GTX970
    CPUだと学習に時間がかかりすぎてうまく動作しません

ニューラルネットワーク構成

  • 入力は150 x 112 x 3chのピクセルデータです。ゲーム画面のサイズは600 x 450pxなのですが、キャプチャした画像を縦横1/4ずつ縮小して入力しています。
  • 出力は行動の評価値のベクトルです。
    ベクトルの長さは行動パターンの数に一致します。
    今回はポインタのY座標は固定でX座標を33段階で変動させるようにしました。
    ボタンはONとOFFの2状態があり計66個の行動パターンとなります。
  • 中間レイヤはConvolutional Neural Networkが3層、LSTMが1層、Fully Connected Layer 1層となっています。

プレイについて

  • タイトル画面、ピッチャー選択画面等では、決められた位置をクリックするようにしました。
  • 対戦中は、100msごとに画面をキャプチャして入力画像として使います。(以下、100ms間隔の単位を「フレーム」と呼びます)評価値の最も高い行動パターンを次のフレームの行動とします。
  • 以下の状態を判定したときに報酬を与えました。それ以外のフレームでの報酬は0です。 ルール上ファールとヒットはストライクと同じ失敗扱いですが、空振りよりは球に当たったほうが良い行動と言えるので、ストライクより少しだけ報酬を高くしています。
    • ホームラン: 100
    • ストライク: -100
    • ファール: -90
    • ヒット: -80
  • 以下の3パターンでランダムな行動をとらせるようにしました。 今回のゲームは1フレームだけランダムな行動をとらせてもあまり意味がないので、10~30フレームの範囲で連続したフレームの間ランダムな行動をとらせました。
    • ポインタの位置だけランダム
    • ボタンの状態だけランダム
    • ポインタの位置もボタンの状態もランダム

学習について

  • 学習はプレイと別スレッドで並行して行いました。
  • LSTMの学習を行うために、以下のように連続したフレームの入力値を与えてパラメータ更新するようにしました。 mは4~32の範囲で徐々に大きくしました。
    • nをランダムに選択する
    • フレームnの入力値を入力
    • フレームn + 1の入力値を入力してフレームn + 1の最大評価値を求める。それを使ってフレームnのパラメータを更新
    • ...
    • フレームn + mの入力値を入力してフレームn + mの最大評価値を求める。それを使ってフレームn + m - 1のパラメータ更新

設定

  • ミニバッチ数: 64
  • DQNのgamma: 0.98
  • optimizer: AdaDelta
    文献[2]でAdaDeltaを使っていたので採用しました。
    パラメータはデフォルトのrho=0.95, eps=1e-06としています。
  • chainer.optimizer.GradientClipping()を使ってL2 normを最大0.1に制限
    勾配を制限しないとQ値が大きくなりすぎ学習が安定しませんでした。

学習結果

10時間ほどステージ1で学習を続けたところ、ステージ1は大体クリアできるようになりました。
プレイ動画を以下にアップロードしました。
動画撮影時にはランダムな行動をとらないようにしています。
https://youtu.be/J4V6ZveYFUM

ほかのステージも含めて学習を行わせたところ、ステージ3をまぐれでクリアするところまでは確認できました。

参考文献

77
69
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
77
69