0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

自己対戦的学習エージェントと三目並べで遊ぶ

Last updated at Posted at 2022-02-24

概要

 自己対戦による強化学習を試していきます.学習アルゴリズムはPPOを使わせていただきました.

三目並べ

 いわゆるマルバツゲーム(3×3)です.

コード

  • インポート類
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(41)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import numpy as np
import random
random.seed(41)

from copy import deepcopy
from tqdm import tqdm

  • ニューラルモデル(政策用と価値用)
  • 一部重みを共有する場合もあるが今回はパス
class CNN_policy(nn.Module):
    def __init__(self, square_size, output_dim, hidden_dim=100):
        super().__init__()
        self.square_size = square_size
        self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.output = nn.Linear((square_size**2)*hidden_dim, output_dim)

    def forward(self,x):
        x = x.reshape(-1, self.square_size, self.square_size).unsqueeze(1).float()
        x = F.relu(self.conv1(x)) + x
        x = F.relu(self.conv2(x)) + x
        x = self.output(x.reshape(x.size(0),-1))
        return x


class CNN_value(nn.Module):
    def __init__(self, square_size, hidden_dim=100):
        super().__init__()
        self.square_size = square_size
        self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.output = nn.Linear((square_size**2)*hidden_dim, 1)

    def forward(self,x):
        x = x.reshape(-1, self.square_size, self.square_size).unsqueeze(1).float()
        x = F.relu(self.conv1(x)) + x
        x = F.relu(self.conv2(x)) + x
        x = self.output(x.reshape(x.size(0),-1))
        return x

  • Transformer版
  • 縦横の位置をそれぞれ埋め込んで足し合わせてます.
from torch.nn.modules.container import ModuleList
from torch.nn.modules.normalization import LayerNorm

class SoftmaxAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.head_dim = head_dim

    def forward(self, Q, K, V, mask=None):
        logit = torch.einsum("bhld,bhmd->bhlm",Q,K)/math.sqrt(self.head_dim)

        if mask!=None:
            logit = logit + mask[:,None,:,:]
        
        attn = F.softmax(logit, dim=-1)
        X = torch.einsum("bhlm,bhmd->bhld",attn,V)

        return X

class SelfAttention(nn.Module):
    def __init__(self, dim, head_dim, num_head):
        super().__init__()

        self.dim = dim
        self.head_dim = head_dim
        self.num_head = num_head

        assert dim == head_dim*num_head,print("ASSERT #dim = head_dim * num_head")

        self.W_q = nn.Linear(self.dim, self.dim)
        self.W_k = nn.Linear(self.dim, self.dim)
        self.W_v = nn.Linear(self.dim, self.dim)

        self.attn = SoftmaxAttention(head_dim)

    def forward(self, X, sq_mask):
        Q = self.split_heads(self.W_q(X))
        K = self.split_heads(self.W_k(X))
        V = self.split_heads(self.W_v(X))

        attn_out = self.attn(Q.float(), K.float(), V.float(), sq_mask.float())
        attn_out = self.combine_heads(attn_out)
        return attn_out

    def combine_heads(self, X):
        X = X.transpose(1, 2)
        X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
        return X

    def split_heads(self, X):
        X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
        X = X.transpose(1, 2)
        return X

class TransformerEncoderLayer(nn.Module):
    def __init__(self,attn_layer,ff_layer,norm_layer,drop_layer):
        super().__init__()
        self.attn_layer = attn_layer
        self.ff_layer = ff_layer
        self.norm_layer = norm_layer
        self.drop_layer = drop_layer
    
    def forward(self,x,mask):
        x = self.drop_layer(self.attn_layer(x,mask)) + x
        x = self.norm_layer(x)
        x = self.drop_layer(self.ff_layer(x)) + x
        x = self.norm_layer(x)
        return x

class CNN_policy(nn.Module):
    def __init__(self, square_size, output_dim, hidden_dim=100,num_head=4,num_layer=3):
        super().__init__()
        self.square_size = square_size
        self.emb_x = nn.Embedding(square_size,hidden_dim)
        self.emb_y = nn.Embedding(square_size,hidden_dim)
        self.emb = nn.Linear(1,hidden_dim)
        self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(hidden_dim, int(hidden_dim/num_head), num_head),\
                            nn.Sequential(nn.Linear(hidden_dim,hidden_dim*2),nn.ReLU(),nn.Linear(hidden_dim*2,hidden_dim)),\
                                LayerNorm(hidden_dim, eps=1e-5),\
                                    nn.Dropout(0.1)) for i in range(num_layer)])
        self.output = nn.Linear(hidden_dim,output_dim)

    def forward(self,z):
        r = torch.arange(self.square_size).to(z.device)
        x_pos = torch.cat([r.unsqueeze(-1)]*self.square_size,dim=-1).reshape(-1).unsqueeze(0)
        y_pos = torch.cat([r.unsqueeze(0)]*self.square_size,dim=0).reshape(-1).unsqueeze(0)
        x_emb = self.emb_x(x_pos)
        y_emb = self.emb_y(y_pos)
        z = self.emb(z.unsqueeze(-1))
        z = z + x_emb + y_emb
        mask = torch.zeros(z.size(0),z.size(1),z.size(1)).float().to(z.device)
        for i in range(len(self.encoders)):
            z = self.encoders[i](z,mask)
        z = self.output(z.mean(1))
        return z


class CNN_value(nn.Module):
    def __init__(self, square_size, hidden_dim=100,num_head=4,num_layer=3):
        super().__init__()
        self.square_size = square_size
        self.emb_x = nn.Embedding(square_size,hidden_dim)
        self.emb_y = nn.Embedding(square_size,hidden_dim)
        self.emb = nn.Linear(1,hidden_dim)
        self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(hidden_dim, int(hidden_dim/num_head), num_head),\
                            nn.Sequential(nn.Linear(hidden_dim,hidden_dim*2),nn.ReLU(),nn.Linear(hidden_dim*2,hidden_dim)),\
                                LayerNorm(hidden_dim, eps=1e-5),\
                                    nn.Dropout(0.1)) for i in range(num_layer)])
        self.output = nn.Linear(hidden_dim,1)

    def forward(self,z):
        r = torch.arange(self.square_size).to(z.device)
        x_pos = torch.cat([r.unsqueeze(-1)]*self.square_size, dim=-1).reshape(-1).unsqueeze(0)
        y_pos = torch.cat([r.unsqueeze(0)]*self.square_size, dim=0).reshape(-1).unsqueeze(0)
        x_emb = self.emb_x(x_pos)
        y_emb = self.emb_y(y_pos)
        z = self.emb(z.unsqueeze(-1))
        z = z + x_emb + y_emb
        mask = torch.zeros(z.size(0),z.size(1),z.size(1)).float().to(z.device)
        for i in range(len(self.encoders)):
            z = self.encoders[i](z,mask)
        z = self.output(z.mean(1))
        return z

  • 価値推定モデルと政策モデル
  • 長いですが政策(policy)は状況に対して行動とその確率,価値(value)は状況に対して価値を出します.
class PolicyNetwork(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.mlp = CNN_policy(int(input_dim**0.5), output_dim)

    def forward(self, x, out_softmax=True):

        x = self.mlp(x)

        if out_softmax:
            return F.softmax(x, dim=-1)
        else:
            return x

    def sample_action(self, state, noise=False):

        if not state is torch.Tensor:
            state = torch.from_numpy(state).float().to(device)

        if len(state.size()) == 1:
            state = state.unsqueeze(0)

        with torch.no_grad():
            x_logit = self(state, out_softmax=False).squeeze()
        if noise: _x_logit = _x_logit + torch.randn(_x_logit.shape).to(_x_logit.device)*1.0
        x_logit = x_logit.masked_fill(state.squeeze()!=0, float("-inf")) #空いてるところ以外置けない
        x = F.softmax(x_logit,dim=-1)
        dist = Categorical(x)
        
        action = dist.sample()
        log_probability = dist.log_prob(action)

        return action.item(), log_probability.item()

    def best_action(self, state):

        if not state is torch.Tensor:
            state = torch.from_numpy(state).float().to(device)

        if len(state.size()) == 1:
            state = state.unsqueeze(0)

        with torch.no_grad():
            x_logit = self(state, out_softmax=False).squeeze()
        x_logit = x_logit.masked_fill(state.squeeze()!=0, float("-inf")) #空いてるところ以外置けない
        x = F.softmax(x_logit, dim=-1)
        action = torch.argmax(x)

        return action.item()

    def evaluate_actions(self, states, actions):
        
        x = self(states)

        dist = Categorical(x)
        entropy = dist.entropy()
        log_probabilities = dist.log_prob(actions)

        return log_probabilities, entropy


class ValueNetwork(torch.nn.Module):
    def __init__(self,input_dim):
        super(ValueNetwork, self).__init__()
        self.mlp = CNN_value(int(input_dim**0.5))

    def forward(self, x):
        x = self.mlp(x)
        return x.squeeze(1)

    def state_value(self, state):

        if not state is torch.Tensor:
            state = torch.from_numpy(state).float().to(device)

        if len(state.size()) == 1:
            state = state.unsqueeze(0)
        
        with torch.no_grad():
            x = self(state).squeeze()

        return x.item()

  • 価値と政策のモデル学習(こちらから引用)
def train_value_network(value_model, value_optimizer, data_loader, epochs=4):

    epochs_losses = []
    for i in range(epochs):

        losses = []
        for observations, _, _, _, rewards_to_go in data_loader:

            observations = observations.float().to(device)
            rewards_to_go = rewards_to_go.float().to(device)

            value_optimizer.zero_grad()
            values = value_model(observations)
            loss = F.mse_loss(values, rewards_to_go)
            loss.backward()
            value_optimizer.step()

            losses.append(loss.item())

        mean_loss = np.mean(losses)
        epochs_losses.append(mean_loss)

    return epochs_losses


def ac_loss(new_log_probabilities, old_log_probabilities, advantages, epsilon_clip=0.2):

    probability_ratios = torch.exp(new_log_probabilities - old_log_probabilities)
    clipped_probabiliy_ratios = torch.clamp(probability_ratios, 1 - epsilon_clip, 1 + epsilon_clip)

    surrogate_1 = probability_ratios * advantages
    surrogate_2 = clipped_probabiliy_ratios * advantages

    return -torch.min(surrogate_1, surrogate_2)


def train_policy_network(policy_model, policy_optimizer, data_loader, epochs=4, clip=0.2):
    
    epochs_losses = []
    c1 = 0.01
    for i in range(epochs):

        losses = []
        for observations, actions, advantages, log_probabilities, _ in data_loader:
            observations = observations.float().to(device)
            actions = actions.long().to(device)
            advantages = advantages.float().to(device)
            old_log_probabilities = log_probabilities.float().to(device)

            policy_optimizer.zero_grad()
            new_log_probabilities, entropy = policy_model.evaluate_actions(observations, actions)
            loss = (ac_loss(new_log_probabilities, old_log_probabilities, advantages, epsilon_clip=clip).mean() - c1 * entropy.mean())
            loss.backward()
            policy_optimizer.step()

            losses.append(loss.item())

        mean_loss = np.mean(losses)
        epochs_losses.append(mean_loss)

    return epochs_losses

  • エピソードを貯めるクラスと学習用のデータセットクラスなど(同じく引用)
def cumulative_sum(array, gamma=1.0):
    curr = 0
    cumulative_array = []

    for a in array[::-1]:
        curr = a + gamma * curr
        cumulative_array.append(curr)

    return cumulative_array[::-1]

class Episode:
    def __init__(self, gamma=0.99, lambd=0.95):
        self.observations = []
        self.actions = []
        self.advantages = []
        self.rewards = []
        self.rewards_to_go = []
        self.values = []
        self.log_probabilities = []
        self.gamma = gamma
        self.lambd = lambd

    def append(self, observation, action, reward, value, log_probability, reward_scale=20):
        self.observations.append(observation)
        self.actions.append(action)
        self.rewards.append(reward / reward_scale)
        self.values.append(value)
        self.log_probabilities.append(log_probability)

    def end_episode(self, last_value):
        rewards = np.array(self.rewards + [last_value])
        values = np.array(self.values + [last_value])

        deltas = rewards[:-1] + self.gamma * values[1:] - values[:-1]
        self.advantages = cumulative_sum(deltas.tolist(), gamma=self.gamma * self.lambd)
        self.rewards_to_go = cumulative_sum(rewards.tolist(), gamma=self.gamma)[:-1]

def normalize_list(array):
    array = np.array(array)
    array = (array - np.mean(array)) / (np.std(array) + 1e-5)
    return array.tolist()

class History(Dataset):
    def __init__(self):
        self.episodes = []
        self.observations = []
        self.actions = []
        self.advantages = []
        self.rewards = []
        self.rewards_to_go = []
        self.log_probabilities = []

    def free_memory(self):
        del self.episodes[:]
        del self.observations[:]
        del self.actions[:]
        del self.advantages[:]
        del self.rewards[:]
        del self.rewards_to_go[:]
        del self.log_probabilities[:]

    def add_episode(self, episode):
        self.episodes.append(episode)

    def build_dataset(self):
        for episode in self.episodes:
            self.observations += episode.observations
            self.actions += episode.actions
            self.advantages += episode.advantages
            self.rewards += episode.rewards
            self.rewards_to_go += episode.rewards_to_go
            self.log_probabilities += episode.log_probabilities

        # 長さが全て同じ必要がある
        assert (len({len(self.observations),len(self.actions),len(self.advantages),len(self.rewards),len(self.rewards_to_go),len(self.log_probabilities),}) == 1)
        self.advantages = normalize_list(self.advantages)

    def __len__(self):
        return len(self.observations)

    def __getitem__(self, idx):
        return (
            self.observations[idx],
            self.actions[idx],
            self.advantages[idx],
            self.log_probabilities[idx],
            self.rewards_to_go[idx],
        )

  • 過去のエージェントと対戦(先行・後攻入れ替えて50試合ずつ)
def self_play(env, now_model, old_model, state_scale):
    num_play = 50
    results = []

    for i in range(num_play):
        observation = env.reset()

        while True:
            action, _ = now_model.sample_action(observation/state_scale)
            new_observation, _, done, _ = env.step(action)
            observation = new_observation
            
            if done == 2:
                results.append(0)
                break

            if done == 1:
                results.append(1)
                break
            
            action, _ = old_model.sample_action(observation/state_scale)
            new_observation, _, done, _ = env.step(action)
            observation = new_observation

            if done == 2:
                results.append(0)
                break

            if done == 1:
                results.append(-1)
                break

    for i in range(num_play):
        observation = env.reset()

        while True:
            action, _ = old_model.sample_action(observation / state_scale)
            new_observation, _, done, _ = env.step(action)
            observation = new_observation

            if done == 2:
                results.append(0)
                break

            if done == 1:
                results.append(-1)
                break
            
            action, _ = now_model.sample_action(observation / state_scale)
            new_observation, _, done, _ = env.step(action)
            observation = new_observation

            if done == 2:
                results.append(0)
                break

            if done == 1:
                results.append(1)
                break

    print("win:", results.count(1), "lose", results.count(-1), "draw", results.count(0))
    return sum(results) > 0

  • 三目並べの環境
  • 対戦の側に関わらず常に"1"が行動側の"コマ"
class Environment():
    def __init__(self, device, square_size=3, moku=3):
        self.square_size = square_size
        self.device = device
        self.moku = moku
        self.board = torch.zeros(self.square_size**2).to(self.device)
        self.action_n = square_size**2

        self.win_reward = 100
        self.draw_reward = 20

    def reset(self):
        self.board = torch.zeros(self.square_size**2).to(self.device)
        return self.board.cpu().numpy()

    def step(self, action :int):

        self.board += F.one_hot(torch.tensor(action).to(self.device),num_classes=self.action_n)

        reward = 0
        complete = 0

        for i in range(self.square_size-self.moku+1):
            for j in range(self.square_size-self.moku+1):
                ker = self.board.reshape(self.square_size, self.square_size)[i:i+self.moku,j:j+self.moku]
                ker = ker*((ker!=-1).float()).long()
                if torch.sum(ker.reshape(self.moku, self.moku).diagonal())==self.moku or\
                    torch.sum(ker.reshape(self.moku, self.moku).fliplr().diagonal())==self.moku or\
                    torch.any(torch.sum(ker.reshape(self.moku, self.moku), dim=0)==self.moku) or\
                    torch.any(torch.sum(ker.reshape(self.moku, self.moku), dim=1)==self.moku):
                    reward = self.win_reward
                    complete = 1

        if torch.all(self.board!=0) and complete == 0:
            reward = self.draw_reward
            complete = 2

        self.board = -1 * self.board + 0
        return self.board.cpu().numpy(), reward, complete, None

  • メインループ
def main(reward_scale, clip, learning_rate, state_scale):

    env = Environment(device, square_size=3, moku=3) #盤面の大きさと何目並べか
    observation = env.reset()

    n_actions = env.action_n
    feature_dim = observation.size

    value_model = ValueNetwork(input_dim=feature_dim).to(device)
    value_optimizer = optim.Adam(value_model.parameters(), lr=learning_rate)
    policy_model = PolicyNetwork(input_dim=feature_dim, output_dim=n_actions).to(device)
    policy_optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)

    policy_model_old, value_model_old = deepcopy(policy_model), deepcopy(value_model)

    n_epoch = 1
    max_episodes = 100
    batch_size = 10000
    max_iterations = 20000

    history = History()

    epoch_ite = 0
    episode_ite = 0

    for ite in tqdm(range(max_iterations), desc="[iteration]"):

        if ite % 10 == 0 and ite > 0:
            torch.save(policy_model.state_dict(),"now_policy.pth")
            if self_play(env, policy_model, policy_model_old, state_scale=state_scale):
                print("win step up!")
                policy_model_old,value_model_old = deepcopy(policy_model),deepcopy(value_model)
                torch.save(policy_model.state_dict(),"best_policy.pth")

        for episode_i in (range(max_episodes)):
            observation = env.reset()
            if random.random() > 0.5:
                enemy_action, _ = policy_model.sample_action(observation/state_scale, noise=True)
                new_observation, _, _, _ = env.step(enemy_action)
                observation = new_observation

            episode = Episode()
            while True:
                action, log_probability = policy_model.sample_action(observation/state_scale, noise=True)
                value = value_model.state_value(observation / state_scale)
                new_observation, reward, done, _ = env.step(action)

                episode.append(
                    observation=observation / state_scale,
                    action=action,
                    reward=reward,
                    value=value,
                    log_probability=log_probability,
                    reward_scale=reward_scale,
                )

                if done:
                    episode.end_episode(last_value=0)
                    break

                observation = new_observation

                enemy_action, _ = policy_model.sample_action(observation/state_scale, noise=True)

                new_observation, reward, done, _ = env.step(enemy_action)
                observation = new_observation

                if done==2:
                    episode.rewards[-1] = reward*0.99 / reward_scale
                    episode.end_episode(last_value=0)
                    break

                if done==1:
                    episode.rewards[-1] = -reward*0.99 / reward_scale
                    episode.end_episode(last_value=0)
                    break

            episode_ite += 1
            history.add_episode(episode)

        history.build_dataset()
        data_loader = DataLoader(history, batch_size=batch_size, shuffle=True)

        policy_loss = train_policy_network(policy_model, policy_optimizer, data_loader, epochs=n_epoch, clip=clip)
        value_loss = train_value_network(value_model, value_optimizer, data_loader, epochs=n_epoch)
        history.free_memory()

if __name__ == "__main__":
    main(reward_scale=20.0, clip=0.2, learning_rate=0.001, state_scale=1.0)

  • 学習したモデルで遊びます.(エージェントが先行 我々のコマはOです)
play_gui.py
import time
import tkinter

board_size = 3
moku = 3

policy = PolicyNetwork(board_size**2, board_size**2).to(device)
policy.load_state_dict(torch.load("best_policy.pth"))
policy.eval()

env = Environment(device, square_size=board_size, moku=moku)

action_enemy_first = policy.best_action(env.board.cpu().numpy())
board, reward, done, _ = env.step(action_enemy_first)

def step(_i,_j):
    action_me = board_size*_i+_j
    board, reward, done, _ = env.step(action_me)
    buttons[_i*board_size+_j].config(text="O")
    if done:
        root.update()
        time.sleep(.5)
        exit(0)
    
    action_enemy = policy.best_action(env.board.cpu().numpy())
    board, reward, done_, _ = env.step(action_enemy)
    buttons[action_enemy].config(text="X")
    if done_:
        root.update()
        time.sleep(.5)
        exit(0)

root = tkinter.Tk()
root.title('Tkinter')
root.geometry("400x400")

class Click:
    def __init__(self, _i, _j):
        self.i = _i
        self.j = _j

    def click(self):
        step(self.i, self.j)

buttons = []
for i in range(board_size):
    for j in range(board_size):
        if action_enemy_first//board_size==i and action_enemy_first%board_size==j:
            button = tkinter.Button(root, text="X", command=Click(i,j).click, height = 3, width = 6)
        else:
            button = tkinter.Button(root, text="-", command=Click(i,j).click, height = 3, width = 6)
        buttons.append(button)

for i in range(board_size):
    for j in range(board_size):
        buttons[board_size*i+j].grid(row=i, column=j)

root.mainloop()

結果

手を抜いてプレイ

瞬殺されました

本気でプレイ

どうあがいても引き分けです

まとめ

  • 自己対戦とPPOで三目並べを学習してみました.
  • うまく動作しているようにも見えますが,現在ソースコード検証中です.

最後に

誤っている箇所などございましたら,コメント欄などで優しく教えてください.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?