物体を一ヶ所に集めるゲームで強化学習
強化学習に関して勉強する機会があり、試しに実装してみたくなったので実装してみた。
本当はきれいなコードで記述したかったが、状態価値やらなんやらを考えてると手が動かなくなったので、
とりあえず手を動かそうと思って作ってみたら、かなり雑な仕上がりになってしまった。
そのうちgymみたいな記述方式に直すかもしれないです。
強化学習の概要についてはweb上にいくらでも記事があるのでここでは触れません。
ゲーム内容
- エージェントがフィールド内の物体を一ヶ所に集めたらクリア
- エージェントはフィールド内を動き回ることができる
- 物体のあるマスに移動するとエージェントは物体を持った状態となる
- 物体は一度に一つしか持つことができない
- 特定のマスでしか荷物を降ろすことができない
- 荷物を特定のマスで降ろすことによって報酬を獲得できる
マップとプレイヤーの配置についてはこのサイトを参考にしました。
【Pygame】RPGのマップ作成 | 西住工房@技術雑記
Q学習
Q学習でQ値を以下のような更新式で更新していきます。
Q(s_t,a_t)\leftarrow Q(s_t,a_t)+\eta\bigl(r_{t}+\gamma\max_aQ(s_{t+1},a)-Q(s_t,a_t)\bigr)
詳しい理論に関してはここで下手に説明するよりもいくらでも良記事があるので
気になる方は各自ググってください。
ソースコード
import pygame
from pygame.locals import *
import sys
import numpy as np
import os
from datetime import datetime
SCR_RECT = Rect(0, 0, 640, 480) # 画面サイズ
# マップのクラス
class Map:
def __init__(self, map_n=0):
# マップデータ
self.map = np.loadtxt("./field/field{}.txt".format(map_n)).astype(np.int32)
self.row,self.col = len(self.map), len(self.map[0]) # マップの行数,列数を取得
self.imgs = [None] * 16 # マップチップ
self.msize = 32 # 1マスの大きさ[px]
# マップの描画
def draw(self, screen):
for i in range(self.row):
for j in range(self.col):
screen.blit(self.imgs[self.map[i][j]], (j*self.msize,i*self.msize))
# 画像の読み込み
def load_img(filename, colorkey=None):
img = pygame.image.load(filename)
img = img.convert()
if colorkey is not None:
if colorkey == -1:
colorkey = img.get_at((0,0))
img.set_colorkey(colorkey, RLEACCEL)
return img
def main(max_step, episode, map_n, greed_e, r1, r2, r3, def_r, test, fps):
save_path = "./log/{0}_map{1}".format(datetime.now().strftime("%y%m%d%H"), map_n)
if test != True:
os.makedirs(save_path)
Q = np.zeros([15,20,2,4])
# learning rate
lr = 0.8
# 割引率
y = 0.95
r = 0
for ep in range(episode):
step = 0
s_buc = 0
s_buc_next = 0
act = 0
pygame.init()
screen = pygame.display.set_mode(SCR_RECT.size)
map = Map(map_n)
map.imgs[0] = pygame.image.load("./image/sand.png")
map.imgs[1] = pygame.image.load("./image/white.png")
map.imgs[2] = pygame.image.load("./image/track.png")
map.imgs[3] = pygame.image.load("./image/wheel.png")
map.imgs[4] = pygame.image.load("./image/wheel_bucket.png")
goal_id = np.where(map.map==2)
FPSCLOCK = pygame.time.Clock()
rAll = 0
while step < max_step:
wheel_id = np.where(map.map>=3)
wheel = map.map[wheel_id[0], wheel_id[1]]
s_row, s_col = wheel_id
#print(wheel_id)
if wheel == 4:
s_buc = 1
elif wheel == 3:
s_buc = 0
else:
print("error!")
pygame.quit()
sys.exit()
step += 1
sys.stdout.write("\rstep: {}/{} | episode: {}/{} | sum reward: {}".format(step,max_step,ep,episode, rAll))
sys.stdout.flush()
if np.random.uniform() < greed_e:
act = np.random.randint(0, 4)
else:
act = np.argmax(Q[s_row,s_col,s_buc,:] + np.random.randn(1, 4)*(1.0/(step + 1)))
map.draw(screen)
# イベント処理
for event in pygame.event.get():
# 終了用のイベント処理
if event.type == QUIT: # 閉じるボタンが押されたとき
pygame.quit()
sys.exit()
if event.type == KEYDOWN: # キーを押したとき
if event.key == K_ESCAPE: # Escキーが押されたとき
pygame.quit()
sys.exit()
if 0 == act:
if wheel_id[1] > 0:
# 貨物状態
if wheel == 4:
# 物体にぶつかる場合
if map.map[wheel_id[0], wheel_id[1]-1] == 0:
# print("積み下ろす必要がある")
r = r3
else:
map.map[wheel_id[0], wheel_id[1]] = 1
# 積み下ろし
if map.map[wheel_id[0], wheel_id[1]-1] == 2:
map.map[wheel_id[0], wheel_id[1]-1] = 3
r = r2
s_buc_next = 0
# 非目的地
else:
map.map[wheel_id[0], wheel_id[1]-1] = 4
s_buc_next = 1
r = def_r
# 空状態
if wheel == 3:
map.map[wheel_id[0], wheel_id[1]] = 1
# 積み上げ
if map.map[wheel_id[0], wheel_id[1]-1] == 0:
map.map[wheel_id[0], wheel_id[1]-1] = 4
r = r1
s_buc_next = 1
# ゴール
elif map.map[wheel_id[0], wheel_id[1]-1] == 2:
map.map[wheel_id[0], wheel_id[1]-1] = 3
r = r3
s_buc_next = 0
# 無し
else:
map.map[wheel_id[0], wheel_id[1]-1] = 3
s_buc_next = 0
r = def_r
else:
# print("over the workspace")
r = r3
if 1 == act:
if wheel_id[1] < (map.col-1):
if wheel == 4:
if map.map[wheel_id[0], wheel_id[1]+1] == 0:
# print("積み下ろす必要がある")
r = r3
else:
map.map[wheel_id[0], wheel_id[1]] = 1
if map.map[wheel_id[0], wheel_id[1]+1] == 2:
map.map[wheel_id[0], wheel_id[1]+1] = 3
r = r2
s_buc_next = 0
else:
map.map[wheel_id[0], wheel_id[1]+1] = 4
s_buc_next = 1
r = def_r
if wheel == 3:
map.map[wheel_id[0], wheel_id[1]] = 1
if map.map[wheel_id[0], wheel_id[1]+1] == 0:
map.map[wheel_id[0], wheel_id[1]+1] = 4
r = r1
s_buc_next = 1
elif map.map[wheel_id[0], wheel_id[1]+1] == 2:
map.map[wheel_id[0], wheel_id[1]+1] = 3
r = r3
s_buc_next = 0
else:
map.map[wheel_id[0], wheel_id[1]+1] = 3
s_buc_next = 0
r = def_r
else:
# print("over the workspace")
r = r3
if 2 == act:
if wheel_id[0] > 0:
if wheel == 4:
if map.map[wheel_id[0]-1, wheel_id[1]] == 0:
# print("積み下ろす必要がある")
r = r3
else:
map.map[wheel_id[0], wheel_id[1]] = 1
if map.map[wheel_id[0]-1, wheel_id[1]] == 2:
map.map[wheel_id[0]-1, wheel_id[1]] = 3
r = r2
s_buc_next = 0
else:
map.map[wheel_id[0]-1, wheel_id[1]] = 4
s_buc_next = 1
r = def_r
if wheel == 3:
map.map[wheel_id[0], wheel_id[1]] = 1
if map.map[wheel_id[0]-1, wheel_id[1]] == 0:
map.map[wheel_id[0]-1, wheel_id[1]] = 4
r = r1
s_buc_next = 1
elif map.map[wheel_id[0]-1, wheel_id[1]] == 2:
map.map[wheel_id[0]-1, wheel_id[1]] = 3
r = r3
s_buc_next = 0
else:
map.map[wheel_id[0]-1, wheel_id[1]] = 3
s_buc_next = 0
r = def_r
else:
# print("over the workspace")
r = r3
if 3 == act:
if wheel_id[0] < (map.row-1):
if wheel == 4:
if map.map[wheel_id[0]+1, wheel_id[1]] == 0:
# print("積み下ろす必要がある")
r = r3
else:
map.map[wheel_id[0], wheel_id[1]] = 1
if map.map[wheel_id[0]+1, wheel_id[1]] == 2:
map.map[wheel_id[0]+1, wheel_id[1]] = 3
r = r2
s_buc_next = 0
else:
map.map[wheel_id[0]+1, wheel_id[1]] = 4
s_buc_next = 1
r = def_r
if wheel == 3:
map.map[wheel_id[0], wheel_id[1]] = 1
if map.map[wheel_id[0]+1, wheel_id[1]] == 0:
map.map[wheel_id[0]+1, wheel_id[1]] = 4
r += r1
s_buc_next = 1
elif map.map[wheel_id[0]+1, wheel_id[1]] == 2:
map.map[wheel_id[0]+1, wheel_id[1]] = 3
r = r3
s_buc_next = 0
else:
map.map[wheel_id[0]+1, wheel_id[1]] = 3
s_buc_next = 0
r = def_r
else:
# print("over the workspace")
r = r3
s_row_next, s_col_next = np.where(map.map>=3)
Q[s_row, s_col, s_buc, act] = Q[s_row, s_col, s_buc, act] +\
lr*(r + y*np.max(Q[s_row_next, s_col_next, s_buc_next, :]) - Q[s_row, s_col, s_buc, act])
rAll += r
if map.map[goal_id[0], goal_id[1]] == 1:
map.map[goal_id[0], goal_id[1]] = 2
clear = np.where(map.map==0)
if not bool(len(clear[0])):
sys.stdout.write(" clear")
break
pygame.display.update()
FPSCLOCK.tick(fps)
if test != True:
if ep%(episode//10) == 0:
np.save(save_path + "/Qtable_ep{0:0=5}.npy".format(ep), Q)
print()
if __name__ == "__main__":
main(max_step=5000,
episode=100,
map_n=2,
greed_e=0.1,
r1=1, # 積み上げ報酬
r2=2, # 積み下ろし報酬
r3=-1, # 壁ぶつかり報酬
def_r=-0.01, # ただ移動しているだけでは減点されていく
test=True,
fps=100)
学習初期
ランダムに移動しているだけで、全然クリアできそうにありません
後半
なんとなくそれっぽい感じになってきました
ただ、このアルゴリズムは迷路探索の強化学習アルゴリズムをそのまま組み込んだので、学習後半になるにつれてどんどん動きが悪くなっていきます。
これを改善するにはエージェントがフィールドの状態を理解するなど、ゲームにあったアルゴリズムを組み込む必要があります。
今回は強化学習のほんのさわりの部分を実装してみました。