これからDeep Q-Network(DQN)を実装してみたい人向けです。
理論に深入りはせず、少々実装寄りの内容になっています。
また、切り口や表現、説明の仕方が、他の方々の解説とは一味違うかもしれません。
本稿が対象とする読者は、
・Q-Learningについては理解している
・ニューラルネットワークの基本は理解している
方、を想定しています。
Q-LearningからDQNへの発展
Q-Learningではできないこと、そしてそれをQ-Learningの流れを汲んだDQNがどう実現したのか、について。
Q-Learningではできないこと
<復習>TD(Temporal Difference)を用いたQ値の更新の仕方
Q-Learningは、現在のQ(s, a)値を保持する、状態sと行動aの2軸のQ(s, a)テーブルを基盤としています。
Agentを動かし、軌跡と即時報酬を得ながら、そのQ(s, a)テーブル内の各Q値を、「TD(Temporal Difference)」を加味しながら更新していきます。
具体的には以下です。
時刻tに状態s(t)=s1において行動a(t)=a2を取り、状態s(t+1)=s3に遷移して即時報酬r(t+1)を得た、とします。
Q(s, a)テーブル内のQ(s1, a2)値を以下のように更新します。
上式の「maxQ~」のところがQ-Learningのキモで、同一のある状態sについて、定義されている行動a全ての中で最大となるQ(s, a)値を、TD(Temporal Difference)の算出に使用するのでした。
そして、Q値の更新は、上式のごとく、現在のQ(s, a)値に、学習率αで程よく調整したそのTDを加算する、という方法を取ります。
このように、Q-Learningは、現在のQ(s, a)値を保持する、状態sと行動aの2軸のQ(s, a)テーブルを基盤としています。
Q(s, a)テーブル方式の限界
しかし、そもそも状態sと行動aの2軸のQ(s, a)テーブルというものを作ることができるのは、sとaの組合せがある程度の数で収まる場合だけです。
例えば状態sが連続値の場合など、sとaの組み合わせがある程度の数では収まらない場合、Q(s, a)テーブルを作るのは不可能です。
(ちなみに、行動aが連続値の場合、そもそもDQNを含めたQ-Learning系のメソッドを使用することは考えにくいです。定義されている全aに対するQ(s, a)のmaxを取る必要があり、行動aが離散値でないとやりにくいからです。)
Q(s, a)テーブルではなく、Q(s, a)の近似値を出力する関数にする
Q(s, a)の近似値を出力する関数を作り、その関数のパラメーターを最適化することにより、実際にQ(s, a)の近似値を出力できるようにします。
このようにすれば、Q(s, a)のテーブルを保持する必要はありません。
そして、**関数の実体として、表現力が豊かなニューラルネットワーク(以降「NN」)を使用したのが、Deep Q-Network、「DQN」**です。
※冗長なので、以降は「Q(s, a)の近似値」の「近似値」を省略します。
以降、説明の材料として、OpenAIGymの「CartPole」を使用します。
茶色のバーが倒れないように、黒の台車(Agent)を「左」「右」に動かします。
この「左」「右」が、定義されている行動aです。
Deep Q-Network(DQN)
ここからが本論で、DQNそのものについて語りたいと思います。
一般的構成
入力と出力
「関数」なので、入力と出力があります。
今までの話の流れを正しくとらえていると、以下のように発想するのが自然です。
「Q(s, a)テーブルの代わりになるのだから、状態sと行動aのペアを入力とし、そのQ(s, a)値を1個出力する」
しかし、実際にはこのような構成を取りません。
一般的には、以下の構成を取ります。
入力:状態s
出力:(離散値として)定義されている全ての行動aの1つ1つに対する、Q(s, a)値
例えばCartPoleなら、出力層のニューロンは、定義されている全ての行動a「左」「右」の2個であり、1回の順伝播での出力は、Q(s, a=左)とQ(s, a=右)の2個となる、ということです。
なぜこのような出力構成になっているか、は、下の図の通りです。
※DQNのTD(Temporal Difference)については、後で詳しく述べます。とりあえず今は、DQNで使用されるTDの式の中に、$ max_{a'}[Q(s(t+1), a')] $というのがあって、「全aについてのQ(s, a)の最大値を取る必要があるんだな」、くらいの認識でいいです。
要は、**行動aの全てについてのQ(s, a)のmaxを取る必要があるんだったら、その各aについてのQ(s, a)を一度に出してしまう、**ということです。
下の図では、Q(s, a=左)=10 < Q(s, a=右)=100なので、Q(s, a=右)=100がmaxであることが、1度の順伝播でわかります。
DQNの訓練
DQNはパラメーターにより構成されるNNです。
このパラメーターをどう最適化してQ(s, a)の近似値を出力できるようにするか、がテーマです。
訓練の方向性~「TD(Temporal Difference)の最小化」
DQNはQ(s, a)を出力する関数であり、元祖Q-LearningのようにQ(s, a)テーブルを保持していないので「現在のQ(s, a)値」というものは存在しません。
従って、元祖Q-Learningの訓練の方針である
「現在のQ(s, a)値に、学習率αで程よく調整したTDを加える」という方法は取りません(取れません)。
DQNはNNの一種であり、「損失(誤差)を定義し、誤差逆伝播によりパラメーターを最適化していく」ということは他のNNと変わりありません。
DQNでは、**「TDをNNの損失(Loss)に見立てて、それを回帰タスクを解くことで最小化していく(最小化するパラメーターを求める)→結果としてQ(s, a)を近似できるようになる」**ということをします。
※実際にはTDそのものを損失(Loss)とするのではなく、TDをもとに算出した2乗和損失やHuber損失といったものを損失(Loss)として使用します。
回帰タスク
DQNで使用されるTD(Temporal Difference)と損失(Loss)
DQNで使用されるTDは、以下の数式で表されるものです。
r(t+1)+γmax_{a'}[Q(s(t+1), a')]-Q(s(t), a(t))
回帰タスクにおける損失には、教師信号の定義がさらに必要です。
教師信号、TD、損失をまとめたのが、以下の図です。
DQN回帰タスクの損失(Loss)は、TDをもとに算出した2乗和損失やHuber損失です。
例えば、最もよく用いられる2乗和損失は、$ TD^{2}/2 $です。
※実装では、バッチで平均を取ることを忘れないでください。
この回帰タスクの大きな問題
この回帰タスクには、大きな問題が1つあります。
それは、
**教師信号のもとになるQ値を、訓練されるNN自らが生成**する必要がある
ということです。
教師信号
r(t+1)+γmax_{a'}[Q(s(t+1), a')]
の中に「Q(s(t+1), a')」というのがあります。
これはQ値なので、NNの出力値です。
NNの訓練に使用する教師信号(正解として出力値を近づける目標値)がそのNN自身の出力値によって成り立っている、という、不合理なことが起こっているわけです。
一般的な教師あり学習では、不変の正解である「教師データ」が渡されます。
が、DQNには(というより強化学習全般的に)そういうものは無いので、「教師信号」を自ら作り出すしかない、という背景があります。
この問題から、以下の2つの問題が派生します。
問題1:教師信号の値の妥当性
1イテレーションだけで見たときに、「その教師信号を回帰の目的値としていいか」という観点です。
教師信号を構成するQ値単体では、訓練途上の未熟なNNが出力したものなので、値も不正確なものでしょう。
が、同様に教師信号を構成する即時報酬「r(t+1)」を見てわかる通り、この教師信号全体としては、「s(t)で実際にa(t)してs(t+1)に進み実際に即時報酬r(t+1)を得た後の、少し正確になった見積」とみなすことができます。
なので、この”ささやかな”回帰を多数回地道に繰り返し、尺取虫のように「正しいQ値」に近づけていける、と考えることができます。
また、この教師信号の値の”質”の問題については、例えばmulti-step learningなど改善の手法が編み出されており、それらを併用することにより、改善することができます。
問題2:教師信号の値が激しく変動(Moving Target)
NNは訓練途上にあるので、頻繁にパラメーターが更新されます。そのNNが自ら教師信号のもとになるQ値を生成するので、教師信号が揺れる、というのは理解しやすいと思います。
**目的値たる教師信号が激しく変動する(Moving Target)ので、当然訓練は安定しません。結果、NNは、よい近似値のQ値を出力するところまでちゃんと訓練されない、**ということになります。
NNのパラメーター更新頻度そのものが問題だが、パラメーター更新頻度を下げてしまうと、NNの訓練自体が進まなくなってしまう
→では訓練対象のNNのクローンを作ってそれに教師信号のもとになるQ値を出力させ、たまにしかパラメーター更新しないようにすればいいでしょ
という考え方で、「Fixed Target Q-Network(Fixed Target QN)」とか「Target QN固定」と呼ばれる手法が考え出され、定着するに至りました。
Fixed Target Q-Network(Fixed Target QN)
・**教師信号のもとになるQ値を出力するだけが役割のNN「Target QN」**を用意
・「Target QN」は、もともとの訓練対象NN「Main QN」のクローン
レイヤー構成、パラメーター初期値が同じ。
・Main QNのパラメーター更新頻度より低い頻度で、Main QNのパラメーターをTarget QNにコピー
例)Main QNのパラメーター更新が1ステップ毎なら、パラメーターコピーは1エピソード毎
Target QNへのパラメーターコピーの頻度は、Main QNのパラメーター更新頻度より低い頻度(これが「Fixed」)でなければ意味がない。
Main QNのパラメーターを「$ θ $」、Target QNのパラメーターを「$θ^-$」と表記すると、TDの式は以下のように変わります。
DQNで使用されるTD(Temporal Difference)~Fixed Target QN適用
r(t+1)+γmax_{a'}[Q(s(t+1), a';θ^{-})]-Q(s(t), a(t);θ)
上式中、2個あるQ値のうち1個目(左側)のQ値は教師信号を構成するQ値です。つまりTarget QNの出力なので、「 $;θ^-$ 」となっています。
DQN訓練の全体図~Fixed Target QN適用
※字が小さくなってしまったので、クリックして拡大して見た方が良いです。
全体図の補足
・「Experience Replayにより、あらかじめ収集されている経験データ」
DQN訓練前、及び訓練中に渡って、DQNの推論結果に基づきAgent(CartPoleなら黒の台車)を動かし、大量の軌跡と即時報酬の履歴を収集・保管し(これが「経験データ」)、DQN訓練データとして使用する、というものです。
Experience Replayは、DQNとセットで使用される手法、と認識しています。
本稿中のDQNの「s(t)」「a(t)」「s(t+1)」「r(t+1)」というのは、実はこの大量の経験データの一部です。
Experience Replayについては、後で詳しく述べます。
・「a(t)ではないニューロンにあてる教師信号は出力値そのものとする」
NNの教師あり学習の誤差逆伝播の基本なのですが、出力層の全ニューロンに対して、教師信号が必要です。
しかし、TDの式
r(t+1)+γmax_{a'}[Q(s(t+1), a';θ^{-})]-Q(s(t), a(t);θ)
の最後の項、いわゆる算数の「引く数」は、$ Q(s(t), a(t);θ) $ です。
これは、出力層の全ニューロン(定義されている行動aの全て、CartPoleなら「左」「右」)のうち、**行動aがa(t)である1個の出力ニューロンの出力値を回帰で是正する式でしかない、**ということです。
例えば、経験データ中のa(t)=左なら、「左」の出力ニューロンの出力値Q(s(t), a=a(t)=左)のみの是正であり、「右」(≠a(t))の出力ニューロンの出力値Q(s(t), a=右)に対するものではない、ということです。
しかし前述の通り、NN教師あり学習の誤差逆伝播では、**この「右」(≠a(t))の出力ニューロンに充てる教師信号も必要なのです。ではその教師信号は何?という話で、「それは出力値Q(s(t), a=右)と全く同じとする。「右」の出力ニューロンの出力値として何が妥当か不明なので、”出力値は現状維持”しかない」**というのがここの趣旨です。
<補足>DQNの訓練で追加適用される手法
DQNの訓練をする際、必ずorよく適用される手法をいくつか挙げます。
Experience Replay
例えば以下のようにしたとします(1エピソード毎にNNを更新するとする)。
-------
「Agentが状態s(t)でa(t)し、その結果s(t+1)に遷移し即時報酬r(t+1)を得た」
↓
「引き続きAgentが状態s(t+1)でa(t+1)し、その結果s(t+2)に遷移し即時報酬r(t+2)を得た」
:
エピソード終了
このエピソードの全軌跡(一連の状態sと行動a)と即時報酬の履歴を用いて、NNを更新
-------
これら入力データ「一連の状態s」は典型的な時系列データであり、互いに独立ではないです。
相関性の強い入力系列を用いて訓練すると、訓練が安定しないことが知られています。
この問題に対処するための「Experience Replay」という手法があり、DQNでは通常この手法を使用します。
Experience Replayの骨子は以下の通りです。
・訓練前に(だけでなく訓練中でも)Agentを走らせて大量の経験データを収集し蓄積
「Agentを走らせる」とは、訓練対象DQNにQ値を出力させ、そのQ値に基づく推論actionをAgentに取らせる、ということです。
・その蓄積された大量の経験データ群からランダムに経験データを抽出し、訓練に使用
「相関性の強い入力系列」でなくするために、ランダム抽出します。
「経験バッファ」(あるいは「経験メモリ」)と呼ばれる”容器”に、Agentを動かして、経験データという”水”をひたすら注いでいきます。
経験バッファ(容器)が満タンになると、古い経験データ(水)から捨てられていく、という仕組みになっています。
経験バッファ(容器)の容量はけっこう大きめです(1万~数十万件の経験データ)。
時刻tにおいてAgentが1ステップ動く、ということを軌跡と即時報酬で表現すると以下のようになります。
「状態s(t)で行動a(t)を取り、状態s(t+1)に遷移して即時報酬r(t+1)を得た」
この1ステップの軌跡と即時報酬を経験データ1件として経験バッファに格納します。
つまり、**経験データ1件[ s(t)、a(t)、s(t+1)、r(t+1) ]が1万~数十万件、経験バッファに格納されており、訓練データとしてこの中からランダムサンプリングされる、**ということになります。
Multi-step learning
教師信号の”質”を上げる手法です。
本稿の今までのTDは、「1」ステップだけ進め、その遷移先s(t+1)で得た即時報酬r(t+1)のみを使用していました。
r(t+1)+γmax_{a'}[Q(s(t+1), a';θ^{-})]-Q(s(t), a(t);θ)
が、さらに「数」ステップ進め、その間に獲得した即時報酬と遷移先の情報を使用した方が、「実際起きたこと」の情報が多いので、教師信号として正確になるのでは、という考え方によります。
例えば、2ステップ進め、その間の即時報酬を使用する場合、TDは、
r(t+1)+γr(t+2)+γ^{2}max_{a'}[Q(s(t+2), a';θ^{-})]-Q(s(t), a(t);θ)
3ステップの場合、TDは、
r(t+1)+γr(t+2)+γ^{2}r(t+3)+γ^{3}max_{a'}[Q(s(t+3), a';θ^{-})]-Q(s(t), a(t);θ)
従って「nステップ先まで進める」と一般化すると、TDは、
\sum_{k=1}^{n}(γ^{k-1} r(t+k))+γ^{n}max_{a'}[Q(s(t+n), a';θ^{-})]-Q(s(t), a(t);θ)
となります。
n=1、つまり1ステップだけ進めて教師信号を算出する場合の式が、本稿で今まで見てきたTDの式だということがわかると思います。
逆に、突き詰めて最終ステップまで進めると、モンテカルロ法になります。
※上式いずれも、Fixed Target QNを併用している前提で書いています。
DQN訓練の全体図~Fixed Target QNとMulti-step learning適用
※字が小さくなってしまったので、クリックして拡大して見た方が良いです。
Multi-step learningは、原理が単純で理解しやすく、実装も難しくなく効果もあります。
DQN実装時は、Multi-step learningを適用することをオススメします。
ちなみに、n=2~5くらいが最も良い結果が得られる、とのことです。
Multi-step learningの趣旨から、「2~5ステップなんてケチなこと言わないで、最後まで行ってしまえばいいじゃないか」と思ってしまうかもしれません。
しかしそれだとモンテカルロ法と同じで、「バイアス」の問題が発生してしまいます。
訓練途中の(未熟な)DQNが推論した結果の行動に基づきAgentが動き、その軌跡と道中の即時報酬が訓練データとして使用されます。
例えば100ステップまで進めたとすると、その軌跡と即時報酬の履歴は、訓練十分なDQNが推論した結果のそれらとかなり乖離があるでしょうから、教師信号としては不正確なはずです。
しかし2~5ステップなら、それほどの乖離は無いでしょう。
いささか乱暴な例えですが、
酔っ払いの100歩はシラフの人の100歩とは軌跡は随分違うだろう。しかし2~5歩くらいならまあそんなに違わないだろう。
ということだと思います。
報酬のクリッピング
即時報酬を、+1、0,-1の3通り「のみ」とします。
これにより、訓練は安定しスピードが上がります。
一般的には、(本来の即時報酬値が)正の場合は+1,負の場合は-1,0の場合はそのまま0
とします。
ここで誰しもが疑問に思うのは、
「行動の良し悪しの程度を報酬で表現できないのならば、”より良い行動”を学ばなくなるのでは?」
だと思います。
確かにそれはこの手法の短所だと思います。
が、それ以上に訓練の安定性が欲しい、というタスクもあろうかと思います。
そういう場合に使用するのが良いでしょう。