はじめに
DuelingDQNとは
通常のDQNは上図のようにSequencialなNetworkを通じて行動価値関数を予測していきます。それに対し、DuelingNetworkでは下図のように途中で状態価値関数を求めるV(s)と行動価値関数からV(s)を引いたものを求めるAdvantageに分岐します。その後、それらを足し合わせることで最終的な行動価値関数を導出します。
本記事ではシンプルなDQNとDuelingDQNをPytorchを用いて実装し、性能比較していきます。
Github
参考
実装
シンプルなNetworkの実装
ここではネットワークのみの実装を行います。
class DQN(nn.Module):
def __init__(self, num_states, num_actions):
super(DQN, self).__init__()
self.linear_relu_stack = nn.Sequential(
nn.Linear(num_states, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, num_actions)
)
def forward(self, x):
x = x.to(device)
logits = self.linear_relu_stack(x)
return logits
全結合層を3つ使用し、活性化関数にはReLU関数を使用しています。
Dueling Networkの実装
class DuelingDQN(nn.Module):
def __init__(self, num_states, num_actions):
super(DuelingDQN, self).__init__()
self.num_states = num_states
self.num_actions = num_actions
self.fc1 = nn.Linear(self.num_states, 32)
self.relu = nn.ReLU()
self.fcV1 = nn.Linear(32, 32)
self.fcA1 = nn.Linear(32, 32)
self.fcV2 = nn.Linear(32, 1)
self.fcA2 = nn.Linear(32, self.num_actions)
def forward(self, x):
x = self.relu(self.fc1(x))
V = self.fcV2(self.fcV1(x))
A = self.fcA2(self.fcA1(x))
averageA = A.mean(1).unsqueeze(1)
return V.expand(-1, self.num_actions) + (A - averageA.expand(-1, self.num_actions))
DuelingDQNでは最初に全結合層を行い、その後AdvantageとVに分岐され、最終的にVとAから出力を算出します。
結果
DQNとDuelingDQNでの実行結果を示します。今回は下記のハイパーパラメータを利用し3000エピソード実行しました。
Parameter | Value | Description |
---|---|---|
episode | 3000 | エピソード数 |
capacity | 10000 | experience replayの容量 |
batch_size | 32 | バッチサイズ |
gamma | 0.97 | 時間割引率 |
target_update_iter | 20 | target networkの更新間隔 |
eps_start | 0.9 | ε-greedyのパラメータ |
eps_end | 0.5 | ε-greedyのパラメータ |
eps_decay | 200 | ε-greedyのパラメータ |
DQN
Dueling DQN
シンプルなDQNを使用した場合は2000エピソード超えたあたりから成績が向上していますが、Dueling DQNを使用した場合は1200エピソードあたりから向上しており、学習効率の向上が見て取れます。
また、Cartpoleは簡単な問題であるのと報酬を200で切っているため、性能についてははっきりわかりませんが、3000エピソードまでで比べるとDuelingDQNの方が成績は良い傾向が見れます。
DuelingDQNの学習結果
エピソード数が少ないだけあって、だいぶ安定していないですね。
[追記]
このあとエピソード6000回行った結果です。
100回の試行すべてで200stepに到達していますが、若干遊びを覚えているような?感じがします。
200stepくらいならこのくらいの雑さでも大丈夫ということを学習したのかもしれません、、、。
(単純に学習が足りてないだけの可能性が高いです)
まとめ
今回はDueling DQNの実装を行い、通常のDQNとCartpoleでの性能実験を行いました。
深層学習は他にも多くの手法があるため、それらも実装していきたいと思います。 (直近はRainbowの実装を目指していきます。)