概要
自己対戦による強化学習を試していきます.学習アルゴリズムはPPOを使わせていただきました.
三目並べ
いわゆるマルバツゲーム(3×3)です.
コード
- インポート類
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(41)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
import random
random.seed(41)
from copy import deepcopy
from tqdm import tqdm
- ニューラルモデル(政策用と価値用)
- 一部重みを共有する場合もあるが今回はパス
class CNN_policy(nn.Module):
def __init__(self, square_size, output_dim, hidden_dim=100):
super().__init__()
self.square_size = square_size
self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
self.output = nn.Linear((square_size**2)*hidden_dim, output_dim)
def forward(self,x):
x = x.reshape(-1, self.square_size, self.square_size).unsqueeze(1).float()
x = F.relu(self.conv1(x)) + x
x = F.relu(self.conv2(x)) + x
x = self.output(x.reshape(x.size(0),-1))
return x
class CNN_value(nn.Module):
def __init__(self, square_size, hidden_dim=100):
super().__init__()
self.square_size = square_size
self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
self.output = nn.Linear((square_size**2)*hidden_dim, 1)
def forward(self,x):
x = x.reshape(-1, self.square_size, self.square_size).unsqueeze(1).float()
x = F.relu(self.conv1(x)) + x
x = F.relu(self.conv2(x)) + x
x = self.output(x.reshape(x.size(0),-1))
return x
- Transformer版
- 縦横の位置をそれぞれ埋め込んで足し合わせてます.
from torch.nn.modules.container import ModuleList
from torch.nn.modules.normalization import LayerNorm
class SoftmaxAttention(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.head_dim = head_dim
def forward(self, Q, K, V, mask=None):
logit = torch.einsum("bhld,bhmd->bhlm",Q,K)/math.sqrt(self.head_dim)
if mask!=None:
logit = logit + mask[:,None,:,:]
attn = F.softmax(logit, dim=-1)
X = torch.einsum("bhlm,bhmd->bhld",attn,V)
return X
class SelfAttention(nn.Module):
def __init__(self, dim, head_dim, num_head):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.num_head = num_head
assert dim == head_dim*num_head,print("ASSERT #dim = head_dim * num_head")
self.W_q = nn.Linear(self.dim, self.dim)
self.W_k = nn.Linear(self.dim, self.dim)
self.W_v = nn.Linear(self.dim, self.dim)
self.attn = SoftmaxAttention(head_dim)
def forward(self, X, sq_mask):
Q = self.split_heads(self.W_q(X))
K = self.split_heads(self.W_k(X))
V = self.split_heads(self.W_v(X))
attn_out = self.attn(Q.float(), K.float(), V.float(), sq_mask.float())
attn_out = self.combine_heads(attn_out)
return attn_out
def combine_heads(self, X):
X = X.transpose(1, 2)
X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
return X
def split_heads(self, X):
X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
X = X.transpose(1, 2)
return X
class TransformerEncoderLayer(nn.Module):
def __init__(self,attn_layer,ff_layer,norm_layer,drop_layer):
super().__init__()
self.attn_layer = attn_layer
self.ff_layer = ff_layer
self.norm_layer = norm_layer
self.drop_layer = drop_layer
def forward(self,x,mask):
x = self.drop_layer(self.attn_layer(x,mask)) + x
x = self.norm_layer(x)
x = self.drop_layer(self.ff_layer(x)) + x
x = self.norm_layer(x)
return x
class CNN_policy(nn.Module):
def __init__(self, square_size, output_dim, hidden_dim=100,num_head=4,num_layer=3):
super().__init__()
self.square_size = square_size
self.emb_x = nn.Embedding(square_size,hidden_dim)
self.emb_y = nn.Embedding(square_size,hidden_dim)
self.emb = nn.Linear(1,hidden_dim)
self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(hidden_dim, int(hidden_dim/num_head), num_head),\
nn.Sequential(nn.Linear(hidden_dim,hidden_dim*2),nn.ReLU(),nn.Linear(hidden_dim*2,hidden_dim)),\
LayerNorm(hidden_dim, eps=1e-5),\
nn.Dropout(0.1)) for i in range(num_layer)])
self.output = nn.Linear(hidden_dim,output_dim)
def forward(self,z):
r = torch.arange(self.square_size).to(z.device)
x_pos = torch.cat([r.unsqueeze(-1)]*self.square_size,dim=-1).reshape(-1).unsqueeze(0)
y_pos = torch.cat([r.unsqueeze(0)]*self.square_size,dim=0).reshape(-1).unsqueeze(0)
x_emb = self.emb_x(x_pos)
y_emb = self.emb_y(y_pos)
z = self.emb(z.unsqueeze(-1))
z = z + x_emb + y_emb
mask = torch.zeros(z.size(0),z.size(1),z.size(1)).float().to(z.device)
for i in range(len(self.encoders)):
z = self.encoders[i](z,mask)
z = self.output(z.mean(1))
return z
class CNN_value(nn.Module):
def __init__(self, square_size, hidden_dim=100,num_head=4,num_layer=3):
super().__init__()
self.square_size = square_size
self.emb_x = nn.Embedding(square_size,hidden_dim)
self.emb_y = nn.Embedding(square_size,hidden_dim)
self.emb = nn.Linear(1,hidden_dim)
self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(hidden_dim, int(hidden_dim/num_head), num_head),\
nn.Sequential(nn.Linear(hidden_dim,hidden_dim*2),nn.ReLU(),nn.Linear(hidden_dim*2,hidden_dim)),\
LayerNorm(hidden_dim, eps=1e-5),\
nn.Dropout(0.1)) for i in range(num_layer)])
self.output = nn.Linear(hidden_dim,1)
def forward(self,z):
r = torch.arange(self.square_size).to(z.device)
x_pos = torch.cat([r.unsqueeze(-1)]*self.square_size, dim=-1).reshape(-1).unsqueeze(0)
y_pos = torch.cat([r.unsqueeze(0)]*self.square_size, dim=0).reshape(-1).unsqueeze(0)
x_emb = self.emb_x(x_pos)
y_emb = self.emb_y(y_pos)
z = self.emb(z.unsqueeze(-1))
z = z + x_emb + y_emb
mask = torch.zeros(z.size(0),z.size(1),z.size(1)).float().to(z.device)
for i in range(len(self.encoders)):
z = self.encoders[i](z,mask)
z = self.output(z.mean(1))
return z
- 価値推定モデルと政策モデル
- 長いですが政策(policy)は状況に対して行動とその確率,価値(value)は状況に対して価値を出します.
class PolicyNetwork(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(PolicyNetwork, self).__init__()
self.mlp = CNN_policy(int(input_dim**0.5), output_dim)
def forward(self, x, out_softmax=True):
x = self.mlp(x)
if out_softmax:
return F.softmax(x, dim=-1)
else:
return x
def sample_action(self, state, noise=False):
if not state is torch.Tensor:
state = torch.from_numpy(state).float().to(device)
if len(state.size()) == 1:
state = state.unsqueeze(0)
with torch.no_grad():
x_logit = self(state, out_softmax=False).squeeze()
if noise: _x_logit = _x_logit + torch.randn(_x_logit.shape).to(_x_logit.device)*1.0
x_logit = x_logit.masked_fill(state.squeeze()!=0, float("-inf")) #空いてるところ以外置けない
x = F.softmax(x_logit,dim=-1)
dist = Categorical(x)
action = dist.sample()
log_probability = dist.log_prob(action)
return action.item(), log_probability.item()
def best_action(self, state):
if not state is torch.Tensor:
state = torch.from_numpy(state).float().to(device)
if len(state.size()) == 1:
state = state.unsqueeze(0)
with torch.no_grad():
x_logit = self(state, out_softmax=False).squeeze()
x_logit = x_logit.masked_fill(state.squeeze()!=0, float("-inf")) #空いてるところ以外置けない
x = F.softmax(x_logit, dim=-1)
action = torch.argmax(x)
return action.item()
def evaluate_actions(self, states, actions):
x = self(states)
dist = Categorical(x)
entropy = dist.entropy()
log_probabilities = dist.log_prob(actions)
return log_probabilities, entropy
class ValueNetwork(torch.nn.Module):
def __init__(self,input_dim):
super(ValueNetwork, self).__init__()
self.mlp = CNN_value(int(input_dim**0.5))
def forward(self, x):
x = self.mlp(x)
return x.squeeze(1)
def state_value(self, state):
if not state is torch.Tensor:
state = torch.from_numpy(state).float().to(device)
if len(state.size()) == 1:
state = state.unsqueeze(0)
with torch.no_grad():
x = self(state).squeeze()
return x.item()
- 価値と政策のモデル学習(こちらから引用)
def train_value_network(value_model, value_optimizer, data_loader, epochs=4):
epochs_losses = []
for i in range(epochs):
losses = []
for observations, _, _, _, rewards_to_go in data_loader:
observations = observations.float().to(device)
rewards_to_go = rewards_to_go.float().to(device)
value_optimizer.zero_grad()
values = value_model(observations)
loss = F.mse_loss(values, rewards_to_go)
loss.backward()
value_optimizer.step()
losses.append(loss.item())
mean_loss = np.mean(losses)
epochs_losses.append(mean_loss)
return epochs_losses
def ac_loss(new_log_probabilities, old_log_probabilities, advantages, epsilon_clip=0.2):
probability_ratios = torch.exp(new_log_probabilities - old_log_probabilities)
clipped_probabiliy_ratios = torch.clamp(probability_ratios, 1 - epsilon_clip, 1 + epsilon_clip)
surrogate_1 = probability_ratios * advantages
surrogate_2 = clipped_probabiliy_ratios * advantages
return -torch.min(surrogate_1, surrogate_2)
def train_policy_network(policy_model, policy_optimizer, data_loader, epochs=4, clip=0.2):
epochs_losses = []
c1 = 0.01
for i in range(epochs):
losses = []
for observations, actions, advantages, log_probabilities, _ in data_loader:
observations = observations.float().to(device)
actions = actions.long().to(device)
advantages = advantages.float().to(device)
old_log_probabilities = log_probabilities.float().to(device)
policy_optimizer.zero_grad()
new_log_probabilities, entropy = policy_model.evaluate_actions(observations, actions)
loss = (ac_loss(new_log_probabilities, old_log_probabilities, advantages, epsilon_clip=clip).mean() - c1 * entropy.mean())
loss.backward()
policy_optimizer.step()
losses.append(loss.item())
mean_loss = np.mean(losses)
epochs_losses.append(mean_loss)
return epochs_losses
- エピソードを貯めるクラスと学習用のデータセットクラスなど(同じく引用)
def cumulative_sum(array, gamma=1.0):
curr = 0
cumulative_array = []
for a in array[::-1]:
curr = a + gamma * curr
cumulative_array.append(curr)
return cumulative_array[::-1]
class Episode:
def __init__(self, gamma=0.99, lambd=0.95):
self.observations = []
self.actions = []
self.advantages = []
self.rewards = []
self.rewards_to_go = []
self.values = []
self.log_probabilities = []
self.gamma = gamma
self.lambd = lambd
def append(self, observation, action, reward, value, log_probability, reward_scale=20):
self.observations.append(observation)
self.actions.append(action)
self.rewards.append(reward / reward_scale)
self.values.append(value)
self.log_probabilities.append(log_probability)
def end_episode(self, last_value):
rewards = np.array(self.rewards + [last_value])
values = np.array(self.values + [last_value])
deltas = rewards[:-1] + self.gamma * values[1:] - values[:-1]
self.advantages = cumulative_sum(deltas.tolist(), gamma=self.gamma * self.lambd)
self.rewards_to_go = cumulative_sum(rewards.tolist(), gamma=self.gamma)[:-1]
def normalize_list(array):
array = np.array(array)
array = (array - np.mean(array)) / (np.std(array) + 1e-5)
return array.tolist()
class History(Dataset):
def __init__(self):
self.episodes = []
self.observations = []
self.actions = []
self.advantages = []
self.rewards = []
self.rewards_to_go = []
self.log_probabilities = []
def free_memory(self):
del self.episodes[:]
del self.observations[:]
del self.actions[:]
del self.advantages[:]
del self.rewards[:]
del self.rewards_to_go[:]
del self.log_probabilities[:]
def add_episode(self, episode):
self.episodes.append(episode)
def build_dataset(self):
for episode in self.episodes:
self.observations += episode.observations
self.actions += episode.actions
self.advantages += episode.advantages
self.rewards += episode.rewards
self.rewards_to_go += episode.rewards_to_go
self.log_probabilities += episode.log_probabilities
# 長さが全て同じ必要がある
assert (len({len(self.observations),len(self.actions),len(self.advantages),len(self.rewards),len(self.rewards_to_go),len(self.log_probabilities),}) == 1)
self.advantages = normalize_list(self.advantages)
def __len__(self):
return len(self.observations)
def __getitem__(self, idx):
return (
self.observations[idx],
self.actions[idx],
self.advantages[idx],
self.log_probabilities[idx],
self.rewards_to_go[idx],
)
- 過去のエージェントと対戦(先行・後攻入れ替えて50試合ずつ)
def self_play(env, now_model, old_model, state_scale):
num_play = 50
results = []
for i in range(num_play):
observation = env.reset()
while True:
action, _ = now_model.sample_action(observation/state_scale)
new_observation, _, done, _ = env.step(action)
observation = new_observation
if done == 2:
results.append(0)
break
if done == 1:
results.append(1)
break
action, _ = old_model.sample_action(observation/state_scale)
new_observation, _, done, _ = env.step(action)
observation = new_observation
if done == 2:
results.append(0)
break
if done == 1:
results.append(-1)
break
for i in range(num_play):
observation = env.reset()
while True:
action, _ = old_model.sample_action(observation / state_scale)
new_observation, _, done, _ = env.step(action)
observation = new_observation
if done == 2:
results.append(0)
break
if done == 1:
results.append(-1)
break
action, _ = now_model.sample_action(observation / state_scale)
new_observation, _, done, _ = env.step(action)
observation = new_observation
if done == 2:
results.append(0)
break
if done == 1:
results.append(1)
break
print("win:", results.count(1), "lose", results.count(-1), "draw", results.count(0))
return sum(results) > 0
- 三目並べの環境
- 対戦の側に関わらず常に"1"が行動側の"コマ"
class Environment():
def __init__(self, device, square_size=3, moku=3):
self.square_size = square_size
self.device = device
self.moku = moku
self.board = torch.zeros(self.square_size**2).to(self.device)
self.action_n = square_size**2
self.win_reward = 100
self.draw_reward = 20
def reset(self):
self.board = torch.zeros(self.square_size**2).to(self.device)
return self.board.cpu().numpy()
def step(self, action :int):
self.board += F.one_hot(torch.tensor(action).to(self.device),num_classes=self.action_n)
reward = 0
complete = 0
for i in range(self.square_size-self.moku+1):
for j in range(self.square_size-self.moku+1):
ker = self.board.reshape(self.square_size, self.square_size)[i:i+self.moku,j:j+self.moku]
ker = ker*((ker!=-1).float()).long()
if torch.sum(ker.reshape(self.moku, self.moku).diagonal())==self.moku or\
torch.sum(ker.reshape(self.moku, self.moku).fliplr().diagonal())==self.moku or\
torch.any(torch.sum(ker.reshape(self.moku, self.moku), dim=0)==self.moku) or\
torch.any(torch.sum(ker.reshape(self.moku, self.moku), dim=1)==self.moku):
reward = self.win_reward
complete = 1
if torch.all(self.board!=0) and complete == 0:
reward = self.draw_reward
complete = 2
self.board = -1 * self.board + 0
return self.board.cpu().numpy(), reward, complete, None
- メインループ
def main(reward_scale, clip, learning_rate, state_scale):
env = Environment(device, square_size=3, moku=3) #盤面の大きさと何目並べか
observation = env.reset()
n_actions = env.action_n
feature_dim = observation.size
value_model = ValueNetwork(input_dim=feature_dim).to(device)
value_optimizer = optim.Adam(value_model.parameters(), lr=learning_rate)
policy_model = PolicyNetwork(input_dim=feature_dim, output_dim=n_actions).to(device)
policy_optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
policy_model_old, value_model_old = deepcopy(policy_model), deepcopy(value_model)
n_epoch = 1
max_episodes = 100
batch_size = 10000
max_iterations = 20000
history = History()
epoch_ite = 0
episode_ite = 0
for ite in tqdm(range(max_iterations), desc="[iteration]"):
if ite % 10 == 0 and ite > 0:
torch.save(policy_model.state_dict(),"now_policy.pth")
if self_play(env, policy_model, policy_model_old, state_scale=state_scale):
print("win step up!")
policy_model_old,value_model_old = deepcopy(policy_model),deepcopy(value_model)
torch.save(policy_model.state_dict(),"best_policy.pth")
for episode_i in (range(max_episodes)):
observation = env.reset()
if random.random() > 0.5:
enemy_action, _ = policy_model.sample_action(observation/state_scale, noise=True)
new_observation, _, _, _ = env.step(enemy_action)
observation = new_observation
episode = Episode()
while True:
action, log_probability = policy_model.sample_action(observation/state_scale, noise=True)
value = value_model.state_value(observation / state_scale)
new_observation, reward, done, _ = env.step(action)
episode.append(
observation=observation / state_scale,
action=action,
reward=reward,
value=value,
log_probability=log_probability,
reward_scale=reward_scale,
)
if done:
episode.end_episode(last_value=0)
break
observation = new_observation
enemy_action, _ = policy_model.sample_action(observation/state_scale, noise=True)
new_observation, reward, done, _ = env.step(enemy_action)
observation = new_observation
if done==2:
episode.rewards[-1] = reward*0.99 / reward_scale
episode.end_episode(last_value=0)
break
if done==1:
episode.rewards[-1] = -reward*0.99 / reward_scale
episode.end_episode(last_value=0)
break
episode_ite += 1
history.add_episode(episode)
history.build_dataset()
data_loader = DataLoader(history, batch_size=batch_size, shuffle=True)
policy_loss = train_policy_network(policy_model, policy_optimizer, data_loader, epochs=n_epoch, clip=clip)
value_loss = train_value_network(value_model, value_optimizer, data_loader, epochs=n_epoch)
history.free_memory()
if __name__ == "__main__":
main(reward_scale=20.0, clip=0.2, learning_rate=0.001, state_scale=1.0)
- 学習したモデルで遊びます.(エージェントが先行 我々のコマはOです)
play_gui.py
import time
import tkinter
board_size = 3
moku = 3
policy = PolicyNetwork(board_size**2, board_size**2).to(device)
policy.load_state_dict(torch.load("best_policy.pth"))
policy.eval()
env = Environment(device, square_size=board_size, moku=moku)
action_enemy_first = policy.best_action(env.board.cpu().numpy())
board, reward, done, _ = env.step(action_enemy_first)
def step(_i,_j):
action_me = board_size*_i+_j
board, reward, done, _ = env.step(action_me)
buttons[_i*board_size+_j].config(text="O")
if done:
root.update()
time.sleep(.5)
exit(0)
action_enemy = policy.best_action(env.board.cpu().numpy())
board, reward, done_, _ = env.step(action_enemy)
buttons[action_enemy].config(text="X")
if done_:
root.update()
time.sleep(.5)
exit(0)
root = tkinter.Tk()
root.title('Tkinter')
root.geometry("400x400")
class Click:
def __init__(self, _i, _j):
self.i = _i
self.j = _j
def click(self):
step(self.i, self.j)
buttons = []
for i in range(board_size):
for j in range(board_size):
if action_enemy_first//board_size==i and action_enemy_first%board_size==j:
button = tkinter.Button(root, text="X", command=Click(i,j).click, height = 3, width = 6)
else:
button = tkinter.Button(root, text="-", command=Click(i,j).click, height = 3, width = 6)
buttons.append(button)
for i in range(board_size):
for j in range(board_size):
buttons[board_size*i+j].grid(row=i, column=j)
root.mainloop()
結果
手を抜いてプレイ



瞬殺されました
本気でプレイ





どうあがいても引き分けです
まとめ
- 自己対戦とPPOで三目並べを学習してみました.
- うまく動作しているようにも見えますが,現在ソースコード検証中です.
最後に
誤っている箇所などございましたら,コメント欄などで優しく教えてください.