1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[深層強化学習] Dueling Networkの実装とcartpoleにおける性能比較

Last updated at Posted at 2021-09-12

はじめに

DuelingDQNとは

スクリーンショット 2021-09-12 15.17.50.png

通常のDQNは上図のようにSequencialなNetworkを通じて行動価値関数を予測していきます。それに対し、DuelingNetworkでは下図のように途中で状態価値関数を求めるV(s)と行動価値関数からV(s)を引いたものを求めるAdvantageに分岐します。その後、それらを足し合わせることで最終的な行動価値関数を導出します。

本記事ではシンプルなDQNとDuelingDQNをPytorchを用いて実装し、性能比較していきます。

Github

参考

実装

シンプルなNetworkの実装

ここではネットワークのみの実装を行います。

class DQN(nn.Module):

    def __init__(self, num_states, num_actions):
        super(DQN, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(num_states, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, num_actions)
        )

    def forward(self, x):
        x = x.to(device)
        logits = self.linear_relu_stack(x)
        return logits

全結合層を3つ使用し、活性化関数にはReLU関数を使用しています。

Dueling Networkの実装

class DuelingDQN(nn.Module):
    def __init__(self, num_states, num_actions):
        super(DuelingDQN, self).__init__()
        self.num_states = num_states
        self.num_actions = num_actions

        self.fc1 = nn.Linear(self.num_states, 32)
        self.relu = nn.ReLU()
        self.fcV1 = nn.Linear(32, 32)
        self.fcA1 = nn.Linear(32, 32)
        self.fcV2 = nn.Linear(32, 1)
        self.fcA2 = nn.Linear(32, self.num_actions)

    def forward(self, x):
        x = self.relu(self.fc1(x))

        V = self.fcV2(self.fcV1(x))
        A = self.fcA2(self.fcA1(x))

        averageA = A.mean(1).unsqueeze(1)
        return V.expand(-1, self.num_actions) + (A - averageA.expand(-1, self.num_actions))

DuelingDQNでは最初に全結合層を行い、その後AdvantageとVに分岐され、最終的にVとAから出力を算出します。

結果

DQNとDuelingDQNでの実行結果を示します。今回は下記のハイパーパラメータを利用し3000エピソード実行しました。

Parameter Value Description
episode 3000 エピソード数
capacity 10000 experience replayの容量
batch_size 32 バッチサイズ
gamma 0.97 時間割引率
target_update_iter 20 target networkの更新間隔
eps_start 0.9 ε-greedyのパラメータ
eps_end 0.5 ε-greedyのパラメータ
eps_decay 200 ε-greedyのパラメータ

DQN

スクリーンショット 2021-09-12 14.46.10.png

Dueling DQN

スクリーンショット 2021-09-12 14.51.20.png

シンプルなDQNを使用した場合は2000エピソード超えたあたりから成績が向上していますが、Dueling DQNを使用した場合は1200エピソードあたりから向上しており、学習効率の向上が見て取れます。

また、Cartpoleは簡単な問題であるのと報酬を200で切っているため、性能についてははっきりわかりませんが、3000エピソードまでで比べるとDuelingDQNの方が成績は良い傾向が見れます。

DuelingDQNの学習結果

ezgif.com-gif-maker.gif

エピソード数が少ないだけあって、だいぶ安定していないですね。

[追記]

このあとエピソード6000回行った結果です。
100回の試行すべてで200stepに到達していますが、若干遊びを覚えているような?感じがします。
200stepくらいならこのくらいの雑さでも大丈夫ということを学習したのかもしれません、、、。
(単純に学習が足りてないだけの可能性が高いです)

ezgif.com-gif-maker (1).gif

スクリーンショット 2021-09-12 16.04.57.png

まとめ

今回はDueling DQNの実装を行い、通常のDQNとCartpoleでの性能実験を行いました。
深層学習は他にも多くの手法があるため、それらも実装していきたいと思います。 (直近はRainbowの実装を目指していきます。)

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?