4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Q学習で自作の迷路を解く

Posted at

#概要
強化学習の一つであるQ-learningを用いて,自作の簡単な迷路を解いてみました.実装はpythonで行いました.

#強化学習
強化学習とは,教師付き学習,教師なし学習と並ぶ機械学習の手法の一つです.簡単に説明すると与えられた「環境」における「価値」を最大化するように「エージェント」を学習させます.強化学習において一般的に状態 s・行動 a・報酬 rの3要素を用います.以下各要素の説明です.

状態 s : 「環境」がどのような状態になっているかを示す.
行動 a : 与えられた「環境」において「エージェント」がとる行動を示す.
報酬 r : 「エージェント」が行動を行なった結果得られる報酬を示す.

「ある状態 sにおいて行動 aを行なった時の価値」がわかれば,価値を最大化させるようにエージェントに学習させることができます.

#Q学習
Q学習法とは,前節の「ある状態 sにおいて行動 aを行なった時の価値」のことをQ値と呼びこの値を更新していくことで学習させていく方法です.状態 $s_t$において行動 $a_t$を行うときのQ値を$Q(s_t,a_t)$と表します.このQ値をもらうことのできる報酬 rにしたがって更新していくわけですが,実際に報酬をもらうことができる直接の行動だけではなくその一つ前の行動,さらに一つ前の行動・・においても報酬を得ることに近づいているという点にて価値があると言えます.そこでその行動をとることによって最終的に得られる報酬の期待値を組み込み更新することでそれぞれの時点でどの程度価値があるのかを計算します.具体的な式としては以下の通りです.
$$
Q(s_t,a_t)=Q(s_t,a_t)+\alpha(r_{t+1}+\gamma \text{max}Q(s_{t+1},a_{t+1})-Q(s_t,a_t))
$$

$Q(s_t,a_t)$を更新する際に次のt+1でもらえる報酬$r_{t+1}$と状態$s_{t+1}$における最大のQ値$\text{max}Q(s_{t+1},a_{t+1})$を用います.$\alpha$と$\gamma$はハイパーパラメータになります.

#環境
python 3.5.2

#実装(迷路部分)
迷路部分の実装は以下のようになります.迷路の構造としては$5 \times 5 $のマス目状となっています.(詳しくはコード中のコメント参照)

tresure.py
import numpy as np

class Game():
    def __init__(self):
        self.x_size = 5
        self.y_size = 5
        self.init_position = np.array([4,0])
        self.game_board = np.zeros((self.x_size,self.y_size))
        """
        迷路の構造
        0 0 0 0 G
        0 D 0 D D
        0 0 D 0 0
        D 0 0 0 0
        S 0 D 0 0

        G = 1
        D = -1
        """
        self.game_board[0][4] = 1
        self.game_board[1][1] = -1
        self.game_board[3][0] = -1
        self.game_board[2][2] = -1
        self.game_board[1][3] = -1
        self.game_board[1][4] = -1
        self.game_board[4][2] = -1
    
    def judge(self,x,y):
        #ゲーム終了か判定
        if self.game_board[x][y] == 0:
            return 0
        elif self.game_board[x][y] == 1:
            return 1
        else:
            return -1
    
    def check_size(self,x,y):
        #迷路からはみ出していないか判定
        if x < 0 or y < 0 or x >= self.x_size or y >= self.y_size:
            return 0
        else:
            return 1
    
    def move(self,position,direction):
        if direction == 'up':
            position[0] -= 1
        elif direction == 'down':
            position[0] += 1
        elif direction == 'right':
            position[1] -= 1
        elif direction == 'left':
            position[1] += 1
        
        return position

ルールとしては,プレイヤーはSのマスからスタートし,各状態ごとに上下左右のどこかのマスに移動できるとします.Gにたどり着けばゴールであり報酬がもらえ,Dに移動してしまうとゲームオーバーであり罰としてマイナスの報酬が与えられます.

#実装(学習部分)
コードは以下のようです

学習部分コード全体
train.py
import numpy as np
from treasure import Game
import copy
import matplotlib.pyplot as plt

#board proscess
game = Game()
game_board = game.game_board
print(game_board)
game_direction = ['up','down','left','right']

def get_action(Q_table,dire,epsilon,x,y):

    #random choice
    if np.random.rand() < epsilon:
        return np.random.choice(dire)

    else:
        return dire[np.argmax(Q_table[x,y])]

def game_init():
    #init process
    game = Game()
    position= game.init_position
    
    return game,position

def game_reward(game,position):
    #result print
    if game.judge(position[0],position[1]) == 1:
        #print('You got a goal!')
        return 1
    elif game.judge(position[0],position[1]) == -1:
        #print('You died..')
        return -1
    else:
        return 0

def game_step(game,position,Q_table,dire,epsilon):
    while(True):
        pos = copy.deepcopy(position)
        direction = get_action(Q_table,dire,epsilon,pos[0],pos[1])
        index_dire = dire.index(direction)
        move_position = game.move(pos,direction)
        if game.check_size(pos[0],pos[1]):
            break
    reward = game_reward(game,move_position)

    return move_position,index_dire,reward

def Q_koushin(Q,state_x,state_y,state_a,s_next_x,state_next_y,alpha,reward,gamma):
    Q[state_x,state_y,state_a] += alpha*(reward + gamma*np.max(Q[s_next_x,state_next_y]) - Q[state_x,state_y,state_a])
    return Q[state_x,state_y,state_a]


if __name__ == '__main__':
    #hyper parameter
    epsilon = 0.01
    alpha = 0.5
    gamma = 0.8
    Q_table = np.zeros((game_board.shape[0],game_board.shape[1],len(game_direction)))
    episode = 200
    sucess = []

    for i in range(episode):
        game,position = game_init()
        while(not game.judge(position[0],position[1])):
            next_position,dire,reward = game_step(game,position,Q_table,game_direction,epsilon)
            Q_table[position[0],position[1],dire] = Q_koushin(Q_table,position[0],position[1],dire,next_position[0],next_position[1],alpha,reward,gamma)
            position = next_position
        
        if i % 10 == 0:
            count = 0
            heatmap = np.zeros((game_board.shape[0],game_board.shape[1]))
            for j in range(100):
                game,position = game_init()
                while(not game.judge(position[0],position[1])):
                    next_position,dire,reward = game_step(game,position,Q_table,game_direction,epsilon)
                    position = next_position
                    heatmap[next_position[0]][next_position[1]] += 1
                if reward == 1:
                    count += 1
            sucess.append(count)
            print(i)
            print(heatmap)
    print(sucess)

以下部分ごとに説明していきます.

epsilon-greedy法による方策

train.py
def get_action(Q_table,dire,epsilon,x,y):
    #random choice
    if np.random.rand() < epsilon:
        return np.random.choice(dire)

    else:
        return dire[np.argmax(Q_table[x,y])]

$\epsilon$の確率でランダムに移動し,それ以外はQ値が一番高い選択をするようにしました.

報酬関数の設定

train.py
def game_reward(game,position):
    #result print
    if game.judge(position[0],position[1]) == 1:
        #print('You got a goal!')
        return 1
    elif game.judge(position[0],position[1]) == -1:
        #print('You died..')
        return -1
    else:
        return 0

ゴールした時は報酬として1,ゲームオーバーの時は罰として-1,それ以外の時は0としました.

Q値の更新

train.py
def Q_koushin(Q,state_x,state_y,state_a,s_next_x,state_next_y,alpha,reward,gamma):
    Q[state_x,state_y,state_a] += alpha*(reward + gamma*np.max(Q[s_next_x,state_next_y]) - Q[state_x,state_y,state_a])
    return Q[state_x,state_y,state_a]

計算式に従ってQ値を更新します.

プレイヤー移動

train.py
def game_step(game,position,Q_table,dire,epsilon):
    while(True):
        pos = copy.deepcopy(position)
        direction = get_action(Q_table,dire,epsilon,pos[0],pos[1])
        index_dire = dire.index(direction)
        move_position = game.move(pos,direction)
        if game.check_size(pos[0],pos[1]):
            break
    reward = game_reward(game,move_position)

    return move_position,index_dire,reward

ゲーム中の一回の移動を行います.迷路からはみ出していないかの判定を行い,大丈夫であれば移動先の位置,移動した方向,得られる報酬を返します.pythonではリストは参照渡しされるため,元のリストが変更されないように深いコピーをとりました.

学習部分

train.py
if __name__ == '__main__':
    #hyper parameter
    epsilon = 0.01
    alpha = 0.5
    gamma = 0.8
    Q_table = np.zeros((game_board.shape[0],game_board.shape[1],len(game_direction)))
    episode = 200
    sucess = []

    for i in range(episode):
        game,position = game_init()
        while(not game.judge(position[0],position[1])):
            next_position,dire,reward = game_step(game,position,Q_table,game_direction,epsilon)
            Q_table[position[0],position[1],dire] = Q_koushin(Q_table,position[0],position[1],dire,next_position[0],next_position[1],alpha,reward,gamma)
            position = next_position
        
        if (i+1) % 20 == 0:
            count = 0
            heatmap = np.zeros((game_board.shape[0],game_board.shape[1]))
            for j in range(100):
                game,position = game_init()
                while(not game.judge(position[0],position[1])):
                    next_position,dire,reward = game_step(game,position,Q_table,game_direction,epsilon)
                    position = next_position
                    heatmap[next_position[0]][next_position[1]] += 1
                if reward == 1:
                    count += 1
            sucess.append(count)
            print('%d回時点' %(i+1))
            print(heatmap)

Q_tableとしては(迷路のマス目)$\times$(上下左右の移動方向)で$5\times5\times4$の行列としました.ゴールするかゲームオーバーとなるまで一回の移動ごとにQ値を更新しました.また20回ゲームを終えるたびにその時のQ_tableを用いて,ゲームを100回行い,何度ゴールできたかとどこを通ったのかを調査しました.

#結果
結果は以下のようになりました.

#迷路構造図
[[ 0.  0.  0.  0.  G.]
 [ 0. -1.  0. -1. -1.]
 [ 0.  0. -1.  0.  0.]
 [-1.  0.  0.  0.  0.]
 [ S.  0. -1.  0.  0.]]
20回時点
[[4.500e+01 1.800e+01 1.500e+01 6.000e+00 3.000e+00]
 [4.000e+01 1.400e+01 5.000e+00 3.000e+00 5.000e+00]
 [9.000e+00 4.148e+03 8.000e+00 8.270e+02 5.000e+00]
 [6.700e+01 4.172e+03 1.100e+01 8.350e+02 3.000e+00]
 [0.000e+00 4.900e+01 0.000e+00 2.000e+00 0.000e+00]]
40回時点
[[1.600e+01 1.000e+01 5.000e+00 2.000e+00 2.000e+00]
 [1.300e+01 1.000e+01 2.000e+00 3.000e+00 3.000e+00]
 [7.000e+00 4.105e+03 9.000e+00 6.400e+02 3.000e+00]
 [7.300e+01 4.132e+03 7.000e+00 6.450e+02 2.000e+00]
 [0.000e+00 4.900e+01 0.000e+00 2.000e+00 0.000e+00]]
60回時点
[[1.900e+01 8.000e+00 3.000e+00 3.000e+00 3.000e+00]
 [1.700e+01 1.600e+01 0.000e+00 2.000e+00 4.000e+00]
 [6.000e+00 3.754e+03 8.000e+00 2.120e+02 4.000e+00]
 [6.700e+01 3.783e+03 6.000e+00 2.140e+02 2.000e+00]
 [0.000e+00 5.600e+01 0.000e+00 0.000e+00 0.000e+00]]
80回時点
[[1.100e+01 6.000e+00 6.000e+00 6.000e+00 6.000e+00]
 [1.100e+01 1.200e+01 0.000e+00 6.000e+00 4.000e+00]
 [6.000e+00 3.544e+03 1.300e+01 2.069e+03 1.107e+03]
 [5.900e+01 3.571e+03 1.500e+01 2.080e+03 1.109e+03]
 [0.000e+00 5.700e+01 0.000e+00 3.000e+00 2.000e+00]]
100回時点
[[8.000e+00 8.000e+00 8.000e+00 8.000e+00 8.000e+00]
 [8.000e+00 1.600e+01 0.000e+00 5.000e+00 2.000e+00]
 [8.000e+00 4.783e+03 1.500e+01 3.083e+03 1.225e+03]
 [5.400e+01 4.824e+03 2.300e+01 3.100e+03 1.224e+03]
 [0.000e+00 7.100e+01 0.000e+00 1.000e+01 1.000e+00]]
120回時点
[[1.100e+01 1.100e+01 1.100e+01 1.100e+01 1.100e+01]
 [1.100e+01 6.000e+00 0.000e+00 6.000e+00 3.000e+00]
 [1.100e+01 4.215e+03 1.500e+01 2.403e+03 1.660e+03]
 [5.900e+01 4.251e+03 1.900e+01 2.416e+03 1.665e+03]
 [1.000e+00 6.400e+01 0.000e+00 4.000e+00 4.000e+00]]
140回時点
[[ 99. 100.  98.  96.  96.]
 [100.   4.   1.   0.   0.]
 [101. 101.   0.   0.   0.]
 [  0. 102.   0.   0.   0.]
 [  0. 102.   0.   0.   0.]]
160回時点
[[ 96.  95.  96.  96.  95.]
 [ 97.   2.   0.   0.   0.]
 [ 97. 100.   1.   0.   0.]
 [  0.  99.   0.   0.   0.]
 [  0. 100.   2.   0.   0.]]
180回時点
[[101. 100. 100. 100.  99.]
 [101.   0.   0.   1.   0.]
 [100. 101.   0.   0.   0.]
 [  0. 101.   0.   0.   0.]
 [  0. 100.   0.   0.   0.]]
200回時点
[[ 99.  99. 101. 100.  99.]
 [100.   1.   1.   0.   0.]
 [100. 100.   0.   0.   0.]
 [  0. 101.   0.   0.   0.]
 [  0. 101.   0.   0.   0.]]

#ゴール回数(100回中)
[3, 2, 3, 6, 8, 11, 96, 95, 99, 99]

結果を見てみると,学習の進んでいない序盤の方はスタート近くの外れマスでゲームオーバーになることが多かったですが,学習を重ね140回後にはほとんど正解のルートを見つけ出せているように見えます.

#まとめ
Q-learningを用いて,自作の簡単な迷路を解いてみました.実装の練習にもなり良い機会でした.今後は強化学習の別の手法も用いより複雑なゲームを行なったりしてみたいと思います.

#参考文献
Platinum Data Blog by BrainPad 強化学習入門~これから強化学習を学びたい人の基礎知識~ 2017
https://blog.brainpad.co.jp/entry/2017/02/24/121500

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?