LoginSignup
4
8

More than 5 years have passed since last update.

強化学習の勉強 (3) Double DQN と Dueling Network

Posted at

これの続きです

深層強化学習 PyTorchによる実践プログラミング の6章の内容です。

前回の Deep Q-Netowork の発展版として、Double DQN と Dueling Network の実装が紹介されていました。

Double DQN の結果がこれ。
ddqn.gif
Dueling Network の結果がこれ。
DuelingNetwork.gif
この課題だともう学習後の上手さの違いなどはよくわからない。上記の本では Dueling Network を使うと少ない試行数でも学習が進むという話が書いてあったけどそのような傾向は見られなかった。

Double DQN

Main Q-Network: 次のステップで最大のQ値を持つようなaction、$a_m$を求めるネットワーク
Target Q-Network: $a_m$ の Q 値を評価するネットワーク
行動の決定と評価をそれぞれ別のネットワーク (ただし構造は同一) で行うことで学習を安定化させられるらしい。学習自体は Main Q-Network に対して行うが、たまに (下記の例では 2 エピソードに 1 回) Main Q-Network の weight を Target Q-Network にコピーしている。

Dueling Network

行動価値関数 $ Q(s, a) $ には、$s$だけで決まってしまう要素 $V(s)$ と行動次第で決まる要素 $A(s,a)$ があると考える。例えば cartpole では棒がもう倒れそうな状態だったらそこから右に押そうが左に押そうがあまり関係ない、など。そこで Q-Network の出力を $V(s)$ を出力する部分と各行動に対する $A(s,a)$ を出力する部分に分岐させ、$V(s) + A(s,a) = Q(s, a)$ として Q 値を求める。

DDQN.py
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import gym

from JSAnimation.IPython_display import display_animation
from matplotlib import animation 
from IPython.display import display

def display_frames_as_gif(frames):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save('movie_cart_ple_ddqn.mp4')
    display(display_animation(anim, default_mode='loop'))

from collections import namedtuple

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

# 定数の設定
ENV = 'CartPole-v0'
GAMMA = 0.99
MAX_STEPS = 200
NUM_EPISODES = 500


# 経験を保存するメモリクラスを定義します。

class ReplayMemory:

    def __init__(self, CAPACITY):
        self.capacity = CAPACITY # メモリの最大長さ
        self.memory = []
        self.index = 0

    def push(self, state, action, state_next, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None) #メモリが満タンじゃないときには追加

        self.memory[self.index] = Transition(state, action, state_next, reward)
        self.index = (self.index + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# ニューラルネットワークの定義します。
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, n_in, n_mid, n_out):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_in, n_mid)
        self.fc2 = nn.Linear(n_mid, n_mid)
        self.fc3 = nn.Linear(n_mid, n_out)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        output = self.fc3(h2)
        return output

import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

BATCH_SIZE = 32
CAPACITY = 10000

# 行動決定を行うためのクラスです。
class Brain:
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions

        # メモリオブジェクトの生成
        self.memory = ReplayMemory(CAPACITY)

        # ニューラルネットワークの構築
        n_in, n_mid, n_out = num_states, 32, num_actions
        self.main_q_network = Net(n_in, n_mid, n_out)
        self.target_q_network= Net(n_in, n_mid, n_out)

        print(self.main_q_network)

        # オプティマイザの設定
        self.optimizer = optim.Adam(self.main_q_network.parameters(), lr=0.0001)

    def replay(self):
        '''Experience Replay'''

        # メモリサイズの確認
        # メモリサイズがミニバッチサイズより小さい間は何もしない。
        if len(self.memory) < BATCH_SIZE:
            return

        # ミニバッチの作成        
        self.make_minibatch()

        # 教師信号Q(s_t, a_t)を求める
        self.expected_state_action_values = self.get_expected_state_action_values()

        # ネットワークのパラメータ更新
        self.update_main_q_network()

    def decide_action(self, state, episode):
        '''state に応じて行動を決定する関数'''
        epsilon = 0.5 * (1 / (episode + 1))

        if epsilon <= np.random.uniform(0, 1):
            self.main_q_network.eval()
            with torch.no_grad():
                action = self.main_q_network(state).max(1)[1].view(1, 1)
        else:
            action = torch.LongTensor([[random.randrange(self.num_actions)]])

        return action

    def make_minibatch(self):
        # ミニバッチの作成
        transitions = self.memory.sample(BATCH_SIZE)

        batch = Transition(*zip(*transitions))
        self.batch = batch
        self.state_batch = torch.cat(batch.state)
        self.action_batch = torch.cat(batch.action)
        self.reward_batch = torch.cat(batch.reward)
        self.non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    def get_expected_state_action_values(self):

        self.main_q_network.eval()
        self.target_q_network.eval()

        # state_batchをモデルに与え、推論結果からaction_batchで行った行動に対応するQ値を取ってくる
        # つまりあるstateにおける行ったactionの価値をとってきている。
        self.state_action_value = self.main_q_network(self.state_batch).gather(1, self.action_batch)

        # max{Q(s_t+1, a)}を求める
        non_final_mask = torch.ByteTensor(
        tuple(map(lambda s: s is not None, self.batch.next_state)))

        next_state_values = torch.zeros(BATCH_SIZE)

        a_m = torch.zeros(BATCH_SIZE).type(torch.LongTensor)

        # 次の状態での最大Q値のa_mをmain_q_networkから求める
        a_m[non_final_mask] = self.main_q_network(self.non_final_next_states).detach().max(1)[1]

        # 次の状態があるものだけをフィルターし、sizeをBATCH_SIZEからBATCH_SIZE*1へ
        a_m_non_final_next_states = a_m[non_final_mask].view(-1, 1)

        #行動a_mのQ値をtarget_q_networkで推定する。
        next_state_values[non_final_mask] = self.target_q_network(self.non_final_next_states).gather(
            1, a_m_non_final_next_states).detach().squeeze()

        expected_state_action_values = self.reward_batch + GAMMA * next_state_values

        return expected_state_action_values

    def update_main_q_network(self):
        # main_q_networkのネットワークパラメータの更新
        self.main_q_network.train()
        loss = F.smooth_l1_loss(self.state_action_value, self.expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_q_network(self):
        #target_q_network のパラメータをmain_q_networkと同じにする。
        self.target_q_network.load_state_dict(self.main_q_network.state_dict())

# エージェントを定義するクラスです。
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)

    def update_q_function(self):
        self.brain.replay()

    def get_action(self, state, episode):
        action = self.brain.decide_action(state, episode)
        return action

    def memorize(self, state, action, state_next, reward):
        self.brain.memory.push(state, action, state_next, reward)

    def update_target_q_function(self):
        self.brain.update_target_q_network()

# 環境を定義するクラスです。
class Environment:
    def __init__(self):
        self.env = gym.make(ENV)
        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n

        self.agent = Agent(self.num_states, self.num_actions)

    def run(self):

        episode_10_list = np.zeros(10) #直近10エピソードで振り子が立ち続けたステップ数を記録

        complete_episodes = 0
        episode_final = False
        frames = []

        for episode in range(NUM_EPISODES):
            observation = self.env.reset()

            state = observation
            state = torch.from_numpy(state).type(torch.FloatTensor)
            state = torch.unsqueeze(state, 0)

            for step in range(MAX_STEPS):

                if episode_final is True:
                    frames.append(self.env.render(mode='rgb_array'))

                action = self.agent.get_action(state, episode)

                #行動actionの実行によってs_t+1とdoneフラグを取得
                observation_next, _, done, _ = self.env.step(action.item())

                if done:
                    state_next = None

                    episode_10_list = np.hstack((episode_10_list[1:], step + 1))

                    if step < 195:
                        reward = torch.FloatTensor([-1.0]) # 195ステップ未満で倒れたら報酬-1
                        complete_episodes = 0
                    else:
                        reward = torch.FloatTensor([1.0])
                        complete_episodes = complete_episodes + 1

                else:
                    reward = torch.FloatTensor([0.0])
                    state_next = observation_next
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
                    state_next = torch.unsqueeze(state_next, 0)

                self.agent.memorize(state, action, state_next, reward)
                self.agent.update_q_function()

                state = state_next

                if done:
                    print("%d episode: Finished after %d steps: 10試行の平均step数 = %.lf"%(
                    episode, step + 1, episode_10_list.mean()))

                    if episode % 2 == 0:
                        self.agent.update_target_q_function()

                    observation = self.env.reset()
                    break

            if episode_final is True:
                display_frames_as_gif(frames)
                break

            if complete_episodes >= 10:
                print("10回連続成功")
                episode_final = True


cartpole_env = Environment()
cartpole_env.run()

Dueling Network は Double DQN のネットワーク構造を以下のように変えるだけでよい。

DuelingNetwork.py
class Net(nn.Module):
    def __init__(self, n_in, n_mid, n_out):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_in, n_mid)
        self.fc2 = nn.Linear(n_mid, n_mid)
        self.fc3_adv = nn.Linear(n_mid, n_out)
        self.fc3_v = nn.Linear(n_mid, 1)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        adv = self.fc3_adv(h2)
        val = self.fc3_v(h2).expand(-1, adv.size(1))
        output = val + adv - adv.mean(1, keepdim=True).expand(-1, adv.size(1))

        return output
4
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
8