4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Python、Pygameで作ったゲームに強化学習を実装

Last updated at Posted at 2019-09-20

物体を一ヶ所に集めるゲームで強化学習

強化学習に関して勉強する機会があり、試しに実装してみたくなったので実装してみた。
本当はきれいなコードで記述したかったが、状態価値やらなんやらを考えてると手が動かなくなったので、
とりあえず手を動かそうと思って作ってみたら、かなり雑な仕上がりになってしまった。
そのうちgymみたいな記述方式に直すかもしれないです。
強化学習の概要についてはweb上にいくらでも記事があるのでここでは触れません。

ゲーム内容

  • エージェントがフィールド内の物体を一ヶ所に集めたらクリア
  • エージェントはフィールド内を動き回ることができる
  • 物体のあるマスに移動するとエージェントは物体を持った状態となる
  • 物体は一度に一つしか持つことができない
  • 特定のマスでしか荷物を降ろすことができない
  • 荷物を特定のマスで降ろすことによって報酬を獲得できる

マップとプレイヤーの配置についてはこのサイトを参考にしました。
【Pygame】RPGのマップ作成 | 西住工房@技術雑記

Q学習

Q学習でQ値を以下のような更新式で更新していきます。

Q(s_t,a_t)\leftarrow Q(s_t,a_t)+\eta\bigl(r_{t}+\gamma\max_aQ(s_{t+1},a)-Q(s_t,a_t)\bigr)

詳しい理論に関してはここで下手に説明するよりもいくらでも良記事があるので
気になる方は各自ググってください。

ソースコード

main.py
import pygame
from pygame.locals import *
import sys
import numpy as np
import os
from datetime import datetime
SCR_RECT = Rect(0, 0, 640, 480) # 画面サイズ

# マップのクラス
class Map:
    def __init__(self, map_n=0):
        # マップデータ
        self.map = np.loadtxt("./field/field{}.txt".format(map_n)).astype(np.int32)
        self.row,self.col = len(self.map), len(self.map[0]) # マップの行数,列数を取得
        self.imgs = [None] * 16             # マップチップ
        self.msize = 32                      # 1マスの大きさ[px]
    # マップの描画
    def draw(self, screen):
        for i in range(self.row):
            for j in range(self.col):
                screen.blit(self.imgs[self.map[i][j]], (j*self.msize,i*self.msize))

# 画像の読み込み
def load_img(filename, colorkey=None):
    img = pygame.image.load(filename)
    img = img.convert()
    if colorkey is not None:
        if colorkey == -1:
            colorkey = img.get_at((0,0))
        img.set_colorkey(colorkey, RLEACCEL)
    return img

def main(max_step, episode, map_n, greed_e, r1, r2, r3, def_r, test, fps):
    save_path = "./log/{0}_map{1}".format(datetime.now().strftime("%y%m%d%H"), map_n)
    if test != True:
        os.makedirs(save_path)

    Q = np.zeros([15,20,2,4])
    # learning rate
    lr = 0.8
    # 割引率
    y = 0.95
    r = 0

    for ep in range(episode):
        step = 0
        s_buc = 0
        s_buc_next = 0
        act = 0
        pygame.init()
        screen = pygame.display.set_mode(SCR_RECT.size)
        map = Map(map_n)
        map.imgs[0] =  pygame.image.load("./image/sand.png")
        map.imgs[1] =  pygame.image.load("./image/white.png")
        map.imgs[2] =  pygame.image.load("./image/track.png")
        map.imgs[3] =  pygame.image.load("./image/wheel.png")
        map.imgs[4] =  pygame.image.load("./image/wheel_bucket.png")
        goal_id = np.where(map.map==2)
        FPSCLOCK = pygame.time.Clock()
        rAll = 0
        while step < max_step:
            wheel_id = np.where(map.map>=3)
            wheel = map.map[wheel_id[0], wheel_id[1]]
            s_row, s_col = wheel_id
            #print(wheel_id)
            if wheel == 4:
                s_buc = 1
            elif wheel == 3:
                s_buc = 0
            else:
                print("error!")
                pygame.quit()
                sys.exit()

            step += 1
            sys.stdout.write("\rstep: {}/{} | episode: {}/{} | sum reward: {}".format(step,max_step,ep,episode, rAll))
            sys.stdout.flush()
            if  np.random.uniform() < greed_e:
                act = np.random.randint(0, 4)
            else:
                act = np.argmax(Q[s_row,s_col,s_buc,:] + np.random.randn(1, 4)*(1.0/(step + 1)))
            map.draw(screen)
            # イベント処理
            for event in pygame.event.get():
                # 終了用のイベント処理
                if event.type == QUIT:          # 閉じるボタンが押されたとき
                    pygame.quit()
                    sys.exit()
                if event.type == KEYDOWN:       # キーを押したとき
                    if event.key == K_ESCAPE:   # Escキーが押されたとき
                        pygame.quit()
                        sys.exit()
            if 0 == act:
                if wheel_id[1] > 0:
                    # 貨物状態
                    if wheel == 4:
                        # 物体にぶつかる場合
                        if map.map[wheel_id[0], wheel_id[1]-1] == 0:
                            # print("積み下ろす必要がある")
                            r = r3
                        else:
                            map.map[wheel_id[0], wheel_id[1]] = 1
                            # 積み下ろし
                            if map.map[wheel_id[0], wheel_id[1]-1] == 2:
                                map.map[wheel_id[0], wheel_id[1]-1] = 3
                                r = r2
                                s_buc_next = 0
                            # 非目的地
                            else:
                                map.map[wheel_id[0], wheel_id[1]-1] = 4
                                s_buc_next = 1
                                r = def_r
                    # 空状態
                    if wheel == 3:
                        map.map[wheel_id[0], wheel_id[1]] = 1
                        # 積み上げ
                        if map.map[wheel_id[0], wheel_id[1]-1] == 0:
                            map.map[wheel_id[0], wheel_id[1]-1] = 4
                            r = r1
                            s_buc_next = 1
                        # ゴール
                        elif map.map[wheel_id[0], wheel_id[1]-1] == 2:
                            map.map[wheel_id[0], wheel_id[1]-1] = 3
                            r = r3
                            s_buc_next = 0
                        # 無し
                        else:
                            map.map[wheel_id[0], wheel_id[1]-1] = 3
                            s_buc_next = 0
                            r = def_r
                else:
                    # print("over the workspace")
                    r = r3

            if 1 == act:
                if wheel_id[1] < (map.col-1):
                    if wheel == 4:
                        if map.map[wheel_id[0], wheel_id[1]+1] == 0:
                            # print("積み下ろす必要がある")
                            r = r3
                        else:
                            map.map[wheel_id[0], wheel_id[1]] = 1
                            if map.map[wheel_id[0], wheel_id[1]+1] == 2:
                                map.map[wheel_id[0], wheel_id[1]+1] = 3
                                r = r2
                                s_buc_next = 0
                            else:
                                map.map[wheel_id[0], wheel_id[1]+1] = 4
                                s_buc_next = 1
                                r = def_r
                    if wheel == 3:
                        map.map[wheel_id[0], wheel_id[1]] = 1
                        if map.map[wheel_id[0], wheel_id[1]+1] == 0:
                            map.map[wheel_id[0], wheel_id[1]+1] = 4
                            r = r1
                            s_buc_next = 1
                        elif map.map[wheel_id[0], wheel_id[1]+1] == 2:
                            map.map[wheel_id[0], wheel_id[1]+1] = 3
                            r = r3
                            s_buc_next = 0
                        else:
                            map.map[wheel_id[0], wheel_id[1]+1] = 3
                            s_buc_next = 0
                            r = def_r
                else:
                    # print("over the workspace")
                    r = r3

            if 2 == act:
                if wheel_id[0] > 0:
                    if wheel == 4:
                        if map.map[wheel_id[0]-1, wheel_id[1]] == 0:
                            # print("積み下ろす必要がある")
                            r = r3
                        else:
                            map.map[wheel_id[0], wheel_id[1]] = 1
                            if map.map[wheel_id[0]-1, wheel_id[1]] == 2:
                                map.map[wheel_id[0]-1, wheel_id[1]] = 3
                                r = r2
                                s_buc_next = 0
                            else:
                                map.map[wheel_id[0]-1, wheel_id[1]] = 4
                                s_buc_next = 1
                                r = def_r
                    if wheel == 3:
                        map.map[wheel_id[0], wheel_id[1]] = 1
                        if map.map[wheel_id[0]-1, wheel_id[1]] == 0:
                            map.map[wheel_id[0]-1, wheel_id[1]] = 4
                            r = r1
                            s_buc_next = 1
                        elif map.map[wheel_id[0]-1, wheel_id[1]] == 2:
                            map.map[wheel_id[0]-1, wheel_id[1]] = 3
                            r = r3
                            s_buc_next = 0
                        else:
                            map.map[wheel_id[0]-1, wheel_id[1]] = 3
                            s_buc_next = 0
                            r = def_r
                else:
                    # print("over the workspace")
                    r = r3

            if 3 == act:
                if wheel_id[0] < (map.row-1):
                    if wheel == 4:
                        if map.map[wheel_id[0]+1, wheel_id[1]] == 0:
                            # print("積み下ろす必要がある")
                            r = r3
                        else:
                            map.map[wheel_id[0], wheel_id[1]] = 1
                            if map.map[wheel_id[0]+1, wheel_id[1]] == 2:
                                map.map[wheel_id[0]+1, wheel_id[1]] = 3
                                r = r2
                                s_buc_next = 0
                            else:
                                map.map[wheel_id[0]+1, wheel_id[1]] = 4
                                s_buc_next = 1
                                r = def_r
                    if wheel == 3:
                        map.map[wheel_id[0], wheel_id[1]] = 1
                        if map.map[wheel_id[0]+1, wheel_id[1]] == 0:
                            map.map[wheel_id[0]+1, wheel_id[1]] = 4
                            r += r1
                            s_buc_next = 1
                        elif map.map[wheel_id[0]+1, wheel_id[1]] == 2:
                            map.map[wheel_id[0]+1, wheel_id[1]] = 3
                            r = r3
                            s_buc_next = 0
                        else:
                            map.map[wheel_id[0]+1, wheel_id[1]] = 3
                            s_buc_next = 0
                            r = def_r
                else:
                    # print("over the workspace")
                    r = r3

            s_row_next, s_col_next = np.where(map.map>=3)
            Q[s_row, s_col, s_buc, act] = Q[s_row, s_col, s_buc, act] +\
                lr*(r + y*np.max(Q[s_row_next, s_col_next, s_buc_next, :]) - Q[s_row, s_col, s_buc, act])
            rAll += r
            if map.map[goal_id[0], goal_id[1]] == 1:
                map.map[goal_id[0], goal_id[1]] = 2

            clear = np.where(map.map==0)
            if not bool(len(clear[0])):
                sys.stdout.write("  clear")
                break
            pygame.display.update()
            FPSCLOCK.tick(fps)
        if test != True:
            if ep%(episode//10) == 0:
                np.save(save_path + "/Qtable_ep{0:0=5}.npy".format(ep), Q)
        print()

if __name__ == "__main__":
    main(max_step=5000,
         episode=100,
         map_n=2,
         greed_e=0.1,
         r1=1, # 積み上げ報酬
         r2=2, # 積み下ろし報酬
         r3=-1, # 壁ぶつかり報酬
         def_r=-0.01, # ただ移動しているだけでは減点されていく
         test=True,
         fps=100) 

学習初期

ランダムに移動しているだけで、全然クリアできそうにありません

before_learn.gif

後半

なんとなくそれっぽい感じになってきました

after_learn.gif

ただ、このアルゴリズムは迷路探索の強化学習アルゴリズムをそのまま組み込んだので、学習後半になるにつれてどんどん動きが悪くなっていきます。
これを改善するにはエージェントがフィールドの状態を理解するなど、ゲームにあったアルゴリズムを組み込む必要があります。
今回は強化学習のほんのさわりの部分を実装してみました。

4
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?