深層強化学習 PyTorchによる実践プログラミング の6章の内容です。
前回の Deep Q-Netowork の発展版として、Double DQN と Dueling Network の実装が紹介されていました。
Double DQN の結果がこれ。
Dueling Network の結果がこれ。
この課題だともう学習後の上手さの違いなどはよくわからない。上記の本では Dueling Network を使うと少ない試行数でも学習が進むという話が書いてあったけどそのような傾向は見られなかった。
Double DQN
Main Q-Network: 次のステップで最大のQ値を持つようなaction、$a_m$を求めるネットワーク
Target Q-Network: $a_m$ の Q 値を評価するネットワーク
行動の決定と評価をそれぞれ別のネットワーク (ただし構造は同一) で行うことで学習を安定化させられるらしい。学習自体は Main Q-Network に対して行うが、たまに (下記の例では 2 エピソードに 1 回) Main Q-Network の weight を Target Q-Network にコピーしている。
Dueling Network
行動価値関数 $ Q(s, a) $ には、$s$だけで決まってしまう要素 $V(s)$ と行動次第で決まる要素 $A(s,a)$ があると考える。例えば cartpole では棒がもう倒れそうな状態だったらそこから右に押そうが左に押そうがあまり関係ない、など。そこで Q-Network の出力を $V(s)$ を出力する部分と各行動に対する $A(s,a)$ を出力する部分に分岐させ、$V(s) + A(s,a) = Q(s, a)$ として Q 値を求める。
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display
def display_frames_as_gif(frames):
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
patch = plt.imshow(frames[0])
plt.axis('off')
def animate(i):
patch.set_data(frames[i])
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
anim.save('movie_cart_ple_ddqn.mp4')
display(display_animation(anim, default_mode='loop'))
from collections import namedtuple
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
# 定数の設定
ENV = 'CartPole-v0'
GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500
# 経験を保存するメモリクラスを定義します。
class ReplayMemory:
def __init__(self, CAPACITY):
self.capacity = CAPACITY # メモリの最大長さ
self.memory = []
self.index = 0
def push(self, state, action, state_next, reward):
if len(self.memory) < self.capacity:
self.memory.append(None) #メモリが満タンじゃないときには追加
self.memory[self.index] = Transition(state, action, state_next, reward)
self.index = (self.index + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
# ニューラルネットワークの定義します。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, n_in, n_mid, n_out):
super(Net, self).__init__()
self.fc1 = nn.Linear(n_in, n_mid)
self.fc2 = nn.Linear(n_mid, n_mid)
self.fc3 = nn.Linear(n_mid, n_out)
def forward(self, x):
h1 = F.relu(self.fc1(x))
h2 = F.relu(self.fc2(h1))
output = self.fc3(h2)
return output
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
BATCH_SIZE = 32
CAPACITY = 10000
# 行動決定を行うためのクラスです。
class Brain:
def __init__(self, num_states, num_actions):
self.num_actions = num_actions
# メモリオブジェクトの生成
self.memory = ReplayMemory(CAPACITY)
# ニューラルネットワークの構築
n_in, n_mid, n_out = num_states, 32, num_actions
self.main_q_network = Net(n_in, n_mid, n_out)
self.target_q_network= Net(n_in, n_mid, n_out)
print(self.main_q_network)
# オプティマイザの設定
self.optimizer = optim.Adam(self.main_q_network.parameters(), lr=0.0001)
def replay(self):
'''Experience Replay'''
# メモリサイズの確認
# メモリサイズがミニバッチサイズより小さい間は何もしない。
if len(self.memory) < BATCH_SIZE:
return
# ミニバッチの作成
self.make_minibatch()
# 教師信号Q(s_t, a_t)を求める
self.expected_state_action_values = self.get_expected_state_action_values()
# ネットワークのパラメータ更新
self.update_main_q_network()
def decide_action(self, state, episode):
'''state に応じて行動を決定する関数'''
epsilon = 0.5 * (1 / (episode + 1))
if epsilon <= np.random.uniform(0, 1):
self.main_q_network.eval()
with torch.no_grad():
action = self.main_q_network(state).max(1)[1].view(1, 1)
else:
action = torch.LongTensor([[random.randrange(self.num_actions)]])
return action
def make_minibatch(self):
# ミニバッチの作成
transitions = self.memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
self.batch = batch
self.state_batch = torch.cat(batch.state)
self.action_batch = torch.cat(batch.action)
self.reward_batch = torch.cat(batch.reward)
self.non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
def get_expected_state_action_values(self):
self.main_q_network.eval()
self.target_q_network.eval()
# state_batchをモデルに与え、推論結果からaction_batchで行った行動に対応するQ値を取ってくる
# つまりあるstateにおける行ったactionの価値をとってきている。
self.state_action_value = self.main_q_network(self.state_batch).gather(1, self.action_batch)
# max{Q(s_t+1, a)}を求める
non_final_mask = torch.ByteTensor(
tuple(map(lambda s: s is not None, self.batch.next_state)))
next_state_values = torch.zeros(BATCH_SIZE)
a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)
# 次の状態での最大Q値のa_mをmain_q_networkから求める
a_m[non_final_mask] = self.main_q_network(self.non_final_next_states).detach().max(1)[1]
# 次の状態があるものだけをフィルターし、sizeをBATCH_SIZEからBATCH_SIZE*1へ
a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)
#行動a_mのQ値をtarget_q_networkで推定する。
next_state_values[non_final_mask] = self.target_q_network(self.non_final_next_states).gather(
1, a_m_non_final_next_states).detach().squeeze()
expected_state_action_values = self.reward_batch + GAMMA * next_state_values
return expected_state_action_values
def update_main_q_network(self):
# main_q_networkのネットワークパラメータの更新
self.main_q_network.train()
loss = F.smooth_l1_loss(self.state_action_value, self.expected_state_action_values.unsqueeze(1))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def update_target_q_network(self):
#target_q_network のパラメータをmain_q_networkと同じにする。
self.target_q_network.load_state_dict(self.main_q_network.state_dict())
# エージェントを定義するクラスです。
class Agent:
def __init__(self, num_states, num_actions):
self.brain = Brain(num_states, num_actions)
def update_q_function(self):
self.brain.replay()
def get_action(self, state, episode):
action = self.brain.decide_action(state, episode)
return action
def memorize(self, state, action, state_next, reward):
self.brain.memory.push(state, action, state_next, reward)
def update_target_q_function(self):
self.brain.update_target_q_network()
# 環境を定義するクラスです。
class Environment:
def __init__(self):
self.env = gym.make(ENV)
self.num_states = self.env.observation_space.shape[0]
self.num_actions = self.env.action_space.n
self.agent = Agent(self.num_states, self.num_actions)
def run(self):
episode_10_list = np.zeros(10) #直近10エピソードで振り子が立ち続けたステップ数を記録
complete_episodes = 0
episode_final = False
frames = []
for episode in range(NUM_EPISODES):
observation = self.env.reset()
state = observation
state = torch.from_numpy(state).type(torch.FloatTensor)
state = torch.unsqueeze(state, 0)
for step in range(MAX_STEPS):
if episode_final is True:
frames.append(self.env.render(mode='rgb_array'))
action = self.agent.get_action(state, episode)
#行動actionの実行によってs_t+1とdoneフラグを取得
observation_next, _, done, _ = self.env.step(action.item())
if done:
state_next = None
episode_10_list = np.hstack((episode_10_list[1:], step + 1))
if step < 195:
reward = torch.FloatTensor([-1.0]) # 195ステップ未満で倒れたら報酬-1
complete_episodes = 0
else:
reward = torch.FloatTensor([1.0])
complete_episodes = complete_episodes + 1
else:
reward = torch.FloatTensor([0.0])
state_next = observation_next
state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
state_next = torch.unsqueeze(state_next, 0)
self.agent.memorize(state, action, state_next, reward)
self.agent.update_q_function()
state = state_next
if done:
print("%d episode: Finished after %d steps: 10試行の平均step数 = %.lf"%(
episode, step + 1, episode_10_list.mean()))
if episode % 2 == 0:
self.agent.update_target_q_function()
observation = self.env.reset()
break
if episode_final is True:
display_frames_as_gif(frames)
break
if complete_episodes >= 10:
print("10回連続成功")
episode_final = True
cartpole_env = Environment()
cartpole_env.run()
Dueling Network は Double DQN のネットワーク構造を以下のように変えるだけでよい。
class Net(nn.Module):
def __init__(self, n_in, n_mid, n_out):
super(Net, self).__init__()
self.fc1 = nn.Linear(n_in, n_mid)
self.fc2 = nn.Linear(n_mid, n_mid)
self.fc3_adv = nn.Linear(n_mid, n_out)
self.fc3_v = nn.Linear(n_mid, 1)
def forward(self, x):
h1 = F.relu(self.fc1(x))
h2 = F.relu(self.fc2(h1))
adv = self.fc3_adv(h2)
val = self.fc3_v(h2).expand(-1, adv.size(1))
output = val + adv - adv.mean(1, keepdim=True).expand(-1, adv.size(1))
return output