Help us understand the problem. What is going on with this article?

Keras+DQNでリバーシのAI書く

More than 3 years have passed since last update.

冬休みの自由研究
デモもあるよ

動機

  • この記事を読んで学生時代同じことやってたのが懐かしくなったので
  • この記事を読んだらけっこう簡単にできそうだったので

リバーシのルールを学ばせるとかじゃなく純粋に強いAIを目指します。
リポジトリ

DQN以前の話題

盤面

盤面は6x6としました。
もともとは8x8でやっていたのですが、モデル、学習データともにそれなりの大きさになるので6x6くらいがお試しにはちょうどよかったです。
というか6x6で偶然うまく行ってこれでいっかってなった

AIを用意

学習のための棋譜を大量生成したいので適当なAIを作ります。
今日は完全ランダムとモンテカルロ木探索を作ります。
そしてできたものがこちらになります。
この記事ではMTS(n)でプレイアウト数n回のモンテカルロ木探索のAIを示すものします。

棋譜の生成

こちらで紹介されているNeural Fitted Q Iterationっぽいもので学習するので棋譜を作っておきます。
棋譜には盤面と手番の進行を記録し、とりあえず20万局分ほど用意しました。
DQNだと良質な棋譜を用意しなくていいので非常に楽。完全ランダム同士で数十分回すだけでできます。

DQNの話題

基本的なことはこちらの記事参照。

入力

DQNは入力として状態sと行動aを取ります。
今回は状態を盤面と手番、行動を打ったあとの盤面としました。
DQNでよくあるアタリのゲームでは行動以外の要因(敵の移動など)が次のステートに影響を及ぼすため、行動は単純にキー入力等にしかできません。一方リバーシでは自分の行動で完全に次のステートが定まるのでこのようにしてみました。

報酬と出力

DQNは各行動に報酬rを定めます。
今回は行動の結果終局となり、勝った場合は+1、負けた場合は-1、引き分けや終局以前の場合は0としました。
これにより中間でのスコアは必ず[-1, 1]に収まることになり、出力をtanhにできます。
またゼロサムゲームと想定することで黒のスコア=-白のスコアが成り立ち、後述するR
の推定値計算に用いることができるようになります。
ついでにQ値の伝播も見やすいです。

Rの推定値計算

rを用いて教師信号である総報酬Rを計算します(参照)。

Q_\theta(s, a) = r + \gamma*\max_{a'}Q_\theta(s',a')

教師信号の式はこれですが、リバーシの場合は黒白交互になっているため、sの手番とs'の手番が異なることでこの式を単純に書き写すだけでは欲しい値が出てくれません。
ここで先述のゼロサムの仮定を用いて、以下のように書き換えます。

Q_\theta(s, a) = \left\{
\begin{array}{ll}
 r + \gamma*\max_{a'}Q_\theta(s',a') & (sとs'で手番が同じ)\\
 r - \gamma*\max_{a'}Q_\theta(s',a') & (sとs'で手番が異なる)
\end{array}
\right.

手番が異なっている場合は、ゼロサムの仮定から相手のスコアの符号反転が自身のスコアとなるので、このような式になりました。ミニマックス法でも同じようなことするので自然に思いつきましたがあっているかはわかんない。
さらにリバーシの盤面は斜めの二軸、および180度回転した盤面が同じものと考えられるため、それら四種の中でmaxを取ることによって伝播の速度を早めようというアイディアもぶち込みました。

コードはこのへんになりますが、高速化のためにmax計算のためのpredictを一括でやってしていたり、ε-greedyのコードが混入していたりして恐ろしく読みにくく仕上がっております。
(2017/1/1追記:8路盤を検討している際に気づきましたがNFQではε-greedyは要らないようです。)

モデル

Kerasの長所を活かして以下のようなモデルにしました。

def create_model6():

    board_input = Input(shape=[6, 6])
    action_input = Input(shape=[6, 6])
    color_input = Input(shape=[1])

    model_v = Sequential([
        InputLayer([6, 6]),
        Reshape([6, 6, 1]),
        Conv2D(64, 6, 1), # 1x6
        ELU(),
        Conv2D(64, 1, 1),
        ELU(),
        Flatten()
    ])

    model_h = Sequential([
        InputLayer([6, 6]),
        Reshape([6, 6, 1]),
        Conv2D(64, 1, 6), # 6x1
        ELU(),
        Conv2D(64, 1, 1),
        ELU(),
        Flatten()
    ])

    model_dr = Sequential([
        InputLayer([6, 6]),
        Reshape([6*6, 1]),
        ZeroPadding1D(3),
        Reshape([6, 7, 1]),
        LocallyConnected2D(64, 6, 1),
        ELU(),
        LocallyConnected2D(64, 1, 1),
        ELU(),
        Flatten()
    ])

    model_dl = Sequential([
        InputLayer([6, 6]),
        Reshape([6*6, 1]),
        ZeroPadding1D(2),
        Reshape([8, 5, 1]),
        LocallyConnected2D(64, 8, 1),
        ELU(),
        LocallyConnected2D(64, 1, 1),
        ELU(),
        Flatten()
    ])

    color_model = Sequential([
        InputLayer([1]),
        Dense(256),
        ELU(),
        Dense(1024),
        ELU()
    ])

    merge_layer = merge([
        model_v(board_input),
        model_h(board_input),
        model_dl(board_input),
        model_dr(board_input),
        color_model(color_input),
        model_v(action_input),
        model_h(action_input),
        model_dl(action_input),
        model_dr(action_input),
    ], mode="concat", concat_axis=-1) 

    x = Dense(2048)(merge_layer)
    x = BatchNormalization()(x)
    x = ELU()(x)
    x = Dense(512)(x)
    x = BatchNormalization()(x)
    x = ELU()(x)
    x = Dense(128)(x)
    x = BatchNormalization()(x)
    x = ELU()(x)
    output = Dense(1, activation="tanh")(x)

    model = Model(input=[board_input, color_input, action_input], output=[output])

    adam = Adam(lr=1e-4)
    model.compile(optimizer=adam, loss="mse")

    return model

まず現在と行動後の二つの盤面ですが、これらは重み共有した四種のネットワークに入力されます。
model_v, model_hは見たまんま6x1, 1x6畳み込みの後に1x1で更に畳み込んでいます。
これにより縦一列、横一列のみを見た計算ができます。
残りのmodel_dl,model_drは斜め列の計算です。まずZeroPaddingReshapeをうまく使って斜め列を縦一列に揃え、6x1のLocallyConnected2Dにかけます。例としてmodel_drでは以下の図のように斜め列を揃えています。
スクリーンショット 2016-12-27 15.46.43.png
LocallyConnected2Dを使っているのはこれら縦列の示すものが異なるためです。ほんとは分離したいんですがサボりました。
あとは手番も適当に大きくして、全出力を連結してからの全結合層になります。

Kerasはモデル設計の自由度高いんですが、Model, SequenceLayerが別け隔てなく使えると見せかけて使えなかったりするので結構迷います。最後の全結合層のあたりがその苦しみの産物となっています。
あとこの記事InputLayerを知って捗ったのでもっといいねされるといいと思う。

学習経過確認

経過確認は二通りの方法を使いました。

実際に戦わせる

適当なイテレーション回数ごとにランダムAIと数百戦させて、勝率がどうなっているかを見ました。
実は当初実装がバグっていて、気づいてから修正して続きを始めてみたらその時点で99%超えの勝率になっていたのであんまり参考になっていません。
学習過程で勝率が最も良いモデルを残すのに使いました。

最善に近い対局のデータを与える

MTS(5000)同士で対局させた棋譜を入力にあたえてどのような値が出るか見てみます。
DQNは最善手を指し続けた際のスコアを推定するので、プレイアウト数を大きくとったMTSで最善に近い手を指し続けたスコアを出してみようというアイディアです。
学習終盤での出力は以下のようになっていました。

 0: B: [-0.02196666 -0.02230018 -0.0303785  -0.03168077]
 1: W: [-0.09994178 -0.10080269 -0.09473281 -0.09523395]
 2: B: [-0.05049421 -0.0613905  -0.05865757 -0.0515893 ]
 3: W: [-0.10307723 -0.11034401 -0.1078812  -0.11545543]
 4: B: [-0.02195507 -0.01722838 -0.00687013 -0.00869585]
 5: W: [-0.128447   -0.12213746 -0.14925914 -0.14207964]
 6: B: [-0.22107203 -0.21318831 -0.1865633  -0.18185341]
 7: W: [ 0.06178221  0.08316899  0.05964577  0.03900897]
 8: B: [-0.22808655 -0.20577276 -0.16608836 -0.18835764]
 9: W: [ 0.0713096   0.08927093  0.05861793  0.0417437 ]
10: B: [-0.25330806 -0.2539109  -0.22571792 -0.20664638]
11: W: [ 0.05460038  0.07116394  0.07360842  0.01370821]
12: B: [-0.24553944 -0.22967289 -0.22448763 -0.2255006 ]
13: W: [ 0.08806669  0.11078909  0.11152182  0.0635582 ]
14: B: [-0.45493153 -0.45988095 -0.46441916 -0.41323128]
15: W: [ 0.37723741  0.37333494  0.36738792  0.32914931]
16: B: [-0.46461144 -0.44231483 -0.42101949 -0.38848421]
17: W: [ 0.17253573  0.21936594  0.20173906  0.16213408]
18: B: [-0.50654161 -0.52144158 -0.50978303 -0.51221204]
19: W: [ 0.42853072  0.47864962  0.42829457  0.39107552]
20: B: [-0.68370593 -0.69842124 -0.69973147 -0.70182347]
21: W: [ 0.39964256  0.50497115  0.51084852  0.49810672]
22: B: [-0.71973562 -0.70337516 -0.62852156 -0.67204589]
23: W: [ 0.37252051  0.49400631  0.34360072  0.35528785]
24: B: [-0.81641883 -0.84098679 -0.79452062 -0.82713711]
25: W: [ 0.80729294  0.81642532  0.79480326  0.78571904]
26: B: [-0.916417   -0.9247427  -0.90268892 -0.89786631]
27: W: [ 0.93264407  0.93837488  0.93335259  0.9382773 ]
28: W: [ 0.93243909  0.93243909  0.9183923   0.9183923 ]

上から1手目、2手目、……のスコアで、横方向は同じ状態、行動の斜め軸反転、180度回転です。
終局時のスコアが大きくなっていること、手番の入れ替わりで符号が反転していることなどが読み取れます。横方向の値が大体同じくらいになっているのもそれらの対等性を示しているようで良。
序盤のほうはQ値が伝播していないのかそもそも不定な状態なのか。割引率γはあまり効いていないようですね……。

学習結果

以上のような感じで3日ほど回しました。
メイドインアビスの新刊を買って帰ってきたら出力が全部nanになっており冷や汗。なんよりんなだらけになってほしい。
幸い残っていたモデルを使って勝率を評価してみました。

  • ランダム相手で995/1000
  • MTS(30)相手で95/100
  • MTS(100)相手で6/10
  • MTS(1000)相手で2/10

うーん微妙?ランダム相手は後半戦術だけでも優れていれば勝てるはずなので、後半から伝播するDQNだと勝率上がりやすいかなーとは思っていましたが、後半のほうが強いはずのMTS相手にも勝てたのはちょっと驚きました。中盤までに優勢を取れたのでしょうか。
nanがでなければもっと先まで行けて強くなりそうな感じはします。

デモ

前からGCP使ってみたかったので無料枠使ってみました。
落ちてなければ年末年始の間くらいは動いてると思います。
http://104.155.201.107:8888/
下のセレクトボックスでDQNを選ぶとDQNです。

考察

モデルについて

リバーシなので縦横斜めを考慮できるようなモデルを設計しました。
挟むルールを学べたかは知りませんが開放度くらいは見れてそう?
実際のところ行動の価値は行動後の盤面だけでだいたい分かりそうなので、状態のほうの盤面は入れなくても良かった気がします。どんな状態からでも行動の結果おなじ状態になるのならそれらのスコアは同じはずなので。
適当に作ったモデルでだいたいうまく行ってしまった感じなのでいろいろと検討は足りてないです。でも学習一回に数日かかるのは結構きつい。

スコアについて

終盤のほうのスコアは手番ごとに反転しているので学習が進んでいるようには見えました。
一方でスコアの絶対値はγ^nの形になるのが理想形なはずなのにそれとはだいぶ離れているようでした。
そもそも真のRの値はどうなるのか……後手有利らしいので一手目から黒がマイナス白がプラスになるのでしょうか。

TODO

  • float32でやったのでfloat64にしてnan回避、より先を目指す。
  • もっと簡単なモデルでもやってみる。
  • 8x8でやる。

That's it folks!

t-ae
qoncept
リアルタイム画像認識を専門にした会社です。近年はスポーツにおける認識技術の応用に力を入れています。
https://qoncept.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした