$$\newcommand{\argmax}{\mathop{\rm arg~max}\limits}
\newcommand{\argmin}{\mathop{\rm arg~min}\limits}$$
はじめに
深層強化学習の基本である DQN(Deep Q-Network)では、学習を安定させるための工夫として Fixed Target Q-Network というものが使われています。これも含めた DQN のアルゴリズムの概要は以下のようになっています。
また、DQN を改良したものとして Double DQN があります。こちらでは以下のような式が使われます。一つ目が DQN で、二つ目が Double DQN です。
正直、式だけを見ても、Fixed Target Q-Network がどういうものなのか、DQN と Double DQN の違いは何なのか、私にはピンときませんでした。そこで、図を書くことでより直感的に理解できないかと思い、今回の記事を書くことにしました。同じように深層強化学習の理解に困っている方の参考になれば幸いです。
もし誤りなどありましたらご指摘いただければ幸いです。
参考
- DQN 元論文:
Human-level control through deep reinforcement learning - Double DQN 元論文:
Deep Reinforcement Learning with Double Q-learning - DQNからRainbowまで 〜深層強化学習の最新動向〜
step 0: 2つのネットワークを用意する。
まず前提として、DQN では同じ構造のニューラルネットワークを2つ用意して使います。これらを メインネットワーク、ターゲットネットワークと呼びます。ネットワークの詳細は以下。
- 入力:状態
- 入力次元:状態の種類
- 出力:Q 値(行動価値)
- 出力次元:行動の種類
- 重み:それぞれ $\theta, \theta^-$ とする
- 初期値は $\theta^-=\theta$
step 1: 現状態から行動を決定する。
1-1. メインネットワークに現状態 $s_t$ を入力し、出力 $Q_{main}(s_t,a_1;\theta),...,Q_{main}(s_t,a_n;\theta)$ を得ます。
1-2. 出力の中で最大値をとるようなインデックスを行動 $a_t$ とします。
$$a_t = \argmax_{a}Q_{main}(s_t,a;\theta)$$
step 2: step 1 で決定した行動から、次状態、報酬を得る。
図の通り、エージェントは行動 $a_t$ を選択し、次状態 $s_{t+1}$ 、報酬 $R_{t+1}$ を得ます。
step 3: step 2 で得た次状態から、目標値を得て、メインネットワークを学習させる。
3-1. step 1 と同様にして、$Q_{main}(s_t,a_t;\theta)$ を得ます。
$$Q_{main}(s_t,a_t;\theta) = \max_aQ_{main}(s_t,a;\theta)$$
3-2. 次状態 $s_{t+1}$ をターゲットネットワークに入力し、出力 $Q_{target}(s_{t+1},a_1;\theta^-),...,Q_{target}(s_{t+1},a_n;\theta^-)$ を得ます。
3-3. step 2 で得た報酬 $R_{t+1}$ を上記 $Q_{target}$ と合わせて、目標値を求めます。
$$target = R_{t+1} + \gamma \max_aQ_{target}(s_{t+1},a;\theta^-)$$
3-4. 目標値を教師データ(ラベル)として誤差を計算し、メインネットワークを学習させます。ターゲットネットワークは学習させません。
学習時のラベルをより正確に表現すると以下のようになります。
「時刻 $t$ で選んだ行動の Q 値」(図の赤部分)のみを学習させるため、その他は同じ値にしておきます。
step 4: 一定ステップ毎に、メインネットワークの重みをターゲットネットワークに同期する。
メインネットワークを何度か学習させた後、その重みをターゲットネットワークに同期します。同期する間隔は問題などによって適切な値を設定します。
$$\theta^- = \theta$$
以上が DQN の大まかな流れと Fixed Target Q-Network の説明です。次に Double DQN の説明です。
Double DQN(DDQN)
DDQN も基本的な流れは DQN と変わりません。違うのは上記 step 3 です。
step 3: 次状態から、目標値を得て、メインネットワークを学習させる。
3-1. メインネットワークに現状態 $s_t$ を入力して、$Q_{main}(s_t,a_t;\theta)$ を得ます。
$$Q_{main}(s_t,a_t;\theta) = \max_aQ_{main}(s_t,a;\theta)$$
3-2. メインネットワークに次状態 $s_{t+1}$ を入力して、目標値の計算に使う $a_{t+1}$ を求めます。
$$a_{t+1} = \argmax_aQ_{main}(s_{t+1},a;\theta)$$
3-3. ターゲットネットワークに $s_{t+1}$ を入力し、出力 $Q_{target}(s_{t+1},a_1;\theta^-),...,Q_{target}(s_{t+1},a_n;\theta^-)$ を得て、step 2 で得た報酬 $R_{t+1}$ と上記で求めた $a_{t+1}$ を使って、目標値を計算します。
$$target = R_{t+1} + \gamma Q_{target}(s_{t+1},a_{t+1};\theta^-)$$
3-4. 目標値を教師データ(ラベル)として誤差を計算し、メインネットワークを学習させます。
上記の通り、主に異なるのは 3-2 の部分で、メインネットワークでの計算が増えます。図を比較してみると理解しやすいと思います。
まとめ
- DQN の流れに沿って、Fixed Target Q-Network について図を使い説明した。
- DQN と比較しつつ Double DQN について説明した。
以上です。ありがとうございました。