1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【オセロAIを作る①】SLポリシーネットワーク

Posted at

はじめに

creversiというPythonライブラリ[1]を使って、オセロAIを作ります。

強いオセロAIも作りますが、最終的には
「いい感じに強さを調整して引き分けになるようにしてくれる、接待オセロAI
を作ることが目標です。

プロの「次の1手」を予測するSLポリシーネットワーク

まずは、ある局面が与えられたときに次の一手を予測するモデルを学習させます。

image.png

そのためには

  • 入力:現在の局面(8チャンネル画像)
  • 出力:次の一手(64次元ベクトル)

なるCNNを教師あり学習で構築すれば良いです。

訓練データは、オセロ世界大会の棋譜[2]を使用します。

入力画像のチャンネルについて

入力画像は次の8チャンネルとしました。

  • 黒石の位置
  • 白石の位置
  • 空白の位置
  • 合法手の位置
  • そこに打った場合、何個石を返せるか
  • 「隅」「C」「X」を1で埋める(画像参照)
  • すべて1で埋める
  • すべて0で埋める

image.png
※画像は[3]から引用

実装

ライブラリ

creversiライブラリはpipでインストールすることができます。

pip install creversi
# リバーシ用ライブラリ
from creversi import Board,move_to_str,move_from_str,move_rotate90,move_rotate180,move_rotate270
import creversi
# 基礎ライブラリ
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from copy import copy
# 学習用ライブラリ
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split

棋譜データの読み込み

データベース[2]からダウンロードした.wtbファイルを変換用サイト[4]で.csvに変換しましょう。

変換できたら、次のコードでデータを読み込みます。

def parse(x):
    move_arr = np.zeros(60,dtype=int)
    move = [move_from_str(x[i:i+2]) for i in range(0, len(x), 2)]
    move_arr[:len(move)] = move
    move_arr[len(move):] = -1
    return move_arr

def read_data(year):
    """1年分のデータを読み込む関数"""
    df = pd.read_csv(f"./reversi-datasets/wthor_{year}.csv")
    df = df["transcript"].apply(parse).apply(pd.Series)
    return df.values

# 47年分のデータを読み込む
for y in tqdm(range(1977,2023)):
    d = read_data(y)
    if y==1977:
        data = d
    else:
        data = np.concatenate([data, d])

ネットワークの定義

ネットワークは、次のような13層のCNNとしました。

  • 第1層:5x5の80種類のフィルター+ReLU関数
  • 第2-12層:3x3の80種類のフィルター+ReLU関数
  • 第13層:1x1の1種類のフィルター
class PolicyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        n_filters = 80
        self.input_layer = nn.Sequential(
            nn.Conv2d(8,n_filters,kernel_size=5,padding=2),
            nn.ReLU()
        )
        self.hidden_layer = nn.Sequential(
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(n_filters,n_filters,kernel_size=3,padding=1),
            nn.ReLU()
        )
        self.output_layer = nn.Sequential(
            nn.Conv2d(n_filters,1,kernel_size=1),
            nn.Flatten()
        )
        
    def forward(self,x):
        out = self.input_layer(x)
        out = self.hidden_layer(out)
        out = self.output_layer(out)
        return out

便利な関数の定義

ここで、関数を定義します。

  • load_sub_data(data, n, n_split):前節でロードしたndarray形式のdataをn_split分割して、n番目のデータを取り出す関数。GPUのメモリ節約のためです。
  • board_to_array(board):creversiの盤面オブジェクトcreversi.Boardを、モデル入力用の画像形式(ndarray型)に変換する関数。
def load_sub_data(data, n, n_split=10):
    """dataをn_split(=10)分割したうちのn番目(n=0,1,...)のブロック(sub-data)をロード"""
    N = data.shape[0]
    assert n_split <= N
    N_batch = N // n_split
    X = []
    y = []

    for i in tqdm(range(N_batch*n, N_batch*(n+1))):
        board = Board()
        for j,move in enumerate(data[i]):
            if move == -1:
                break
            if list(board.legal_moves) != [64]: # パスの局面ではない場合
                X.append(board_to_array(board))
                y.append(move)
                board.move(move)
            else:  # パスの局面の場合
                board.move_pass()
    X = np.array(X).astype(np.float32)
    y = np.array(y).astype(np.int64)
    print(f'X:{X.shape}, y:{y.shape}')
    return X, y
def board_to_array(board):
    """boardオブジェクトから入力画像(8チャンネル,ndarray)に変換する関数。"""
    b = np.zeros((8,8,8), dtype=np.float32)
    board.piece_planes(b)
    if not board.turn:
        b = b[[1,0,2,3,4,5,6,7],:,:]
    b[2] = np.where(b[0]+b[1]==1, 0, 1)
    legal_moves = list(board.legal_moves)
    if legal_moves != [64]:
        n_returns = []
        for move in legal_moves:
            board_ = copy(board)
            n_before = board_.opponent_piece_num()
            board_.move(move)
            n_after = board_.piece_num()
            n_returns.append(n_before-n_after)
        tmp = np.zeros(64)
        tmp[legal_moves] = n_returns
        tmp = tmp.reshape(8,8)
        b[3] = np.where(tmp > 0,1,0)
        b[4] = tmp
    b[5] = np.array([1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.]).reshape(8,8)
    b[6] = 1
    return b

学習

学習コードは以下のとおりです。

  • n_epoch:エポック数
  • n_batch:バッチサイズ
  • n_interval:loss出力の頻度
  • n_split:データの分割数(load_sub_data()の第3引数)
  • lr:学習率
  • 学習率はサブデータごとに×0.8ずつ指数関数的に減衰させる
  • 最適化手法:Adam
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

n_epoch = 5
n_batch = 256
n_interval = 1
n_split = 8
lr = 0.001

model = PolicyNetwork().to(device)
optim = torch.optim.Adam(model.parameters(),lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optim,step_size=1,gamma=0.8)
criterion = nn.CrossEntropyLoss()
train_loss_list = []
valid_loss_list = []
acc_list = []

for n in range(n_split):
    X_tr,X_va,y_tr,y_va = train_test_split(*load_sub_data(data,n,n_split), test_size=0.1, random_state=123)
    X_tr,X_va,y_tr,y_va = torch.from_numpy(X_tr),torch.from_numpy(X_va),torch.from_numpy(y_tr),torch.from_numpy(y_va)
    n_train_data = len(X_tr)
    n_valid_data = len(X_va)
    print(f'----subData{n}(lr:{scheduler.get_last_lr()})----')
    print(f'train:{n_train_data}, valid:{n_valid_data}')
    
    for epoch in range(n_epoch):
        train_loss = 0.
        random_idx = np.random.permutation(n_train_data)
        for i in range(n_train_data//n_batch):
            X_batch = X_tr[random_idx[n_batch*i:n_batch*(i+1)]].to(device)
            y_batch = y_tr[random_idx[n_batch*i:n_batch*(i+1)]].to(device)

            optim.zero_grad()
            output = model(X_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optim.step()
            train_loss += loss.item()
        train_loss_list.append(train_loss / (n_train_data//n_batch))

        if epoch % n_interval == 0:
            valid_loss = 0.
            correct = 0
            idx = np.arange(n_valid_data)
            for i in range(n_valid_data//n_batch):
                X_batch = X_va[idx[n_batch*i:n_batch*(i+1)]].to(device)
                y_batch = y_va[idx[n_batch*i:n_batch*(i+1)]].to(device)
                pred = model(X_batch)
                valid_loss += criterion(pred, y_batch).item()
                correct += (pred.argmax(axis=1) == y_batch).sum().item()
            valid_loss_list.append(valid_loss / (n_valid_data//n_batch))
            acc = correct / ((n_valid_data//n_batch)*n_batch) * 100
            acc_list.append(acc)
            print(f'[{epoch}] train loss: {train_loss / (n_train_data//n_batch):.5f} valid loss: {valid_loss / (n_valid_data//n_batch):.5f} valid acc: {acc:.3f}%')
    scheduler.step()

学習結果

lossとaccuracyのエポック推移は次のようになりました。
image.png
最終的な結果は

train loss: 1.13232 valid loss: 1.21434 valid acc: 55.380%

でした。

ランダムプレイヤーとSLポリシーネットワークを戦わせてみる

学習させたSLポリシーネットワークと、ランダムに手を選ぶプレイヤーを戦わせてみました。

SLポリシーネットワークは、確率最大の手を選ばせるようにしました。

1000回戦わせて、勝率は95.6%となりました。

image.png

次回

次回は「探索」についてやっていこうと思います。

参考サイト・文献

[1] TadaoYamaoka「creversi - 高速なPythonのリバーシライブラリ」(Github)
https://github.com/TadaoYamaoka/creversi

[2] フランスオセロ連盟「WTHOR - オセロの棋譜データベース」
https://www.ffothello.org/informatique/la-base-wthor/

[3] 「オセロ盤面の呼び方」
https://bassy84.net/othello-banmen.html

[4] .wtbから.csvに変換するサイト
https://lavox.github.io/wthor.html

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?