10
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

三目並べ(と四目並べ)を強化学習で攻略してみた

Last updated at Posted at 2021-08-24

こちらから学習済みのエージェントと対戦できます!
https://colab.research.google.com/drive/1AfgMy6YQmnakq0RQlCttw4pk0v-awD2f?usp=sharing

今回紹介するプログラムの全ては GitHub にアップロードしています。
https://github.com/yousukeayada/TicTacToe-RL

TL;DR

  • Q 学習で三目並べを攻略するプログラムをフルスクラッチで実装した。
  • 学習した結果、ランダムに行動する CPU との勝率が、先攻の場合 89 % 、後攻の場合 69 % となった。
  • 盤面の大きさが 4 の場合でも学習させてみた。

環境

  • macOS Catalina 10.15.6
  • Python 3.8.1

三目並べの定式化

エージェント(Agent)

  • Q 学習を行うエージェントは「◯」を使うものとします。
  • 相手のランダムエージェントは「×」を使うものとします。

状態(state)

  • 3x3 の各マスに対して「何も置かれていない」「自分のコマが置かれている」「相手のコマが置かれている」の 3 状態があるので、 $3^9 = 19683$ 通り
  • それぞれの状態に番号を割り当てます。例えば、以下の図だと 6730 になります。
    • 盤面の各マスの状態とその番号の変換は 3 進数を考えるとわかりやすいです。(「何も置かれていない」→0、「自分のコマが置かれている」→1、「相手のコマが置かれている」→2、とすると、図の状態は 100020021 と表せます。よって、 $3^8 \times 1 + 3^7 \times 0 +...+3^0 \times 1 = 6730$ となります。)

state.png

行動(action)

  • どのマス目に置くかを行動とします。
  • プログラム内では以下の図のように、各マス目の番号とその座標を相互に変換しながら使います。

action.png

報酬(reward)

  • 勝ったとき +1 、負けたとき -1 とします。

Q table

  • テーブルの大きさは(状態数)x(行動数)、つまり $19683 \times 9$ となります。
  • 実際には「この状態ではこのマスには置けない」ということがあるので、あり得る Q 値の数はもっと少なくなります。
state / action 0 1 2 3 4 5 6 7 8
state_0.png $Q(0,0)$ $Q(0,1)$ ...
state_1.png $Q(1,0)$ $Q(1,1)$ ...
...

Q table 更新

  • 1 回対戦する毎に Q table を以下の式に従って更新します。($\alpha$ は学習率、 $\gamma$ は割引率)

$$Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha (R_{t+1} + \gamma \max_a Q(s_{t+1},a) - Q(s_t,a_t))$$

  • 自分と相手が交互に行動するため、更新の際にどの状態・行動・報酬を使うのか注意する必要があります。
    • 今回は勝ったときと負けたときで分けます。

勝ったとき

  • 例えば以下のように状態が遷移したとすると、状態 7823 のときに行動 1 を選択して勝ったので、 $Q(7823,1)$ を更新します。(報酬は +1)
7823 10010
state_7823.png state_10010.png

負けたとき

  • 例えば以下のように状態が遷移したとすると、負けた原因は状態 5659 のときに行動 4 を選ばず行動 0 を選んでしまったためだと考えられます。
  • つまり、更新するのは $Q(5659,0)$ となります。(報酬は -1)
5659 12220 12382
state_5659.png state_12220.png state_12382.png

実装

では実際にプログラムを書いていきます。おおまかな全体図は以下のようになります。
class_fig.png

まずは Board クラスです。
なお、プログラム内では「◯」= BLACK 、「×」= WHITE としています。

Board.py
import logging
from enum import IntEnum, auto

import numpy as np
import matplotlib.pyplot as plt


logger = logging.getLogger(__name__)

class Piece(IntEnum):
    EMPTY = auto()
    BLACK = auto()
    WHITE = auto()

class Winner(IntEnum):
    DRAW  = auto()
    BLACK = auto()
    WHITE = auto()


class Board:
    def __init__(self, size=3):
        self.size = size
        self.reset_stage()

    def put_piece(self, x: int, y: int, piece: Piece) -> Winner:
        logger.debug(f"Put {piece.name} on ({x}, {y})")
        if piece == Piece.EMPTY:
            raise Exception("Invalid Piece")
        if self.stage[y][x] != Piece.EMPTY:
            raise Exception("Already exists Piece on that position")

        self.stage[y][x] = piece
        self.empties.remove((x, y))

        winner = self.judge(x, y, piece)
        return winner
    
    def can_put(self, x: int, y: int) -> bool:
        return (x, y) in self.empties
    
    def judge(self, x: int, y: int, piece: Piece) -> Winner:
        winner = None
            
        # 行チェック
        cnt = 0
        for i in range(self.size):
            if self.stage[i][x] == piece:
                cnt += 1
        if cnt == self.size:
            if piece == Piece.BLACK:
                winner = Winner.BLACK
            elif piece == Piece.WHITE:
                winner = Winner.WHITE
            return winner
        
        # 列チェック
        cnt = 0
        for j in range(self.size):
            if self.stage[y][j] == piece:
                cnt += 1
        if cnt == self.size:
            if piece == Piece.BLACK:
                winner = Winner.BLACK
            elif piece == Piece.WHITE:
                winner = Winner.WHITE
            return winner

        # 対角線チェック
        cnt = 0
        for i in range(self.size):
            if self.stage[i][i] == piece:
                cnt += 1
        if cnt == self.size:
            if piece == Piece.BLACK:
                winner = Winner.BLACK
            elif piece == Piece.WHITE:
                winner = Winner.WHITE
            return winner

        cnt = 0
        for i in range(self.size):
            if self.stage[i][self.size-i-1] == piece:
                cnt += 1
        if cnt == self.size:
            if piece == Piece.BLACK:
                winner = Winner.BLACK
            elif piece == Piece.WHITE:
                winner = Winner.WHITE
            return winner

        # 揃ってないかつ置けるところがない
        if len(self.empties) == 0:
            winner = Winner.DRAW

        return winner

    def reset_stage(self) -> None:
        self.stage = [[Piece.EMPTY for i in range(self.size)] for j in range(self.size)]
        self.empties = [(i, j) for i in range(self.size) for j in range(self.size)]

    def show_stage(self) -> None:
        x1 = []
        y1 = []
        x2 = []
        y2 = []
        for i in range(self.size):
            for j in range(self.size):
                if self.stage[i][j] == Piece.BLACK:
                    x1.append(j+0.5)
                    y1.append(i+0.5)
                elif self.stage[i][j] == Piece.WHITE:
                    x2.append(j+0.5)
                    y2.append(i+0.5)
        fig = plt.figure()
        ax = fig.add_subplot(aspect="equal")
        ax.grid()
        plt.xlim([0, self.size])
        plt.ylim([0, self.size])
        plt.xticks(np.arange(0, self.size+1, step=1))
        plt.yticks(np.arange(0, self.size+1, step=1))
        ax.scatter(x1, y1, s=1000, marker="o", facecolor="None", edgecolors="blue")
        ax.scatter(x2, y2, s=1000, marker="x", c="red")
        plt.show()

    def test(self):
        self.put_piece(1,1,Piece.BLACK)
        self.put_piece(2,1,Piece.WHITE)
        self.show_stage()
        self.put_piece(1,0,Piece.BLACK)
        self.put_piece(0,1,Piece.WHITE)
        self.show_stage()
        self.put_piece(1,2,Piece.BLACK)
        self.show_stage()
        self.reset_stage()
        self.show_stage()


このクラスのテストをしたい場合は test() を呼んでください。

test_board.py
from Board import *

board = Board()
board.test()

次に、エージェントクラスを実装します。

Agent.py
from abc import *


class Agent(metaclass=ABCMeta):
    @abstractmethod
    def decide_action(self):
        raise NotImplementedError()

QLAgent.py
import random
import logging

import numpy as np

from Agent import *


logger = logging.getLogger(__name__)

class QLAgent(Agent):
    def __init__(self, num_states, actions, alpha=0.1, gamma=0.9):
        self.num_states  = num_states
        self.actions     = actions
        self.num_actions = len(actions)

        self.rng = np.random.default_rng()
        # self.q_table = self.rng.uniform(-1, 1, size=(self.num_states, self.num_actions))
        self.q_table = np.zeros((self.num_states, self.num_actions))

        self.epsilon = 0.1

        self.alpha = alpha
        self.gamma = gamma

    def decide_action(self, state):
        if self.rng.uniform() < self.epsilon:
            return self.decide_random_action()
        else:
            return self.decide_optimal_action(state)

    def decide_random_action(self):
        return random.choice(self.actions)

    def decide_optimal_action(self, state):
        return np.nanargmax(self.q_table[state])

    def update_q_table(self, exp):
        state, action, next_state, reward = exp
        q_s_a    = self.q_table[state][action]
        max_q_ns = np.nanmax(self.q_table[next_state])
        self.q_table[state][action] += self.alpha * (reward + self.gamma * max_q_ns - q_s_a)
        logger.debug(f"Q({state},{action}): {q_s_a} -> {self.q_table[state][action]}")

    def set_q_value(self, state, action, value):
        self.q_table[state][action] = value

    def save_q_table(self, path):
        np.save(path, self.q_table)

    def load_q_table(self, path):
        self.q_table = np.load(path)

    def test(self):
        state = 500
        print("Policy action\tRandom action")
        for i in range(10):
            print(f"{i+1} 回目: {self.decide_action(state)}\t{self.decide_random_action()}")
        print(f"Optimal action: {self.decide_optimal_action(state)}")

        print("Training...")
        action, next_state, reward = 1, 1000, 1
        exp = (state, action, next_state, reward)
        for i in range(10):
            self.update_q_table(exp)
        print(f"Optimal action: {self.decide_optimal_action(state)}")

このクラスのテストをしたい場合も Board 同様 test() を呼んでください。

test_agent.py
from QLAgent import *

num_states = 2000
actions = [0, 1, 2, 3, 4, 5, 6, 7, 8]
agent = QLAgent(num_states, actions)
agent.test()

最後に、三目並べゲーム部分を実装します。

TicTacToe.py
from itertools import product
from enum import IntEnum, auto

import numpy as np

from Board import *


class Turn(IntEnum):
    FIRST  = 0
    SECOND = 1


class TicTacToe:
    def __init__(self):
        self.size        = 3
        self.num_squares = self.size * self.size

        self.board   = Board(size=self.size)

    def reset(self):
        self.board.reset_stage()
        state = 0
        return state

    def step(self, action, piece):
        x, y = action % self.size, int(action / self.size)
        try:
            winner = self.board.put_piece(x, y, piece)

            next_state = self.convert_to_state(self.board.stage)

            if winner:
                done = True
                if winner == Winner.DRAW:
                    reward = 0
                else:
                    reward = 1
            else:
                reward, done = 0, False
            return next_state, reward, done, winner
        except Exception as e:
            logger.info(e)
            return None, np.nan, False, None

    def check(self, action):
        x, y = action % self.size, int(action / self.size)
        return self.board.can_put(x, y)
        
    def convert_to_state(self, stage):
        s = [stage[i][j] for i in range(self.size) for j in range(self.size)]
        index = 0
        for i in range(self.num_squares):
            index += (s[i]-1) * (len(Piece) ** (self.num_squares-i-1))
        return index

学習させてみる

Q 学習を行うエージェントとランダムな行動をとるエージェントを戦わせて、どれぐらい勝率が上がるのか確かめます。

ランダム vs ランダム

学習の前に、ランダム同士で戦わせたときの勝率を調べます。
10 万回対戦させた結果が以下になります。グラフ中の First が先攻、 Second が後攻を指します。

勝率 / エピソード 10 万回
先攻 58.8 %
後攻 28.9 %
合計 43.9 %

wp_3_random.png

QL vs ランダム

プログラムは以下のようになります。
先攻、後攻はランダムで決めます。

train.py
import argparse
import sys
import logging
import random

import numpy as np
import matplotlib.pyplot as plt

from TicTacToe import *
from QLAgent import *


logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description="RL parameter")
parser.add_argument('--alpha', type=float, default=0.1)
parser.add_argument('--gamma', type=float, default=0.9)
args = parser.parse_args()

alpha = args.alpha
gamma = args.gamma


env = TicTacToe()
# env.board.show_stage()

num_states = len(Piece) ** env.num_squares
actions    = [i for i in range(env.num_squares)]
agent1 = QLAgent(num_states, actions, alpha=alpha, gamma=gamma) # training
agent2 = QLAgent(num_states, actions) # random

# プロット用
win_cnt      = [dict.fromkeys([w.name for w in Winner], 0) for i in Turn]
win_rate     = [[] for i in Turn]
win_rate_all = []

episode_interval = 1000
for episode in range(10_0000):
    logger.debug(f"--- Start episode {episode+1} ---")
    
    state = env.reset()
    TURN = random.choice(list(Turn))
    turn_count = TURN.value
    logger.debug(f"QLAgent is {TURN.name}")

    while True:
        if turn_count % 2 == Turn.FIRST:
            logger.debug("Agent1's turn")
            while True:
                action = agent1.decide_action(state)
                # action = agent1.decide_random_action()
                logger.debug(f"state: {state}, action: {action}")
                if not env.check(action):
                    logger.debug("Invalid action")
                    agent1.set_q_value(state, action, np.nan)
                    continue
                
                next_state, reward, done, winner = env.step(action, Piece.BLACK)
                logger.debug(f"next state: {next_state}, reward: {reward}, done: {done}")
                if np.isnan(reward):
                    agent1.set_q_value(state, action, reward)
                    continue
                else: # 正しく置けたとき
                    break
            if done:
                logger.debug(f"state: {prev_state} -> {state} -> {next_state}, action: {prev_action} -> {action}, reward: {prev_reward} -> {reward}")
                logger.debug(f"Updates Agent1 (reward: {reward})")
                exp = (state, action, next_state, reward)
                agent1.update_q_table(exp)
                break
        else:
            logger.debug("Agent2's turn")
            while True:
                action = agent2.decide_random_action()
                logger.debug(f"state: {state}, action: {action}")
                if not env.check(action):
                    logger.debug("Invalid action")
                    agent2.q_table[state][action] = np.nan
                    continue
                
                next_state, reward, done, winner = env.step(action, Piece.WHITE)
                logger.debug(f"next state: {next_state}, reward: {reward}, done: {done}")
                if np.isnan(reward):
                    agent2.q_table[state][action] = reward
                    continue
                else: # 正しく置けたとき
                    if reward == 1:
                        reward = -1
                    break
            if done:
                logger.debug(f"state: {prev_state} -> {state} -> {next_state}, action: {prev_action} -> {action}, reward: {prev_reward} -> {reward}")
                logger.debug(f"Updates Agent1 (reward: {reward})")
                exp = (prev_state, prev_action, next_state, reward)
                agent1.update_q_table(exp)
                break

        prev_state, prev_action, prev_reward = state, action, reward
        state = next_state
        turn_count += 1
        # env.board.show_stage()


    logger.debug(f"Result: {winner.name}")
    win_cnt[TURN.value][winner.name] += 1
    win_rate[TURN.value].append(win_cnt[TURN.value][Winner.BLACK.name] / sum(win_cnt[TURN.value].values()))
    win_rate_all.append((win_cnt[0][Winner.BLACK.name]+win_cnt[1][Winner.BLACK.name]) / (episode+1))
    logger.debug(f"--- Finish episode {episode+1} ---")
    if episode % episode_interval == 0:
        logger.info(f"--- Finish episode {episode+1} ---")
    # env.board.show_stage()


logger.info(win_cnt)
logger.info(f"WP: {win_rate[0][-1]} {win_rate[1][-1]} {win_rate_all[-1]}")


plt.plot(win_rate_all, label="Total")
plt.plot(win_rate[0], label="First")
plt.plot(win_rate[1], label="Second")
plt.ylim([0,1])
plt.xlabel("Episode")
plt.ylabel("Winning percentage")
plt.legend()
plt.savefig("wp.png")

agent1.save_q_table("q_table")

エピソード(対戦回数)を変えて学習した結果を比較してみます。

勝率 / エピソード 1000 回 10000 回 100000 回 1000000 回
先攻 78.5 % 85.3 % 88.7 % 89.8 %
後攻 47.5 % 61.3 % 67.0 % 69.2 %
合計 62.4 % 73.3 % 77.9 % 79.5 %

無事上手く学習できました!
それぞれプロットしたグラフは以下です。

  • 1000 回

wp_3_ql_1k.png

  • 10000 回

wp_3_ql_10k.png

  • 100000 回

wp_3_ql_100k.png

  • 1000000 回

wp_3_ql_1000000.png

実際に学習後のエージェントと戦ってみると、最適な手を打ってくることがわかります。
https://colab.research.google.com/drive/1AfgMy6YQmnakq0RQlCttw4pk0v-awD2f?usp=sharing

盤面を大きくする

  • 盤面の大きさを 4x4 にして学習させてみます。
  • ただ、このとき状態数が $3^{16} = 43046721$ と膨大になるので、かなりの学習が必要となることが予想されます。

ランダム vs ランダム

勝率 / エピソード 10 万回
先攻 31.8 %
後攻 26.9 %
合計 29.4 %

wp_4_random_100k.png

QL vs ランダム

勝率 / エピソード 10 万回 100 万回 1000 万回 1 億回
先攻 70.5 % 73.4 % 76.4 % 78.1 %
後攻 52.6 % 55.8 % 58.8 % 60.7 %
合計 61.5 % 64.6 % 67.6 % 69.4 %
  • 10 万回

wp_4_ql_100k.png

  • 100 万回

wp_4_ql_1m.png

  • 1000 万回

wp_4_ql_10m.png

  • 1 億回

wp_4_ql_100m.png

3x3 のときほどではないですが、こちらもうまく学習できたようです。

まとめ

  • 強化学習によって、三目並べで最適な手を打てるようにエージェントを学習させた。
  • 3x3 だけでなく、4x4 でもうまく学習できた。
    • 5x5 のときは状態数が $3^{25} = 847288609443$ となり、Q 学習では流石に無理があると思います。状態を連続値として扱える DQN ならできるかもしれません。
10
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?