2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

カードゲームの強化学習

Posted at

はじめに

よく知られる碁の強化学習は「Alpha Go」があり、将棋やチェスにも対応した強化学習は「AlphaZero」とか「MuZero」がある。
オセロAIもDeepLearning的立ち位置からやられる事がある。(評価関数的AIの場合もある)
ポーカー(テキサスホールデム)の強化学習は「Pluribus」、麻雀の強化学習は「Suphx」がある。

一方、コンピュータゲームではAtari(70,80年代のビデオゲーム機)のゲームをやるモデルに「Agent57」、「MuZero」、「R2D2」等がある。また、マリオの強化学習としてDQNのチュートリアルコードがあり、比較的よくやられている。その他、強化学習がやられるゲームとしては動画からの学習の題材としてMinecraft、starcraft2の「AlphaStar」などがあるらしい。

しかし、TCG(トレーディングカードゲーム)の強化学習についてはあまり情報は無く、簡単なゲーム環境で強化学習ができるかどうかを試す。

不確定か不完全情報か

二人零和有限確定完全情報ゲームのwikiによればサイコロのようにランダム性がない事を「確定」、全ての情報が両方のプレイヤーに公開されている事を「完全情報」という。

ここで一人でフィールドにカード並べるゲームを考える。
ここでカードゲームでは場のカードは全て確認できるがデッキの中のカードの並びは確認できない。そういう意味ではこのゲームは「確定不完全情報」といえる。
一方、デッキからカードを引く時にデッキのカード割合からランダムにカードが選択されると考えるなら、情報は全て確認できるので「不確定完全情報」ゲームであると言える。
デッキの並びを考慮するのは大変なので、ここではデッキからカードを引く際にランダム性が発生するゲームと考える。

ボードゲーム的思考

検討するゲームはトランプのようなカードゲームの一種なのだから同じような手札の組み合わせを考えたら膨大な数になる。例えばカードが6種類しか存在しなくても手札5枚、フィールド5枚、墓地5枚、計15枚の状態を[card0,card1,card2,card3,card4,card5,None]の7通りで示すと$7^{15}=4,747,561,509,943$となって膨大な状態数になる。また、各状態のカードは順不同であっても同じ意味である。このため各箇所のカード状態をインデックスで示すのはあまり良くないように感じる。
ここでカードがプール(EXデッキと兼用)、デッキ、手札、フィールド、墓地を移動していくボードゲーム的な環境を考える。これは単に各種の駒がグリッド状の場を左や右に移動するだけのゲームである。また、強化学習におけるactionの数はカードの種類×フィールドの種類となる。
このカードの配置状態はゲームプレイヤーは把握可能である。
image.png
effectの場所は効果を発動したカードや融合素材を一時的に移動する。
作成環境が過去の行動履歴を記憶しないので効果の対象を選ぶ際にその元の効果を判別するために作った。過去の行動履歴を保持するなら不要。

簡単な環境の作成

6種類のカード、1種類最大3枚、デッキ枚数10枚、初期手札2枚、3ターン終了後にゲーム終了、報酬はバトルフェイズ(BF)にフィールドに存在するモンスターの攻撃力の総和、意味のない行動には報酬-10、0ターン目はカードプールからデッキにデッキが10枚になるまで移す。
card0は魔法カード、card1,card2は下級モンスター、card3は上級モンスター、card4,card5は融合モンスターである。モデルとなったカード(ドラゴンメイド)から一部効果は省略、簡略化している。

・カード効果
card0(魔法):①手札、フィールドのモンスターを素材に融合モンスターを召喚。②このカードが墓地にある時、フィールドのモンスターを手札に戻し、このカードを墓地から手札に戻す。
card1(下級):①このカードが召喚、特殊召喚に成功した時、デッキから墓地にカードを送る。②BF終了時にこのカードが手札にあり、フィールドにcard3がある時、card3を手札に戻し、このカードを特殊召喚する。
card2(下級):①このカードが召喚、特殊召喚に成功した時、デッキから魔法カードを手札に加える。
card3(上級):①BF開始時にこのカードが手札、墓地にあり、フィールドにcard1かcard2がある時、フィールドのカードを手札に戻し、このカードを特殊召喚する。
card4(融合):card1~5の内、任意の二枚を融合素材とした場合。(効果省略)
card5(融合):二体の融合素材がcard1とcard2のみではない場合に可能。①スタンバイフェイズに手札、墓地からcard1~card4を特殊召喚する。

仮に手札がcard1とcard3の2枚の場合の最適な初動は、以下の手順である。
・card1を通常召喚。card1の効果でcard0をデッキから墓地に落とす。墓地のcard0の効果でフィールドのcard1を手札に戻し、墓地のcard0を手札に戻す。手札のcard0の発動。手札のcard1とcard3を素材にcard5を特殊召喚。

上記の場合、行動ミスしなくても最短で9手掛かりランダム行動では最適手には達しにくい。

import gym
import gym.spaces
import numpy as np
import pandas as pd

class MyGame:
    def __init__(self):
        self.kind_card = 6  # 0, 1, 2, 3, 4, 5
        self.kind_field = 5 # 0:pool, 1:deck, 2:hand, 3:field, 4:boti
        self.init_card = 2
        self.not_ex = 4
        self.max_turn_num = 3
        self.max_mons_num = 5
        self.max_deck_num = 10
        self.a = np.arange(30).reshape(5,6)
        self.att = np.array([0, 500, 500, 2700, 3000, 3500])
        self.is_turn1_limit = [0] * 6
        self.is_summon = [0] * 6
        self.total_reward = 0
        self.turn = 0
        self.phase = 0      # 0:struct_deck, 1:stanby, 2:main1, 3:start_BF, 4:end_BF, 5:main2

        self.action = 0
        self.state = np.zeros((6, 6))
        self.select_state = 0
        self.summon_right = False
        self.is_effect = False
        self.is_special_summon = False
        self.done = False
        self.epsilon = 0

    def pool_init(self):
        self.state = np.zeros((6, 6))
        self.state[0, :] = np.ones(6) * 3

    def reset(self):
        self.pool_init()
        self.total_reward = 0
        self.turn = 0
        self.phase = 0
        #print('turn=', self.turn)
        #print('phase=', self.phase)
        self.done = False

    def start_game(self):
        self.deal_init()
        self.turn = 1
        self.phase = 1
        #print('turn=', self.turn)
        #print('phase=', self.phase)
        self.summon_right = True

    def end_game(self):
        self.done = True

    def change_phase(self):
        if self.phase == 3:
           self.battle_phase()
        self.phase += 1
        #print('phase=', self.phase)
        self.is_summon = [0] * 6

    def change_turn(self):
        self.change_phase()
        self.summon_right = True
        self.phase = 1
        self.turn += 1
        #print('turn=', self.turn)
        #print('phase=', self.phase)
        self.is_turn1_limit = [0] * 6
        self.draw_random_card()
        if self.turn == self.max_turn_num+1:
            self.end_game()

    def deal_init(self):
        for i in range(self.init_card):
            self.draw_random_card()
#        if np.random.rand() < self.epsilon:
#            self.discard_random_card()

    def exist_card(self, action):
        state = self.state.flatten()
        if state[action] >= 1:
            return True
        return False

    def exchange_action(self, action):
        m = action % self.kind_card
        pos = action // self.kind_card
        return m, pos

    def discard_random_card(self):
        prob = np.array(self.state[1])
        prob /= np.sum(prob)
        card_index = np.random.choice(a=list(range(6)), size=1, p=prob)
        self.move_card(card_index[0], 1, 4)

    def draw_random_card(self):
        prob = np.array(self.state[1])
        prob /= np.sum(prob)
        card_index = np.random.choice(a=list(range(6)), size=1, p=prob)
        self.move_card(card_index[0], 1, 2)

    def move_card(self, m, pos1, pos2):
        self.state[pos1, m] -= 1
        self.state[pos2, m] += 1
        if pos2 == 3:
            self.is_summon[m] = 1

    def move_cards(self, pos1, pos2):
        for i in range(6):
            if self.state[pos1, i] >= 1 and pos2 == 3:
                self.is_summon[i] = 1
            if self.state[pos1, i] == 1:
                self.move_card(i, pos1, pos2)
            if self.state[pos1, i] == 2:
                self.move_card(i, pos1, pos2)
                self.move_card(i, pos1, pos2)
            if self.state[pos1, i] == 3:
                self.move_card(i, pos1, pos2)
                self.move_card(i, pos1, pos2)
                self.move_card(i, pos1, pos2)

    def is_make_deck(self):
        if self.action in list(self.a[0,:4].flatten()):
            return True
        return False
        
    def make_deck(self):
        m, pos = self.exchange_action(self.action)
        self.move_card(m, pos, 1)

    def is_complete_make_deck(self):
        if int(np.sum(self.state[1])) == self.max_deck_num:
            return True
        return False

    def complete_make_deck(self):
        self.start_game()

    def is_fusion_summon(self):
        if self.select_state == 0:
            if self.action in list(self.a[2,0].flatten()):
                if self.state[2, 0] >= 1:
                    if self.state[0, 4] >= 1:
                        if np.sum(self.state[2:4, 1:]) >= 2:
                            return True
                    if self.state[0, 5] >= 1:
                        if np.sum(self.state[2:4, 1:]) >= 2 and (np.sum(self.state[2:4, 1:]) != np.sum(self.state[2:4, 1:3])):
                            return True
            return False
        if self.select_state == 1:
            if self.state[3, 0] >= 1:
                if self.action in list(self.a[2:4, 1:].flatten()):
                    return True
        if self.select_state == 2:
            if self.state[3, 0] >= 1:
                if self.action in list(self.a[2:4, 1:].flatten()):
                    return True
        if self.select_state == 3:
            if self.action in list(self.a[0, 4].flatten()):
                     return True
            if self.action in list(self.a[0, 5].flatten()):
                if np.sum(self.state[5, 1:3]) != 2:
                     return True
        return False

    def fusion_summon(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 3)
            self.select_state += 1
        elif self.select_state == 1:
            self.move_card(m, pos, 5)
            self.select_state += 1
        elif self.select_state == 2:
            self.move_card(m, pos, 5)
            self.select_state += 1
        elif self.select_state == 3:
            self.move_card(m, pos, 3)
            self.move_card(0, 3, 4)
            self.move_cards(5, 4)
            self.select_state = 0
            #print('fusion_summon', m)

    def is_normal_summon(self):
        if self.select_state == 0:
            if np.sum(self.state[3, 1:]) <= self.max_mons_num:
                if self.summon_right:
                    if self.action in list(self.a[2, 1:3].flatten()):
                        return True
        return False

    def normal_summon(self):
        m, pos = self.exchange_action(self.action)
        self.move_card(m, pos, 3)
        self.summon_right = False
        #print('normal_summon', m)

    def is_special_summon_mon3(self):
        if self.select_state == 0:
            if self.action in list(self.a[4, 3].flatten()) or self.action in list(self.a[2, 3].flatten()):
                if np.sum(self.state[3, 1:3]) >= 1:
                     return True
        if self.select_state == 1:
            if self.is_special_summon:
                if self.state[5, 3] == 1:
                    if self.action in list(self.a[3, 1:3].flatten()):
                        return True
        return False

    def special_summon_mon3(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 5)
            self.is_special_summon = True
            self.select_state = 1
        elif self.select_state == 1:
            self.move_card(m, pos, 2)
            self.move_cards(5, 3)
            self.is_special_summon = False
            self.select_state = 0
            #print('special_summon_mon3')
    
    def is_special_summon_mon1(self):
        if self.select_state == 0:
            if self.action in list(self.a[2, 1].flatten()):
                 if self.state[3, 3] >= 1:
                     return True
        if self.select_state == 1:
            if self.is_special_summon:
                 if self.state[5, 1] == 1:
                     if self.action in list(self.a[3, 3].flatten()):
                         return True
        return False

    def special_summon_mon1(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 5)
            self.is_special_summon = True
            self.select_state = 1
        elif self.select_state == 1:
            self.move_card(m, pos, 2)
            self.move_cards(5, 3)
            self.is_special_summon = False
            self.select_state = 0
            #print('special_summon_mon1')

    def is_special_summon_mon5(self):
        if self.select_state == 0:
            if self.is_turn1_limit[5] == 0:
                 if self.action in list(self.a[3, 5].flatten()):
                     if np.sum(self.state[2, 1:4]) + np.sum(self.state[4, 1:5]) >= 1:
                         return True
        if self.select_state == 1:
            if self.is_special_summon:
                 if self.state[5, 5] == 1:
                     if self.action in list(self.a[2, 1:4].flatten()) or self.action in list(self.a[4, 1:5].flatten()):
                         return True
        return False

    def special_summon_mon5(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 5)
            self.is_special_summon = True
            self.select_state = 1
        elif self.select_state == 1:
            self.move_card(m, pos, 3)
            self.move_cards(5, 3)
            self.is_special_summon = False
            self.is_turn1_limit[5] = 1
            self.select_state = 0
            #print('special_summon_mon5')

    def is_effect_mon1(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            if self.is_turn1_limit[m] == 0:
                if self.is_summon[m] == 1:
                    if self.action in list(self.a[3,1].flatten()):
                        if np.sum(self.state[1,:4]) >= 1:
                             return True
        if self.select_state == 1:
            if self.is_effect:
                if self.state[5,1] == 1:
                    if self.action in list(self.a[1,:4].flatten()):
                        return True
        return False

    def effect_mon1(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 5)
            self.select_state += 1
            self.is_effect = True
        elif self.select_state == 1:
            self.move_card(m, pos, 4)
            self.move_card(1, 5, 3)
            self.is_turn1_limit[1] = 1
            self.select_state = 0
            self.is_effect = False
            #print('effect_mon1')

    def is_effect_mon2(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            if self.is_turn1_limit[m] == 0:
                if self.is_summon[m] == 1:
                    if self.action in list(self.a[3,2].flatten()):
                        if np.sum(self.state[1,0]) >= 1:
                             return True
        if self.select_state == 1:
            if self.is_effect:
                if self.state[5,2] == 1:
                    if self.action in list(self.a[1,0].flatten()):
                        return True
        return False

    def effect_mon2(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 5)
            self.select_state += 1
            self.is_effect = True
        elif self.select_state == 1:
            self.move_card(m, pos, 2)
            self.move_card(2, 5, 3)
            self.is_turn1_limit[2] = 1
            self.select_state = 0
            self.is_effect = False
            #print('effect_mon2')

    def is_effect_mag0(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            if self.is_turn1_limit[m] == 0:
                if self.action in list(self.a[4,0].flatten()):
                    if np.sum(self.state[3,1:]) >= 1:
                        return True
        if self.select_state == 1:
            if self.is_effect:
                if self.state[5,0] == 1:
                    if self.action in list(self.a[3,1:].flatten()):
                        return True
        return False

    def effect_mag0(self):
        m, pos = self.exchange_action(self.action)
        if self.select_state == 0:
            self.move_card(m, pos, 5)
            self.select_state += 1
            self.is_effect = True
        elif self.select_state == 1:
            if m in [1,2,3]:
                self.move_card(m, pos, 2)
            elif m in [4,5]:
                self.move_card(m, pos, 0)
            self.move_card(0, 5, 2)
            self.is_turn1_limit[0] = 1
            self.select_state = 0
            self.is_effect = False
            #print('effect_mag0')

    def battle_phase(self):
        self.total_reward += np.sum(self.state[3] * self.att)
#        if self.state[2,0] >= 1:
#            self.total_reward += 1000
        #print('total_reward=', self.total_reward)
    
    def get_bf_reward(self):
        reward = np.sum(self.state[3] * self.att)
#        if self.state[2,0] >= 1:
#            reward += 1000
        return reward

class MyEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.game = MyGame()
        self.game.reset()
        self.action_space = gym.spaces.Discrete(30)
        high, low = np.ones(36, dtype='float32') * 3.0, np.ones(36, dtype='float32') * -3.0
        self.observation_space = gym.spaces.Box(low=low, high=high)
        self.reward_range = [-2000, 40000]
        self.done = False
        self.step_num = 0
        self.max_step_num = 5000

    def reset(self):
        self.step_num = 0
        self.game.reset()
        return self.observe()

    def render(self):
        print('turn=', self.game.turn, 'phase=', self.game.phase, 'select_state=', self.game.select_state)
        df = pd.DataFrame(self.game.state.T,
                          index=['card%d' % (i) for i in range(6)],
                          columns=['pool', 'deck', 'hand', 'field', 'boti', 'effect'])
        print(df)
        print()
    def close(self):
        pass

    def step(self, action):
        self.step_num += 1
        if self.step_num > self.max_step_num:
            self.game.done = True
            self.render()
        
        reward = 0.0
        self.game.action = action
        if action in [10, 11, 16, 17] and self.game.select_state==0:
            if self.game.phase==5:
                self.game.change_turn()
            elif self.game.phase!=0:
                self.game.change_phase()
            elif self.game.phase==0:
                reward -= 10.0
            if self.game.phase==4:
                reward += self.game.get_bf_reward()
        elif action in [10, 11, 16, 17] and self.game.select_state!=0:
            reward -= 10.0
        elif self.game.exist_card(action)==False:
            reward -= 10.0
        else:
            if self.game.phase==0:
                if self.game.is_make_deck():
                    self.game.make_deck()
                else:
                    reward -= 10.0
                if self.game.is_complete_make_deck():
                    self.game.complete_make_deck()

            elif self.game.phase==1:
                # 1.stanby phase
                if self.game.is_special_summon_mon5():
                    self.game.special_summon_mon5()
                elif self.game.is_effect_mon1():
                    self.game.effect_mon1()
                elif self.game.is_effect_mon2():
                    self.game.effect_mon2()
                else:
                    reward -= 10.0

            elif self.game.phase==2 or self.game.phase==5:
                # (2:main1 phase) or (5:main2 phase)
                if self.game.is_normal_summon():
                    self.game.normal_summon()
                elif self.game.is_fusion_summon():
                    self.game.fusion_summon()
                elif self.game.is_effect_mag0():
                    self.game.effect_mag0()
                elif self.game.is_effect_mon1():
                    self.game.effect_mon1()
                elif self.game.is_effect_mon2():
                    self.game.effect_mon2()
                else:
                    reward -= 10.0

            elif self.game.phase==3:
                # 3.start BF
                if self.game.is_special_summon_mon3():
                    self.game.special_summon_mon3()
                else:
                    reward -= 10.0

            elif self.game.phase==4:
                # 4.end BF
                if self.game.is_special_summon_mon1():
                    self.game.special_summon_mon1()
                elif self.game.is_effect_mon1():
                    self.game.effect_mon1()
                else:
                    reward -= 10.0

        observation = self.observe()
        return observation, reward, self.game.done, {}

    def get_reward(self):
        return self.game.total_reward
    def is_done(self):
        return self.game.done

    def observe(self):

        observation = np.zeros((31,6,6), dtype='float32')
        for i in range(4):
            observation[i] = np.where(self.game.state == i, 1, 0)
            if self.game.select_state==i:
                observation[i+4] = np.ones((6,6))
            if self.game.turn==i:
                observation[i+8] = np.ones((6,6))
        for i in range(6):
            if self.game.phase==i:
                observation[i+12] = np.ones((6,6))
            if self.game.is_turn1_limit[i]==1:
                observation[i+18] = np.ones((6,6))
            if self.game.is_summon[i]==1:
                observation[i+24] = np.ones((6,6))
        if self.game.summon_right:
            observation[30] = np.ones((6,6))
        return observation

行動は以下の0~29の30通りである。
また、別途next_phaseに移行するactionを定義しないといけないがEXカードのデッキと手札は使わないのでこれの行動[10,11,16,17]をとる時、次のphaseに移行する。
image.png

動作確認のため以下の様なactionを実行した結果は以下である。

env = MyEnv()
observation = env.reset()
env.render()
observation, reward, done, _ = env.step(0)
observation, reward, done, _ = env.step(1)
observation, reward, done, _ = env.step(2)
observation, reward, done, _ = env.step(3)
observation, reward, done, _ = env.step(0)
observation, reward, done, _ = env.step(1)
observation, reward, done, _ = env.step(2)
observation, reward, done, _ = env.step(3)
observation, reward, done, _ = env.step(1)
observation, reward, done, _ = env.step(2)
env.render()
observation, reward, done, _ = env.step(10)
observation, reward, done, _ = env.step(13)
env.render()
observation, reward, done, _ = env.step(19)
observation, reward, done, _ = env.step(6)
env.render()
observation, reward, done, _ = env.step(24)
observation, reward, done, _ = env.step(19)
env.render()
observation, reward, done, _ = env.step(12)
observation, reward, done, _ = env.step(13)
observation, reward, done, _ = env.step(14)
observation, reward, done, _ = env.step(4)
env.render()

q_learn_Figure_3.png

Q学習

最初にネット上のチュートリアルコードを参考にQ学習を試した。

Q学習はモデルフリーでいい反面、全ての状態における全ての行動の方針テーブルを作成する必要がある。例えば取りうる状態が1000通り以下で取りうる行動が10通り以下なら、高々10000個のQテーブルを更新すればよい。三目並べ(〇×ゲーム)とかブラックジャックのように状態及び行動が全探索できる程度の環境なら問題ないが、今の状態は全探索にほど遠いと思われる。
このため学習回数を増やせばメモリが尽きる。

DQN学習

次にスーパーマリオのチュートリアルコードを参考にDQN学習をやった。
入力を(31,6,6)、出力を(30)とする適当なCNNモデルを考え、他はチュートリアルコードのままである。

最初、観測データを出来る限り敷き詰め、入力データを圧縮し、入力を(2,6,6)とした。この時、チャンネル0には盤面の各種カードの数、チャンネル1には各種フラグをonehotにしたものを敷き詰めた。
このDQN学習をした場合、turn=0の場合、poolにあるカードを選択し、turn!=0の場合、[10,11,16,17]の現在のphaseをskipする行動をとった。しかしながら、モンスターの召喚についてはほとんど学習されなかった。

次に入力を(31,6,6)とした時、平均報酬は10000くらいまでは学習出来た。
これは基本的なモンスターの召喚については学習が出来ている。

reward_plot.jpg

再学習

1000回実行したときの結果は
image.png

学習内容を確認すると
・デッキ組み、card1,card2の通常召喚、card3の特殊召喚(BF開始時手札から)、card2の効果(召喚時)、card5による特殊召喚(スタンバイ)、融合召喚は学習出来ている。
・card1の特殊召喚(BF終了時)、card0の効果(墓地)は実行されているものの、見た感じturn3のBF終了後に限られ、特に意味がない。
・card3の特殊召喚(BF開始時墓地から)、card1の効果(召喚時)は有効には学習はされていない。

card0の効果(墓地)が上手く学習されないのはフィールドのカードを手札に戻すので短期的には報酬がマイナスになるのが問題かもしれない。
card3の特殊召喚(BF開始時墓地から)が学習されない理由は融合召喚でcard3が墓地に落ちても、card5による特殊召喚(スタンバイ)によって回収され、墓地にcard3が存在する状況がほとんどない。
card1の効果(召喚時)で墓地に落とそうとしてもcard3の特殊召喚(BF開始時墓地から)自体を学べてないから、落す意味を見出せない。

これの対策として開始時ランダムでデッキから0~1枚墓地に移すようにした。
また、かなり恣意的だがBF時に手札にcard0を所持している時、1000の報酬を与える事にした。
この要素を追加してモデル重みを読み込んで再学習した。

class MyGame:
    def __init__(self):
...
        self.epsilon = 0
...
    def deal_init(self):
        for i in range(self.init_card):
            self.draw_random_card()
        if np.random.rand() < self.epsilon:
            self.discard_random_card()
    def discard_random_card(self):
        prob = np.array(self.state[1])
        prob /= np.sum(prob)
        card_index = np.random.choice(a=list(range(6)), size=1, p=prob)
        self.move_card(card_index[0], 1, 4)
...
    def battle_phase(self):
        self.total_reward += np.sum(self.state[3] * self.att)
        if self.state[2,0] >= 1:
            self.total_reward += 1000
        #print('total_reward=', self.total_reward)
    
    def get_bf_reward(self):
        reward = np.sum(self.state[3] * self.att)
        if self.state[2,0] >= 1:
            reward += 1000
        return reward
...
class Mario(Mario):
    def load(self, load_path):
        loaded_data = torch.load(load_path)
        mario.net.load_state_dict(loaded_data['model'])
        #mario.exploration_rate = loaded_data['exploration_rate']
...
mario.load(load_path)
episodes = 300000 
for e in range(episodes):
    env.game.epsilon = mario.exploration_rate
    state = env.reset()

この追加ルール(ランダム墓地、card0報酬1000)を加えた再学習は特にスコア的に改善は少なかった。追加ルールなしの環境で1000回プレイさせたが、card5の融合召喚の成功率は逆に下がってしまっている。card0報酬1000を加えたことでcard5の3500を召喚するより、card0を手元に残しcard3を特殊召喚した方が1000+2700=3700で短期的報酬が上回ったせいだろうか。

reward_plot.jpg
image.png

さらに学習パラメータを見直してみる。
デフォルトの割引率γ=0.9は長い手数を必要とする報酬を割り引いてしまうかもしれないのでγ=0.99とした。
また、行動の選択肢が多く、ランダム行動に無駄が多いので、取りうる行動に簡単な条件を付けてみた。行動可能かどうかは厳密ではないが判定になるべく時間が掛からないようにした。この変更で再学習を行った。
平均報酬は20000近くまで上がった。
この時のデッキ:card0二枚、card1三枚、card2二枚、card3三枚となった。
card0は普通一枚採用だが、二枚使用している。

reward_plot.jpg
image.png

class MyEnv(gym.Env):
    def legal_actions(self):
        if self.game.select_state!=0:
            return [4,5, 6,7,8,9, 12,13,14,15, 18,19,20,21,22,23, 24,25,26,27,28,29]
        elif self.game.phase==0:
            return [0,1,2,3]
        elif self.game.phase==1:
            return [10, 23]
        elif self.game.phase==2 or self.game.phase==5:
            return [10, 12,13,14, 18,19,20, 24]
        elif self.game.phase==3:
            return [10, 15,27]
        elif self.game.phase==4:
            return [10, 13]
        else:
            return list(range(30))
...
class Mario:
    def act(self, state, env):
        if np.random.rand() < self.exploration_rate:
            #action_idx = np.random.randint(self.action_dim)
            action_list = env.legal_actions()
            action_idx = action_list[np.random.randint(len(action_list))]
...
self.gamma = 0.99

理論上最大報酬

取りうる最大報酬を理論的に考えてみる。
デッキ:card0一枚、card1三枚、card2三枚、card3三枚の計10枚
初期手札2枚として、turn3までの攻撃力合計値を報酬とする。

初期手札が最良であれば
turn1:3500×1
turn2:3500×2
turn3:3500×3
total:21000
は2、3ターン目の引きに依存せず安定して出せる。

カードの引きによっては
turn1:3500×1
turn2:3500×1+2700×2
turn3:3500×3+2700×2
total:28300
が理論上最大報酬だと思われる。

今回の「DQN学習」はAtariのゲームなどの複雑なゲームではR2D2やMuZero等の他の手法に劣るだが、今回レベルの複雑さであればDQN学習でも十分の様である。

image.png

実行確認

デッキ:card0二枚、card1三枚、card2二枚、card3三枚
初期手札:card1とcard2の下級モンスター
image.png
1ターン目:card1を通常召喚、効果でcard3を墓地に、BF開始時にcard3を特殊召喚、BF2700ダメージ、BF終了時card1を特殊召喚、card3を手札に戻す。
image.png
2ターン目:card3をドロー、card2を通常召喚、効果でcard0を手札に、手札のcard3×2枚でcard5を融合召喚、BF開始時にcard3を二体特殊召喚、BF3500+2700×2ダメージ、メインフェイズ2で墓地のcard0を戻し、手札のcard2とcard3でcard5を融合召喚。
image.png
3ターン目:card1をドロー、スタンバイフェイズにcard5の効果で墓地のcard3を特殊召喚、メインフェイズ1で墓地のcard0を戻し、手札のcard1とcard3でcard5を融合召喚、card1を通常召喚、BF開始時にcard3を特殊召喚、BF3500×3+2700×2ダメージ、3ターントータルで27500ダメージ
image.png

turn1:2700×1
turn2:3500×1+2700×2
turn3:3500×3+2700×2
total:27500

初期手札が良くなく、1ターン目にcard5が召喚できていないにも関わらず、最終的にcard5を3体並べられるのは良い。

厳密には3ターン目のドローがcard0以外でcard2の通常召喚後の効果でデッキからcard0を引けばこのターン更にフィールドのcard3と手札のドローしたカードでcard4の融合召喚が可能で3500×3+3000×1+2700×1になる。そのためには2ターン目の融合素材の選択でcard2の代わりにcard1選ぶ必要がある。従って、まだ最適行動ではないが自分も見直している過程でこの手順に気付いたので、人間でも気づきにくいレベルである。

DQN学習モデル

モデルはマリオのチュートリアルコードから多少変更した程度である。
マリオだと入力データが$(4,84,84)$なのでstrideで画像サイズを減らして演算量を減らすメリットがあるが、ボードゲームの場合それがない。stride=1なので中間層はずっと$(6,6)$の解像度のままである。

class MarioNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        c, h, w = input_dim

        self.online = nn.Sequential(
            nn.Conv2d(in_channels=c,   out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(4608, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
        )
...
mario = Mario(state_dim=(31, 6, 6), action_dim=env.action_space.n, save_dir=save_dir)

課題点、改善点

思った改善点を挙げておく。
・今回はcard4の効果の実装は省略した。これは効果が複雑そうだったためである。
①スタンバイフェイズにフィールドにcard2がある時、手札、墓地からcard1を特殊召喚する。またはフィールドにcard1がある時、手札、墓地からcard2を特殊召喚する。②BF開始時にこのカードが墓地にあり、フィールドにcard2がある時、フィールドのcard2を手札に戻し、このカードを特殊召喚する。
・モデルがCNNなので隣接データとの相関は畳み込みで学習されやすいだろう。逆に言えば離れた場所の相関は学習しにくいかもしれない。これはcard0とcard5の離れた相関がcard2とcard3の隣接の相関より学習されにくいかもしれない。また、poolとeffectの離れた相関がhandとfieldの隣接の相関より学習されにくいかもしれない。
入力に半分シフトした入力を追加するとか([1,2,3,4,5,6]と[4,5,6,1,2,3])、二倍の繰り返しデータを作ってstride=2で縮めるとか([1,2,3,4,5,6,0,1,2,3,4,5,6]=>[2,4,6,1,3,5])とかで畳み込みのkernel_sizeを増やさなくても離れた相関を学習出来るかもしれない。
これは場が6×6の場合、それほど大したメリットはないが、例えばカード種類を16種類に増やしたり、場にEXデッキ、EXモンスターゾーン、除外カードエリア、守備カード置き場、リバースカード置き場、魔法罠エリアなど用途に合わせて細分化した時、長距離の相関が重要になるかもしれない。その時、3x3の畳み込み層を増やすより、並びを入れ替えて少ない畳み込み層で相関を得る方が良いかもしれない。

・CNNモデルの改善案として、チャンネルの次元を増やす、CNNレイヤー層数を深くする、ResBlockを使う、BatchNormを使う、SENet(Squeeze-and-Excitation)を使う、depthwise convolutionとpointwise convolutionを使うなどがあるが、モデルが簡単な方が学習は早く、複雑になるほど必要な学習step数が増大すると考えるとDQN学習でCNNを改善するのも躊躇われる。実際、少し複雑にすると却って精度が下がった。ε-greedy方策のランダム行動を起こす確率の減少速度が相関するのかもしれない。
・DreamerV3とかはCNNモデルの前半部分の畳み込み部分を画像入力のVAEのEncoderを使うが、stride不要のボードゲームには不向きかもしれない。
・マリオとかの場合は空中にいる一枚の画像に対して上昇中か下降中かの区別がつかないので時系列の前の入力が必要だが、今回のモデルでは効果対象の選択時、効果元のカードはeffect場に移動するので、直前の選択肢を疑似的に覚えており、今回の範囲内なら潜在層の時系列モデルの採用は特にメリットはないように思う。

・ゲームを都合3ターンプレイしているが、初期手札の枚数を1~5枚のランダムにすれば1ターンの学習で済むかもしれない。また、この1ターンの学習の場合、次ターンでの発展性が学習出来ないが、この場合はフィールドや墓地にランダムでカードを移動して開始すればよい。
・4ターン目以降はフィールドにモンスターが一杯になる恐れがある。現在のルールでは初期手札2枚なのもあって1ターン当たり1~2体しか召喚出来ないので3ターン目までにあふれることは無く、召喚可能かはチェックしていない。
・既知のバグとしてBFエンド時にフィールドのcard3から手札のcard1の②効果で特殊召喚する時、card1の①の効果を発動可能な場合に区別できない。(BFエンド時、action指定が正しくないと①の効果でフィールドの別のcard3を何回でも手札に戻せてしまう。action指定が正しければ①の効果発動には問題なし)。これはそもそも複数の効果を持つカードにおいて抱える問題である。解決策は効果によってeffect1、effect2の場を分けるだが、結果として大した影響はないと考え無視した。

・多数の種類のカードを覚えるのは可能だろうか。言語モデル(Transformer)なら5万種類の単語を認識するが、これは畳み込みを使うCNNではない。
・実際の所、攻撃力の高いモンスターを並べるより妨害数のほうが重要である。
とはいえ一人プレイでは妨害の質に関しては評価不能であるし、妨害数を報酬とする強化学習は困難であろう。(妨害効果を適当なバーンダメージと仮定して実装するべきだろうか)

まとめ

簡単な環境で一人用カードゲームの強化学習が可能か試した。
6種類のカードから成る簡単な環境であればDQN学習でもカードゲームのプレイが出来ているように見える。
しかし、参考元のカード(ドラゴンメイド)の回し方を知らないとこの記事を読んでも多分詳細はいまひとつ分からないだろう。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?