Help us understand the problem. What is going on with this article?

FlappyBird で強化学習の練習 その2: Double DQN

※ 2019/04/04 追記: 問題設定を少し修正し、実験をやり直しました。
※ 2019/04/19 追記: モデルを少し変更し、実験をやり直しました。

この記事は何

せっかく Pythonで学ぶ強化学習 をざっと読んだので、手を動かしてみる大作戦です。
FlappyBird という数年前に話題になったゲームがあり、それを強化学習を用いて学習していきたいと思います。
目標は満点である264点を安定して取ることです。
のんびり動かしてみつつ、色々やったことを記録していこうと考えています :muscle:

本記事では、前回の DQN に Double DQN という手法を組み合わせました。
勉強しつつ書いてるので、何か誤りなどあればコメントいただけると助かります :bow:

実装は jupyter notebook 上で行っており、 今回のコードはこちらです。
リポジトリはこちら: cfiken/flappybird-try

目次

  • Double DQN とは
  • FlappyBird の結果
  • 実装の紹介
  • まとめ

Double DQN とは

  • Q-Learning の問題
  • Double Q-Learning
  • Double DQN

Q-Learning の問題

DQN をはじめとする Q-Learning では、次のように価値を更新します。

\tilde{Q}(S_t, a) = R_{t+1}+\gamma \max _{a} Q\left(S_{t+1}, a; \boldsymbol{\theta}_{t}\right)

Q 関数の推定にパラメータを持つ関数を仮定している場合、次のような損失関数を最小化することでパラメータを学習します(MSEの場合)。

L = \frac{1}{2}(Q(S_t, a; \boldsymbol{\theta}_t) - \tilde{Q}(S_t, a))^2

ここで、$\tilde{Q}$の計算式の右辺第二項の計算が問題となります。常に max となる価値をもとに計算しているため、推定のズレが価値が大きくなる方向に偏ってしまい、結果として推定が過大評価されがちになってしまいます。

具体的には、Q の値に一様分布によるランダムな誤差が$[-\epsilon, \epsilon]$の範囲で付加されていると仮定すると、 $\gamma \max_{a} Q\left(S_{t+1}, a;\boldsymbol{\theta}_{t}\right)$ の推定値は、アクション数を $m$ とすると、最大で $\gamma \epsilon \frac{m-1}{m+1}$ の過大評価をしてしまうということが Thrun and Schwartz [1] により示されています。

Double Q-Learning

そこで、van Hasselt [2] によって発表されたのが Double Q-Learning という手法です。
ここでは、1つの同じ関数によってパラメータの更新が行なわれる点が過大評価の原因となっているとし、2つの関数を用意して次のように Q 値を推定しました。
まず、オリジナルの更新式の max の中を次のように(1)行動の選択と、(2)その価値の計算に分解します。

\tilde{Q}(S_t, a) = R_{t+1}+\gamma Q\left(S_{t+1}, \operatorname{argmax}_{a} Q\left(S_{t+1}, a ; \boldsymbol{\theta}_{t}\right) ; \boldsymbol{\theta}_{t}\right)

Q 値が最大となるアクションを選択し、それについて価値を計算しているだけなので、もとの式と変わりません。次に、それぞれを異なる関数で計算します。

\tilde{Q}(S_t, a) = R_{t+1}+\gamma Q\left(S_{t+1}, \operatorname{argmax}_{a} Q\left(S_{t+1}, a ; \boldsymbol{\theta}_{t}\right) ; \boldsymbol{\theta}_{t}'\right)

ここで、$\boldsymbol{\theta}_t, \boldsymbol{\theta}_t'$はそれぞれ2つの関数のパラメータです。Double Q-Learning では、それぞれを交互に学習させ、片方の学習時のターゲットとなる価値の計算にもう片方の関数を用いることで過大評価が起きないようにしました。

Double DQN

Double DQN [3] は、Double Q-Learning を DQN に適用した手法です。DQN では、既に online network と呼ばれる学習対象となるネットワークと、 target network と呼ばれるターゲットの価値を計算するためのネットワークの2つが用意されています。この target network は、一定時間ごとに online network のパラメータが同期されます。

online network, target network のパラメータをそれぞれ $\boldsymbol{\theta} _{t}, \boldsymbol{\theta}^{-}$とすると、

\tilde{Q}(S_t, a) = R_{t+1}+\gamma Q\left(S_{t+1}, \underset{q}{\operatorname{argmax}} Q\left(S_{t+1}, a ; \boldsymbol{\theta}_{t}\right), \boldsymbol{\theta}^{-}\right)

のように、行動の選択には学習中の online network, その価値の計算には target network を使用して更新を行います。これにより、多くの Atari games で DQN が出していた SOTA スコアがさらに更新されました。

論文では、これによりどの程度の過大評価が抑えられたか、なども定量的に実験されています。
下記の図1では、行動の種類の数とそれにより過大評価の量のグラフです。通常の Q-Learning (図の赤)では、取りうる行動が多ければ多いほど推定値の誤差が増えているのが分かる一方、Double Q-Learning (図の青)では、ほとんど変わらないことが分かります。

スクリーンショット 2019-03-23 18.32.43.png
図1: [3] Figure1 より引用

Double DQN は DQN にシンプルな処理の変更を加えるだけで効果があるため魅力的な手法ですが、今回の FlappyBird では、行動が2種類しかないため、そこまで効果は期待できないなと思っていました :eyes:

結果

スコアの計算は次のように行います。

  • モデルを別々に5回学習する
  • それぞれのモデルで50回ずつプレイして、スコア(超えたドカンの数)の平均(と分散)を計算する
  • モデル5個のそれぞれの結果の mean, median を比較する

5つのモデルと最終的な結果は次のようになりました。

mean: 29.8200, std: 26.4497
mean: 31.0800, std: 27.8660
mean: 39.5600, std: 37.1350
mean: 31.0000, std: 27.6695
mean: 31.4200, std: 32.2832
--- total ---
mean: 32.576, median: 31.080

平均スコアは 32.576, 中央値は 31.080 でした。
中央値で見ても DQN での平均スコアから下がってしまいました。原因としては上記でも述べていますが、

  • 行動が2種類しかないため、元々の DQN での過大評価も大きくはなかった

が挙げられるかと思います。
(なお、 Fixed Target でなくてもあまり性能が変わらないという報告を知人より受けたので、なおさら効果は薄そうです...タスクが簡単だからでしょうか...)

下記は50プレイ分の結果です。

学習中の TensorBoard での reward と loss の様子です。それぞれオレンジが DQN ,緑が Double DQN となっています。
結果は悪かったですが、reward のグラフを見ているとそこまで大きな差はなさそうです。これにあわせてハイパーパラメータをチューニングしたらまた良くなるかもしれません。

reward

スクリーンショット 2019-04-04 11.59.16.png

実装の紹介

実行時の notebook はこちらです。

元となっている DQN のコードについては前回記事: FlappyBird で強化学習の練習 その1: DQN に紹介しています。
今回の差分は Agent クラスの update メソッド内の二行だけです。差分の直前にコメントを入れています。

    def update(self, experiences, gamma):
        '''
        与えられた experiences をもとに学習
        '''
        states = np.array([e.state for e in experiences])
        next_states = np.array([e.next_state for e in experiences])

        estimated_values = self.model.predict(states)
        next_state_values = self._teacher_model.predict(next_states)

        # train
        for i, e in enumerate(experiences):
            reward = e.reward
            if not e.done:
                # DQN の場合: reward += gamma * np.max(next_state_values[i])
                # DQN と違い、action の推定は学習中のモデルで、その価値の推定は target network で行う
                next_action = np.argmax(estimated_values[i])
                reward += gamma * next_state_values[i][next_action]
            estimated_values[i][e.action] = reward
        loss = self.model.train_on_batch(states, estimated_values)
        return loss

まとめ

FlappyBird を強化学習で攻略する第一歩として、DQN に Double DQN を追加しました。
学習させた結果、DQN よりも少し性能が下がる結果になりました。
今回の問題設定では、行動が2種類しかないなどが理由で Double DQN が解決する課題がそもそもあまりなかったのだと思われます。
Double DQN に合わせてハイパーパラメータをチューニングしたりするとまた変わるかもしれません。
しかし、実装は簡単なのでとりあえず試す分には良いと思います。

参考文献

[1] S. Thrun and A. Schwartz, Issues in using function approximation for reinforcement learning. Proceedings. 1993.
[2] H. van Hasselt, Double Q-learning. Advances in Neural Information Processing Systems, 23:2613–2621, 2010.
[3] H. van Hasselt and A. Guez and D. Silver, Deep Reinforcement Learning with Double Q-learning, arXiv:1509.06461 [cs.LG]

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away