5章の Deep Q-Network (DQN) についての実装を試してみた。
Q学習のときには q_table (離散化されたstate × action) の表を作っておいて、state から次の action を決めるのはその表を元にやっていたが、DQN では state を入力、action を出力とするニューラルネットワークでやるらしい。
DQN の実装上の工夫は以下の4点。
-
Experience Replay
各ステップごとに学習をしてしまうと、時間的に相関が高いデータを連続して学習することになり学習が不安定化する。代わりに、各ステップの情報をメモリに保存しておき、そこからサンプリングしたものを使って学習を行う。 -
Fixed Target Q-Network
行動を決定する main-network と行動価値を計算する target-network を分ける。ただし今回は main-network をミニバッチ学習する形で簡便な実装を行う。 -
報酬のクリッピング
各ステップの報酬は -1、0、1 のいずれかとする。 -
Huber 関数を用いた誤差
二乗誤差を使うと誤差関数の出力が大きくなりすぎて学習が不安定化することがあるらしい。
L_1(x) = \left\{
\begin{array}{ll}
\frac{1}{2}x^2 & (|x| \leq 1) \\
|x| - \frac{1}{2} & (|x| \gt 1)
\end{array}
\right.
以下コードのメモ。実際はJupyter Notebookで実行している。
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display
# 動画を出力する関数です
def display_frames_as_gif(frames):
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
patch = plt.imshow(frames[0])
plt.axis('off')
def animate(i):
patch.set_data(frames[i])
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
anim.save('movie_cart_ple_dqn.mp4')
display(display_animation(anim, default_mode='loop'))
from collections import namedtuple
# 各ステップでの情報を保持するための namedtuple です
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
# 定数の設定
ENV = 'CartPole-v0'
GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500
# 経験を保存するメモリクラスを定義します。
class ReplayMemory:
def __init__(self, CAPACITY):
self.capacity = CAPACITY # メモリの最大長さ
self.memory = []
self.index = 0
def push(self, state, action, state_next, reward):
if len(self.memory) < self.capacity:
self.memory.append(None) #メモリが満タンじゃないときには追加
self.memory[self.index] = Transition(state, action, state_next, reward)
self.index = (self.index + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
BATCH_SIZE = 32
CAPACITY = 10000
# エージェントの行動方針を決めるためのクラスです。
class Brain:
def __init__(self, num_states, num_actions):
self.num_actions = num_actions
# メモリオブジェクトの生成
self.memory = ReplayMemory(CAPACITY)
# ニューラルネットワークの構築
self.model = nn.Sequential()
self.model.add_module('fc1', nn.Linear(num_states, 32))
self.model.add_module('relu1', nn.ReLU())
self.model.add_module('fc2', nn.Linear(32, 32))
self.model.add_module('relu2', nn.ReLU())
self.model.add_module('fc3', nn.Linear(32, num_actions))
print(self.model)
# オプティマイザの設定
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
def replay(self):
'''Experience Replay'''
# メモリサイズの確認
# メモリサイズがミニバッチサイズより小さい間は何もしない。
if len(self.memory) < BATCH_SIZE:
return
# ミニバッチの作成
transitions = self.memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
# ネットワークを推論モードにする
self.model.eval()
# state_batchをモデルに与え、推論結果からaction_batchで行った行動に対応するQ値を取ってくる
# つまりあるstateにおける行ったactionの価値をとってきている。
state_action_value = self.model(state_batch).gather(1, action_batch)
# max{Q(s_t+1, a)}を求める
non_final_mask = torch.ByteTensor(
tuple(map(lambda s: s is not None, batch.next_state)))
next_state_values = torch.zeros(BATCH_SIZE)
next_state_values[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()
# Q(s_t, a_t)をQ学習の式から求める。
expected_state_action_values = reward_batch + GAMMA * next_state_values
# ネットワークパラメータの更新
self.model.train()
loss = F.smooth_l1_loss(state_action_value, expected_state_action_values.unsqueeze(1))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def decide_action(self, state, episode):
'''state に応じて行動を決定する関数'''
epsilon = 0.5 * (1 / (episode + 1))
if epsilon <= np.random.uniform(0, 1):
self.model.eval()
with torch.no_grad():
action = self.model(state).max(1)[1].view(1, 1)
else:
action = torch.LongTensor([[random.randrange(self.num_actions)]])
return action
# エージェントクラスです。state に応じて action を行います。
class Agent:
def __init__(self, num_states, num_actions):
self.brain = Brain(num_states, num_actions)
def update_q_function(self):
self.brain.replay()
def get_action(self, state, episode):
action = self.brain.decide_action(state, episode)
return action
def memorize(self, state, action, state_next, reward):
self.brain.memory.push(state, action, state_next, reward)
# CartPole 実行環境のクラスです。
class Environment:
def __init__(self):
self.env = gym.make(ENV)
self.num_states = self.env.observation_space.shape[0]
self.num_actions = self.env.action_space.n
self.agent = Agent(self.num_states, self.num_actions)
def run(self):
episode_10_list = np.zeros(10) #直近10エピソードで振り子が立ち続けたステップ数を記録
complete_episodes = 0
episode_final = False
frames = []
for episode in range(NUM_EPISODES):
observation = self.env.reset()
state = observation
state = torch.from_numpy(state).type(torch.FloatTensor)
state = torch.unsqueeze(state, 0)
for step in range(MAX_STEPS):
if episode_final is True:
frames.append(self.env.render(mode='rgb_array'))
action = self.agent.get_action(state, episode)
#行動actionの実行によってs_t+1とdoneフラグを取得
observation_next, _, done, _ = self.env.step(action.item())
if done:
state_next = None
episode_10_list = np.hstack((episode_10_list[1:], step + 1))
if step < 195:
reward = torch.FloatTensor([-1.0]) # 195ステップ未満で倒れたら報酬-1
complete_episodes = 0
else:
reward = torch.FloatTensor([1.0])
complete_episodes = complete_episodes + 1
else:
reward = torch.FloatTensor([0.0])
state_next = observation_next
state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
state_next = torch.unsqueeze(state_next, 0)
self.agent.memorize(state, action, state_next, reward)
self.agent.update_q_function()
state = state_next
if done:
print("%d episode: Finished after %d steps: 10試行の平均step数 = %.lf"%(
episode, step + 1, episode_10_list.mean()))
observation = self.env.reset()
break
if episode_final is True:
display_frames_as_gif(frames)
break
if complete_episodes >= 10:
print("10回連続成功")
episode_final = True
cartpole_env = Environment()
cartpole_env.run()