LoginSignup
117
113

More than 5 years have passed since last update.

ChainerでDQN。強化学習を三目並べでいろいろ試してみた。(Deep Q Network、Q-Learning、モンテカルロ)

Last updated at Posted at 2016-09-28

初めてのQiita投稿です。Ridge-iという会社で機械学習を中心としたコンサル~開発をしてます。

強化学習について教える機会が出てきたので、三目並べをベースに

  • モンテカルロ
  • Q-Learning
  • Deep Q Network (いわゆるDQN)

についてJupyter(ipython) で実装して教材を作りました。

ちなみに強いプレーヤー同士ならば、ドローだけが繰り返されるはずです。(WarGameの有名なやつですね。)

結論としては

  • モンテカルロ 実装簡単。100回試行位でほぼ負けなし(50回くらいだと時々負ける)
  • Q-Learning  更新式の設計に気を遣う。3目並べ程度なら10万対戦で最強
  • Deep Q Network 色々な落とし穴が多数。最適なトポロジーがわからん。Leaky Reluにするまで最弱。教え方ミスると何も学習しない。などなど

Q-Learningまでの実装は1-2時間位しかかからなかったのに、DQNがきちんと学習するようにするだけで6時間くらいかかりました。(泣)

学習の効率化という面ではきちんと盤面の反転と回転をするとよいのですが、今回は分かりやすさ重視で抜きました。
あと、きちんとインターフェース設計と変数の隠匿化をしなかったので時々違反してます。そのあたりは無視してください。

Githubはこちら。
https://github.com/narisan25/TTT-RL

理論的な詳細を知りたい場合は別途連絡ください。また指摘・改良・アドバイス、大Welcomeです。

三目並べの設計

これはざざっとやります。良くするべき点たくさんありますが、今回は分かりやすさと早さ(動けばいい)を重視

盤面

0-9の配列にXなら1、ブランクなら0、○なら-1を入れます。

EMPTY=0
PLAYER_X=1
PLAYER_O=-1
MARKS={PLAYER_X:"X",PLAYER_O:"O",EMPTY:" "}
DRAW=2

class TTTBoard:

    def __init__(self,board=None):
        if board==None:
            self.board = []
            for i in range(9):self.board.append(EMPTY)
        else:
            self.board=board
        self.winner=None

    def get_possible_pos(self):
        pos=[]
        for i in range(9):
            if self.board[i]==EMPTY:
                pos.append(i)
        return pos

    def print_board(self):
        tempboard=[]
        for i in self.board:
            tempboard.append(MARKS[i])
        row = ' {} | {} | {} '
        hr = '\n-----------\n'
        print((row + hr + row + hr + row).format(*tempboard))



    def check_winner(self):
        win_cond = ((1,2,3),(4,5,6),(7,8,9),(1,4,7),(2,5,8),(3,6,9),(1,5,9),(3,5,7))
        for each in win_cond:
            if self.board[each[0]-1] == self.board[each[1]-1]  == self.board[each[2]-1]:
                if self.board[each[0]-1]!=EMPTY:
                    self.winner=self.board[each[0]-1]
                    return self.winner
        return None

    def check_draw(self):
        if len(self.get_possible_pos())==0 and self.winner is None:
            self.winner=DRAW
            return DRAW
        return None

    def move(self,pos,player):
        if self.board[pos]== EMPTY:
            self.board[pos]=player
        else:
            self.winner=-1*player
        self.check_winner()
        self.check_draw()

    def clone(self):
        return TTTBoard(self.board.copy())

    def switch_player(self):
        if self.player_turn == self.player_x:
            self.player_turn=self.player_o
        else:
            self.player_turn=self.player_x

ゲームの進行役

Observer+Mediatorモデルにします。これでPlayerが直接盤面をいじらないようにしつつ、ターン・勝敗・表示の管理などを行います。

class TTT_GameOrganizer:

act_turn=0
winner=None

    def __init__(self,px,po,nplay=1,showBoard=True,showResult=True,stat=100):
        self.player_x=px
        self.player_o=po
        self.nwon={px.myturn:0,po.myturn:0,DRAW:0}
        self.nplay=nplay
        self.players=(self.player_x,self.player_o)
        self.board=None
        self.disp=showBoard
        self.showResult=showResult
        self.player_turn=self.players[random.randrange(2)]
        self.nplayed=0
        self.stat=stat

    def progress(self):
        while self.nplayed<self.nplay:
            self.board=TTTBoard()
            while self.board.winner==None:
                if self.disp:print("Turn is "+self.player_turn.name)
                act=self.player_turn.act(self.board)
                self.board.move(act,self.player_turn.myturn)
                if self.disp:self.board.print_board()

                if self.board.winner != None:
                    # notice every player that game ends
                    for i in self.players:
                        i.getGameResult(self.board) 
                    if self.board.winner == DRAW:
                        if self.showResult:print ("Draw Game")
                    elif self.board.winner == self.player_turn.myturn:
                        out = "Winner : " + self.player_turn.name
                        if self.showResult: print(out)
                    else:
                        print ("Invalid Move!")
                    self.nwon[self.board.winner]+=1
                else:
                    self.switch_player()
                    #Notice other player that the game is going
                    self.player_turn.getGameResult(self.board)

            self.nplayed+=1
            if self.nplayed%self.stat==0 or self.nplayed==self.nplay:
                        print(self.player_x.name+":"+str(self.nwon[self.player_x.myturn])+","+self.player_o.name+":"+str(self.nwon[self.player_o.myturn])
             +",DRAW:"+str(self.nwon[DRAW]))


    def switch_player(self):
        if self.player_turn == self.player_x:
            self.player_turn=self.player_o
        else:
            self.player_turn=self.player_x

色々なプレイヤーを作成

ランダム君

置けるところに置くだけのシンプルなやつです。

import random


class PlayerRandom:
    def __init__(self,turn):
        self.name="Random"
        self.myturn=turn

    def act(self,board):
        acts=board.get_possible_pos()
        i=random.randrange(len(acts))
        return acts[i]


    def getGameResult(self,board):
        pass

人間

class PlayerHuman:
    def __init__(self,turn):
        self.name="Human"
        self.myturn=turn

    def act(self,board):
        valid = False
        while not valid:
            try:
                act = input("Where would you like to place " + str(self.myturn) + " (1-9)? ")
                act = int(act)
                #if act >= 1 and act <= 9 and board.board[act-1]==EMPTY:
                if act >= 1 and act <= 9:
                    valid=True
                    return act-1
                else:
                    print ("That is not a valid move! Please try again.")
            except Exception as e:
                    print (act +  "is not a valid move! Please try again.")
        return act

    def getGameResult(self,board):
        if board.winner is not None and board.winner!=self.myturn and board.winner!=DRAW:
            print("I lost...")

まずは人間とランダムで戦わせてみます。負けたら恥。

def Human_vs_Random():

    p1=PlayerHuman(PLAYER_X)
    p2=PlayerRandom(PLAYER_O)
    game=TTT_GameOrganizer(p1,p2)
    game.progress()

実行結果

Human_vs_Random()
Turn is Random
   |   |   
-----------
   |   | O 
-----------
   |   |   
Turn is Human
Where would you like to place 1 (1-9)? 1
 X |   |   
-----------
   |   | O 
-----------
   |   |   
Turn is Random
 X |   |   
-----------
 O |   | O 
-----------
   |   |   
Turn is Human
Where would you like to place 1 (1-9)? 2
 X | X |   
-----------
 O |   | O 
-----------
   |   |   
Turn is Random
 X | X |   
-----------
 O | O | O 
-----------
   |   |   
I lost...
Winner : Random
Human:0,Random:1,DRAW:0

なんと、ログに残っていたのは気を抜いて負けた例でした。

改良ランダム君 (Alpha Random)

さすがにただのランダムだとつまらないので、次の手で勝てるところがあればきちんと指すようにします。(なければランダム)

class PlayerAlphaRandom:


    def __init__(self,turn,name="AlphaRandom"):
        self.name=name
        self.myturn=turn

    def getGameResult(self,winner):
        pass

    def act(self,board):
        acts=board.get_possible_pos()
        #see only next winnable act
        for act in acts:
            tempboard=board.clone()
            tempboard.move(act,self.myturn)
            # check if win
            if tempboard.winner==self.myturn:
                #print ("Check mate")
                return act
        i=random.randrange(len(acts))
        return acts[i]

モンテカルロ

さて、ここからが強化学習領域に入ってきます。
モンテカルロでは、あるポリシーに沿って一定回数をゲーム終了まで試し、その勝率を見ることでいい手を探索します。

モンテカルロの重要な点は、「一定回数」の数と、「あるポリシー」によって強さが左右されること。

「あるポリシー」には、先ほどの改良ランダム君を使いました。
「一定回数」は試したところ100回を超えれば、まず間違った手は打たないようです。50回くらいだと、乱数次第で時々負けたりします。

ここから実績を動的計画法に持ってったり、モンテカルト木探索(MCTS)に派生して強いポリシーと弱いポリシーを組み合わせる方法もありますが、今回は言及しません。

class PlayerMC:
    def __init__(self,turn,name="MC"):
        self.name=name
        self.myturn=turn

    def getGameResult(self,winner):
        pass

    def win_or_rand(self,board,turn):
        acts=board.get_possible_pos()
        #see only next winnable act
        for act in acts:
            tempboard=board.clone()
            tempboard.move(act,turn)
            # check if win
            if tempboard.winner==turn:
                return act
        i=random.randrange(len(acts))
        return acts[i]

    def trial(self,score,board,act):
        tempboard=board.clone()
        tempboard.move(act,self.myturn)
        tempturn=self.myturn
        while tempboard.winner is None:
            tempturn=tempturn*-1
            tempboard.move(self.win_or_rand(tempboard,tempturn),tempturn)

        if tempboard.winner==self.myturn:
            score[act]+=1
        elif tempboard.winner==DRAW:
            pass
        else:
            score[act]-=1


    def getGameResult(self,board):
        pass


    def act(self,board):
        acts=board.get_possible_pos()
        scores={}
        n=50
        for act in acts:
            scores[act]=0
            for i in range(n):
                #print("Try"+str(i))
                self.trial(scores,board,act)

            #print(scores)
            scores[act]/=n

        max_score=max(scores.values())
        for act, v in scores.items():
            if v == max_score:
                #print(str(act)+"="+str(v))
                return act

モンテカルロ同士で戦わせる

p1=PlayerMC(PLAYER_X,"M1")
p2=PlayerMC(PLAYER_O,"M2")
game=TTT_GameOrganizer(p1,p2,10,False)
game.progress()

Draw Game
Winner : M2
Draw Game
Draw Game
Draw Game
Draw Game
Draw Game
Winner : M2
Draw Game
Draw Game
M1:0,M2:2,DRAW:8

大体ドローですね。気が向いたら試行回数を100回にして試してみてください。全部ドローになると思います。

Q-Learning

さて、強化学習といえば、Q-Learning。
ゲーム途中は報酬 0 、勝ったら+1、負けたら-1、引き分けは0です。
それをベースにQ(s,a)のDictionaryを以下の更新式で更新していきます。

Q(s,a) = Q(s,a) + alpha (reward + gammma* max(Q(s',a')- Q(s,a)) 

重要なのはε-greedyというコンセプトでランダムの打ち手を推奨すること。でないとすぐにlocal optimumに陥ります。
最初のQの値も重要です。でないと、すぐ探す気をなくしちゃいます。わがままな子です。

ゲーム終了でのmax(Q)をどうしようとか、ゲーム途中の更新を忘れたりと、意外とはまりました。

class PlayerQL:
    def __init__(self,turn,name="QL",e=0.2,alpha=0.3):
        self.name=name
        self.myturn=turn
        self.q={} #set of s,a
        self.e=e
        self.alpha=alpha
        self.gamma=0.9
        self.last_move=None
        self.last_board=None
        self.totalgamecount=0


    def policy(self,board):
        self.last_board=board.clone()
        acts=board.get_possible_pos()
        #Explore sometimes
        if random.random() < (self.e/(self.totalgamecount//10000+1)):
                i=random.randrange(len(acts))
                return acts[i]
        qs = [self.getQ(tuple(self.last_board.board),act) for act in acts]
        maxQ= max(qs)

        if qs.count(maxQ) > 1:
            # more than 1 best option; choose among them randomly
            best_options = [i for i in range(len(acts)) if qs[i] == maxQ]
            i = random.choice(best_options)
        else:
            i = qs.index(maxQ)

        self.last_move = acts[i]
        return acts[i]

    def getQ(self, state, act):
        # encourage exploration; "optimistic" 1.0 initial values
        if self.q.get((state, act)) is None:
            self.q[(state, act)] = 1
        return self.q.get((state, act))

    def getGameResult(self,board):
        r=0
        if self.last_move is not None:
            if board.winner is None:
                self.learn(self.last_board,self.last_move, 0, board)
                pass
            else:
                if board.winner == self.myturn:
                    self.learn(self.last_board,self.last_move, 1, board)
                elif board.winner !=DRAW:
                    self.learn(self.last_board,self.last_move, -1, board)
                else:
                    self.learn(self.last_board,self.last_move, 0, board)
                self.totalgamecount+=1
                self.last_move=None
                self.last_board=None

    def learn(self,s,a,r,fs):
        pQ=self.getQ(tuple(s.board),a)
        if fs.winner is not None:
            maxQnew=0
        else:
            maxQnew=max([self.getQ(tuple(fs.board),act) for act in fs.get_possible_pos()])
        self.q[(tuple(s.board),a)]=pQ+self.alpha*((r+self.gamma*maxQnew)-pQ)
        #print (str(s.board)+"with "+str(a)+" is updated from "+str(pQ)+" refs MAXQ="+str(maxQnew)+":"+str(r))
        #print(self.q)


    def act(self,board):
        return self.policy(board)

Q-Learning 同士で10万回 戦わせます。

pQ=PlayerQL(PLAYER_O,"QL1")
p2=PlayerQL(PLAYER_X,"QL2")
game=TTT_GameOrganizer(pQ,p2,100000,False,False,10000)
game.progress()

QL1:4371,QL2:4436,DRAW:1193
QL1:8328,QL2:8456,DRAW:3216
QL1:11903,QL2:11952,DRAW:6145
QL1:14268,QL2:14221,DRAW:11511
QL1:15221,QL2:15099,DRAW:19680
QL1:15730,QL2:15667,DRAW:28603
QL1:16136,QL2:16090,DRAW:37774
QL1:16489,QL2:16439,DRAW:47072
QL1:16832,QL2:16791,DRAW:56377
QL1:17128,QL2:17121,DRAW:65751

大体ドローに収束していますね。εで勝敗が分かれているのがわかります。

モンテカルロ君と戦わせる

εをゼロにして、ランダム要素をなくして最良の打ち手で攻めてみます

pQ.e=0
p2=PlayerMC(PLAYER_X,"M1")
game=TTT_GameOrganizer(pQ,p2,100,False,False,10)
game.progress()
QL1:1,M1:0,DRAW:9
QL1:1,M1:0,DRAW:19
QL1:2,M1:0,DRAW:28
QL1:2,M1:0,DRAW:38
QL1:3,M1:0,DRAW:47
QL1:4,M1:0,DRAW:56
QL1:5,M1:0,DRAW:65
QL1:5,M1:0,DRAW:75
QL1:6,M1:0,DRAW:84
QL1:6,M1:0,DRAW:94

モンテカルロ君も時々負けてしまいました。やはりモンテカルロ君は乱数の気まぐれから逃れられません。
ちなみにこの状態のQ-Learning君と戦うと、かなり痛い目にあいます。

DQN (Deep Q Network)

さて今回の本題。DQNです。
Q-LearningのネックはQ(s,a)のテーブルが膨大になること。三目並べならいいけど、碁ではSもaも莫大なメモリとなってしまいます。
Q-NetworkではQ(s)により各aの期待報酬を出して、argmaxを取ることになります。
Deep Q Networkなら 「勝手に勝つ特徴をみつけちゃうんじゃない!?」、1000回位の対戦でいいんじゃない?とたかをくくってました。

・・・・が、ドはまりしました。

全然強くならない。しかもHyper Parameterが多すぎる。Hidden Layer何個にすればいいの?Relu? 誤差はSMEでいいの?OptimizerはAdamかSGD? 正直Grid Searchがほしかったです。でもそもそも教師がないのでロスの測定もできないんですよね。精度は、相手と戦わせて初めてわかる。

さらに、学習させていても、Variable Linkが切れて全然学習しなかったり。ほんとハマります。
しかもGPUのないマシンでいろいろやってたので遅い。。。GPUないと学習に5-10分くらいかかったりします。

import chainer

from chainer import Function, gradient_check, Variable, optimizers, serializers, utils
import chainer.functions as F
import chainer.links as L
import numpy as np
from chainer import computational_graph as c

# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_in, n_units, n_out):
        super(MLP, self).__init__(
            l1=L.Linear(n_in, n_units),  # first layer
            l2=L.Linear(n_units, n_units),  # second layer
            l3=L.Linear(n_units, n_units),  # Third layer
            l4=L.Linear(n_units, n_out),  # output layer
        )

    def __call__(self, x, t=None, train=False):
        h = F.leaky_relu(self.l1(x))
        h = F.leaky_relu(self.l2(h))
        h = F.leaky_relu(self.l3(h))
        h = self.l4(h)

        if train:
            return F.mean_squared_error(h,t)
        else:
            return h

    def get(self,x):
        # input x as float, output float
        return self.predict(Variable(np.array([x]).astype(np.float32).reshape(1,1))).data[0][0]


class DQNPlayer:
    def __init__(self, turn,name="DQN",e=1,dispPred=False):
        self.name=name
        self.myturn=turn
        self.model = MLP(9, 162,9)
        self.optimizer = optimizers.SGD()
        self.optimizer.setup(self.model)
        self.e=e
        self.gamma=0.95
        self.dispPred=dispPred
        self.last_move=None
        self.last_board=None
        self.last_pred=None
        self.totalgamecount=0
        self.rwin,self.rlose,self.rdraw,self.rmiss=1,-1,0,-1.5


    def act(self,board):

        self.last_board=board.clone()
        x=np.array([board.board],dtype=np.float32).astype(np.float32)

        pred=self.model(x)
        if self.dispPred:print(pred.data)
        self.last_pred=pred.data[0,:]
        act=np.argmax(pred.data,axis=1)
        if self.e > 0.2: #decrement epsilon over time
            self.e -= 1/(20000)
        if random.random() < self.e:
            acts=board.get_possible_pos()
            i=random.randrange(len(acts))
            act=acts[i]
        i=0
        while board.board[act]!=EMPTY:
            #print("Wrong Act "+str(board.board)+" with "+str(act))
            self.learn(self.last_board,act, -1, self.last_board)
            x=np.array([board.board],dtype=np.float32).astype(np.float32)
            pred=self.model(x)
            #print(pred.data)
            act=np.argmax(pred.data,axis=1)
            i+=1
            if i>10:
                print("Exceed Pos Find"+str(board.board)+" with "+str(act))
                acts=self.last_board.get_possible_pos()
                act=acts[random.randrange(len(acts))]

        self.last_move=act
        #self.last_pred=pred.data[0,:]
        return act

    def getGameResult(self,board):
        r=0
        if self.last_move is not None:
            if board.winner is None:
                self.learn(self.last_board,self.last_move, 0, board)
                pass
            else:
                if board.board== self.last_board.board:            
                    self.learn(self.last_board,self.last_move, self.rmiss, board)
                elif board.winner == self.myturn:
                    self.learn(self.last_board,self.last_move, self.rwin, board)
                elif board.winner !=DRAW:
                    self.learn(self.last_board,self.last_move, self.rlose, board)
                else:                    #DRAW
                    self.learn(self.last_board,self.last_move, self.rdraw, board)
                self.totalgamecount+=1
                self.last_move=None
                self.last_board=None
                self.last_pred=None

    def learn(self,s,a,r,fs):
        if fs.winner is not None:
            maxQnew=0
        else:
            x=np.array([fs.board],dtype=np.float32).astype(np.float32)
            maxQnew=np.max(self.model(x).data[0])
        update=r+self.gamma*maxQnew
        #print(('Prev Board:{} ,ACT:{}, Next Board:{}, Get Reward {}, Update {}').format(s.board,a,fs.board,r,update))
        #print(('PREV:{}').format(self.last_pred))
        self.last_pred[a]=update

        x=np.array([s.board],dtype=np.float32).astype(np.float32)
        t=np.array([self.last_pred],dtype=np.float32).astype(np.float32)
        self.model.zerograds()
        loss=self.model(x,t,train=True)
        loss.backward()
        self.optimizer.update()
        #print(('Updated:{}').format(self.model(x).data))
        #print (str(s.board)+"with "+str(a)+" is updated from "+str(pQ)+" refs MAXQ="+str(maxQnew)+":"+str(r))
        #print(self.q)

ソース長いですねぇ。。。たぶん試行錯誤の形跡が残っている気がしますが、きれいにしていません。

改良ランダムと対戦

これでまずはざっくり打ち方と勝ち方を覚えさせます。

pDQ=DQNPlayer(PLAYER_X)
p2=PlayerAlphaRandom(PLAYER_O)
game=TTT_GameOrganizer(pDQ,p2,20000,False,False,1000)
game.progress()
DQN:206,AlphaRandom:727,DRAW:67
DQN:468,AlphaRandom:1406,DRAW:126
DQN:861,AlphaRandom:1959,DRAW:180
DQN:1458,AlphaRandom:2315,DRAW:227
DQN:2185,AlphaRandom:2560,DRAW:255
DQN:3022,AlphaRandom:2704,DRAW:274
DQN:3832,AlphaRandom:2856,DRAW:312
DQN:4632,AlphaRandom:3023,DRAW:345
DQN:5481,AlphaRandom:3153,DRAW:366
DQN:6326,AlphaRandom:3280,DRAW:394
DQN:7181,AlphaRandom:3400,DRAW:419
DQN:8032,AlphaRandom:3522,DRAW:446
DQN:8902,AlphaRandom:3618,DRAW:480
DQN:9791,AlphaRandom:3705,DRAW:504
DQN:10673,AlphaRandom:3793,DRAW:534
DQN:11545,AlphaRandom:3893,DRAW:562
DQN:12420,AlphaRandom:3986,DRAW:594
DQN:13300,AlphaRandom:4074,DRAW:626
DQN:14183,AlphaRandom:4158,DRAW:659
DQN:15058,AlphaRandom:4246,DRAW:696

だいぶ強くなりましたね。さすがに改良ランダム君には勝てるようで。

本命 Q-Learningと対決

pDQ.e=1
game=TTT_GameOrganizer(pDQ,pQ,30000,False,False,1000)
game.progress()

DQN:4,QL1:436,DRAW:560
DQN:6,QL1:790,DRAW:1204
DQN:6,QL1:1135,DRAW:1859
DQN:6,QL1:1472,DRAW:2522
DQN:6,QL1:1801,DRAW:3193
DQN:6,QL1:2123,DRAW:3871
DQN:6,QL1:2439,DRAW:4555
DQN:6,QL1:2777,DRAW:5217
DQN:7,QL1:3128,DRAW:5865
DQN:9,QL1:3648,DRAW:6343
DQN:13,QL1:4132,DRAW:6855
DQN:13,QL1:4606,DRAW:7381
DQN:13,QL1:5087,DRAW:7900
DQN:14,QL1:5536,DRAW:8450
DQN:14,QL1:6011,DRAW:8975
DQN:14,QL1:6496,DRAW:9490
DQN:14,QL1:7394,DRAW:9592
DQN:14,QL1:7925,DRAW:10061
DQN:16,QL1:8357,DRAW:10627
DQN:16,QL1:8777,DRAW:11207

ようやく強くなったのかな。。。なんだか微妙な強さですが。。
このログではまだ負けが多いですが、10万回対戦位すると大体ドローに落ち着きました。

最後に人間と対戦。εはゼロ

pDQ.e=0
p2=PlayerHuman(PLAYER_O)
game=TTT_GameOrganizer(pDQ,p2,2)
game.progress()
Turn is Human
Where would you like to place -1 (1-9)? 1
//anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:69: VisibleDeprecationWarning: converting an array with ndim > 0 to an index will result in an error in the future

 O |   |   
-----------
   |   |   
-----------
   |   |   
Turn is DQN
 O |   |   
-----------
   | X |   
-----------
   |   |   
Turn is Human
Where would you like to place -1 (1-9)? 2
 O | O |   
-----------
   | X |   
-----------
   |   |   
Turn is DQN
 O | O | X 
-----------
   | X |   
-----------
   |   |   
Turn is Human
Where would you like to place -1 (1-9)? 7
 O | O | X 
-----------
   | X |   
-----------
 O |   |   
Turn is DQN
 O | O | X 
-----------
 X | X |   
-----------
 O |   |   
Turn is Human
Where would you like to place -1 (1-9)? 6
 O | O | X 
-----------
 X | X | O 
-----------
 O |   |   
Turn is DQN
 O | O | X 
-----------
 X | X | O 
-----------
 O |   | X 
Turn is Human
Where would you like to place -1 (1-9)? 7
 O | O | X 
-----------
 X | X | O 
-----------
 O |   | X 
I lost...
Invalid Move!
Turn is Human
Where would you like to place -1 (1-9)? 2
   | O |   
-----------
   |   |   
-----------
   |   |   
Turn is DQN
   | O |   
-----------
   | X |   
-----------
   |   |   
Turn is Human
Where would you like to place -1 (1-9)? 1
 O | O |   
-----------
   | X |   
-----------
   |   |   
Turn is DQN
 O | O | X 
-----------
   | X |   
-----------
   |   |   
Turn is Human
Where would you like to place -1 (1-9)? 7
 O | O | X 
-----------
   | X |   
-----------
 O |   |   
Turn is DQN
 O | O | X 
-----------
 X | X |   
-----------
 O |   |   
Turn is Human
Where would you like to place -1 (1-9)? 6
 O | O | X 
-----------
 X | X | O 
-----------
 O |   |   
Turn is DQN
 O | O | X 
-----------
 X | X | O 
-----------
 O |   | X 
Turn is Human
Where would you like to place -1 (1-9)? 8
 O | O | X 
-----------
 X | X | O 
-----------
 O | O | X 
Draw Game
DQN:1,Human:0,DRAW:1

まぁそれらしい打ち手にはなっていました。

結論

Q-Learningは実装はかなり楽でしたが、Deep Q Networkでは試行錯誤しないと決められないことが非常に多かったです。
最初Reluで実装しても全然強くならず、Leaky Reluにしてよくなりました。たぶん盤面に-1が入るのでReluではうまく表現できなかったのかな。
あと、なぜかAdamではだめでSGDではうまくいきました。不思議です。Adam最強と勝手に思っていました。
さらにAlphaRandomと対戦させないで、最初からQ-Learning(強いやつ)と戦わさせると、すぐにやる気をなくして、負けっぱなしになったりします。最初は弱いやつと戦わせて、やる気を出させる必要があるようです。

また、何より苦労したのが

  • どうやって、DQNに打ってはいけない手を教えるか

という点です。

今回は既にコマのあるところに打ったら、強制的にマイナス報酬を与えて10回学習という、アプローチをいれてうまくいきました。打ってよい場所を入力に入れればいいのかな?

疑問や改善したいところがたくさんあるのでもしわかる人いたら教えてください。
アドバイスや指摘、どんどんウェルカムです。

それにしても、三目並べなら2-3時間くらいあればできるだろう、とたかをくくっていたのに、思いのほか時間がかかりました。
知っている、実装できる(コード書ける)、できる(きちんと動く) にはとんでもないギャップがあることを痛感。ぜひ皆さんも今回のソースはあまり見ずに、ご自身で実装されることをお勧めします。

さて、時間のあるときにオセロバージョンを作るかなぁ。。。
それかきちんとAlphaGoの3NN構成を試してみようか、次の課題に悩み中です。

117
113
10

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
117
113