1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

強化学習の勉強 (2) Deep Q-Network

Last updated at Posted at 2019-02-01

これの続きです

5章の Deep Q-Network (DQN) についての実装を試してみた。

Q学習のときには q_table (離散化されたstate × action) の表を作っておいて、state から次の action を決めるのはその表を元にやっていたが、DQN では state を入力、action を出力とするニューラルネットワークでやるらしい。

120エピソード程度の学習でこんな感じになった。
cartpole_dqn.gif

DQN の実装上の工夫は以下の4点。

  1. Experience Replay
    各ステップごとに学習をしてしまうと、時間的に相関が高いデータを連続して学習することになり学習が不安定化する。代わりに、各ステップの情報をメモリに保存しておき、そこからサンプリングしたものを使って学習を行う。

  2. Fixed Target Q-Network
    行動を決定する main-network と行動価値を計算する target-network を分ける。ただし今回は main-network をミニバッチ学習する形で簡便な実装を行う。

  3. 報酬のクリッピング
    各ステップの報酬は -1、0、1 のいずれかとする。

  4. Huber 関数を用いた誤差
    二乗誤差を使うと誤差関数の出力が大きくなりすぎて学習が不安定化することがあるらしい。

L_1(x) =  \left\{
\begin{array}{ll}
\frac{1}{2}x^2 & (|x| \leq 1) \\
|x| - \frac{1}{2} & (|x| \gt 1)
\end{array}
\right.

以下コードのメモ。実際はJupyter Notebookで実行している。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym

from JSAnimation.IPython_display import display_animation
from matplotlib import animation 
from IPython.display import display

# 動画を出力する関数です
def display_frames_as_gif(frames):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')
    
    def animate(i):
        patch.set_data(frames[i])
    
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save('movie_cart_ple_dqn.mp4')
    display(display_animation(anim, default_mode='loop'))

from collections import namedtuple

# 各ステップでの情報を保持するための namedtuple です
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

# 定数の設定
ENV = 'CartPole-v0'
GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500

# 経験を保存するメモリクラスを定義します。

class ReplayMemory:
    
    def __init__(self, CAPACITY):
        self.capacity = CAPACITY # メモリの最大長さ
        self.memory = []
        self.index = 0
    
    def push(self, state, action, state_next, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None) #メモリが満タンじゃないときには追加
            
        self.memory[self.index] = Transition(state, action, state_next, reward)
        self.index = (self.index + 1) % self.capacity
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)


import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

BATCH_SIZE = 32
CAPACITY = 10000

# エージェントの行動方針を決めるためのクラスです。
class Brain:
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        
        # メモリオブジェクトの生成
        self.memory = ReplayMemory(CAPACITY)
        
        # ニューラルネットワークの構築
        self.model = nn.Sequential()
        self.model.add_module('fc1', nn.Linear(num_states, 32))
        self.model.add_module('relu1', nn.ReLU())
        self.model.add_module('fc2', nn.Linear(32, 32))
        self.model.add_module('relu2', nn.ReLU())
        self.model.add_module('fc3', nn.Linear(32, num_actions))
        
        print(self.model)
        
        # オプティマイザの設定
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
        
    def replay(self):
        '''Experience Replay'''
        
        # メモリサイズの確認
        # メモリサイズがミニバッチサイズより小さい間は何もしない。
        if len(self.memory) < BATCH_SIZE:
            return
        
        # ミニバッチの作成
        transitions = self.memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        # ネットワークを推論モードにする
        self.model.eval()
        
        # state_batchをモデルに与え、推論結果からaction_batchで行った行動に対応するQ値を取ってくる
        # つまりあるstateにおける行ったactionの価値をとってきている。
        state_action_value = self.model(state_batch).gather(1, action_batch)
        
        # max{Q(s_t+1, a)}を求める
        non_final_mask = torch.ByteTensor(
        tuple(map(lambda s: s is not None, batch.next_state)))
        
        next_state_values = torch.zeros(BATCH_SIZE)
        
        next_state_values[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()
        
        # Q(s_t, a_t)をQ学習の式から求める。
        expected_state_action_values = reward_batch + GAMMA * next_state_values
        
        # ネットワークパラメータの更新
        self.model.train()
        loss = F.smooth_l1_loss(state_action_value, expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def decide_action(self, state, episode):
        '''state に応じて行動を決定する関数'''
        epsilon = 0.5 * (1 / (episode + 1))
        
        if epsilon <= np.random.uniform(0, 1):
            self.model.eval()
            with torch.no_grad():
                action = self.model(state).max(1)[1].view(1, 1)
        else:
            action = torch.LongTensor([[random.randrange(self.num_actions)]])
        
        return action
        
# エージェントクラスです。state に応じて action を行います。
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)
    
    def update_q_function(self):
        self.brain.replay()
    
    def get_action(self, state, episode):
        action = self.brain.decide_action(state, episode)
        return action
    
    def memorize(self, state, action, state_next, reward):
        self.brain.memory.push(state, action, state_next, reward)

# CartPole 実行環境のクラスです。
class Environment:
    def __init__(self):
        self.env = gym.make(ENV)
        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n
        
        self.agent = Agent(self.num_states, self.num_actions)
        
    def run(self):
        
        episode_10_list = np.zeros(10) #直近10エピソードで振り子が立ち続けたステップ数を記録
        
        complete_episodes = 0
        episode_final = False
        frames = []
        
        for episode in range(NUM_EPISODES):
            observation = self.env.reset()
            
            state = observation
            state = torch.from_numpy(state).type(torch.FloatTensor)
            state = torch.unsqueeze(state, 0)
            
            for step in range(MAX_STEPS):
                
                if episode_final is True:
                    frames.append(self.env.render(mode='rgb_array'))
                
                action = self.agent.get_action(state, episode)
                
                #行動actionの実行によってs_t+1とdoneフラグを取得
                observation_next, _, done, _ = self.env.step(action.item())
                
                if done:
                    state_next = None
                    
                    episode_10_list = np.hstack((episode_10_list[1:], step + 1))
                    
                    if step < 195:
                        reward = torch.FloatTensor([-1.0]) # 195ステップ未満で倒れたら報酬-1
                        complete_episodes = 0
                    else:
                        reward = torch.FloatTensor([1.0])
                        complete_episodes = complete_episodes + 1
                        
                else:
                    reward = torch.FloatTensor([0.0])
                    state_next = observation_next
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
                    state_next = torch.unsqueeze(state_next, 0)
                    
                self.agent.memorize(state, action, state_next, reward)
                self.agent.update_q_function()
                
                state = state_next
                
                if done:
                    print("%d episode: Finished after %d steps: 10試行の平均step数 = %.lf"%(
                    episode, step + 1, episode_10_list.mean()))
                    observation = self.env.reset()
                    break
            
            if episode_final is True:
                display_frames_as_gif(frames)
                break
            
            if complete_episodes >= 10:
                print("10回連続成功")
                episode_final = True

cartpole_env = Environment()
cartpole_env.run()
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?