強化学習
- 強化学習の基本については、機械学習の学習まとめを参照
深層Qネットワーク(Deep Q-Networks, DQN)
- 深層学習を用いた強化学習モデル
- 畳み込みニューラルネットワークを用いて行動価値関数を推定する
- 機械学習では行動価値関数を表形式で表現したが、DQNではDNNを用いて表現する
- 畳み込みニューラルネットワークを用いて行動価値関数を推定する
- 学習を安定化させるための工夫を行っている
- 体験再生(experience replay)
- 各タイムステップ$ t $におけるエージェントの経験$ e_t = (s_t, a_t, r_t, s_{t+1}) $をデータ集合$ D = \lbrace e_1, \cdots , e_N \rbrace $に蓄積する
- このデータ集合$ D $から取り出された経験を再生記憶(replay memory)という
- 学習時は蓄積されたサンプルの中から、経験をランダムに取り出し、損失の計算に用いる
- 体験再生には通常のオンラインQ学習と比べて幾つかの利点がある
- パラメーターの更新時に、同じ経験を何回も学習に使えるため、計算量の大きなエピソードの進行の回数を抑制することができ、データの効率がよい
- ランダムに取り出された経験を用いて損失を計算するため、入力系列の相関を断ち切ることができ、更新の分散を軽減できる
- 過去の様々な状態で行動分布が平均化されるため、直前に取得したデータが次の行動の決定の決定に及ぼす影響を軽減でき、パラメーターの振動や発散を避けることができる
- 各タイムステップ$ t $におけるエージェントの経験$ e_t = (s_t, a_t, r_t, s_{t+1}) $をデータ集合$ D = \lbrace e_1, \cdots , e_N \rbrace $に蓄積する
- 目標Qネットワークの固定(ターゲットネットワーク)
- 学習の目標値(教師信号)算出に用いるネットワークと、行動価値$ Q $の推定に用いるネットワークが同じ場合、行動価値関数を更新すると目標値(教師信号)も変化してしまい、学習が不安定になる
- これに対し、目標Qネットワークを固定(= 目標値の算出に用いるQネットワークのパラメーターを固定し、一定周期でこれを更新する)することによって、学習を安定させることが可能になった
- 報酬のクリッピング
- 報酬の値を$ \lbrace -1, 0, 1 \rbrace $の3値に制限することによって、報酬のスケールが与えられたタスクによって大きく異なるという問題を解消し、ゲームごとに学習率を調整することを不要とした
- 体験再生(experience replay)
- 損失関数は以下の通り
- $ L(θ) = \mathbb{E}_{s, a, r, s' \sim D} \lbrack (r + γ \underset{a'}{max} Q(s', a'; θ^-) - Q(s, a; θ))^2 \rbrack $
- ここで$ θ^- $が固定されたパラメーターを意味する
- $ L(θ) = \mathbb{E}_{s, a, r, s' \sim D} \lbrack (r + γ \underset{a'}{max} Q(s', a'; θ^-) - Q(s, a; θ))^2 \rbrack $
AlphaGo
- 以後のようなゲームに対し、盤面$ s $に対する最適な状態価値関数$ v^*(s) $を評価するのは一般的に難しい
- $ b $をゲームの幅(状態ごとの合法手の数)、$ d $をゲームの深さ(ゲームの長さ)として、状態価値を探索木で計算するためには$ b^d $のオーダーの計算が必要になるが、これが以後の場合$ 250^{150} $という計算になってしまう
- このため、深さと幅を減らすことが必要になる
- 深さを減らすためには、盤面$ s $の状態価値の評価をうまく近似する必要がある
- 幅を減らすためには、全公道からのサンプリングの仕方を工夫する必要がある
- AlphaGoでは、畳み込みニューラルネットワークでモデル化した状態価値関数及び方策関数を用いることで、探索の際の深さと幅を減らしている
- AlphaGo Lee: 2つのNNで構築された関数をもつ
- PolicyNet(方策関数)
- CNN
- 入力は19×19の交点×48チャンネル(自石の位置、敵石の位置など)
- 畳み込みとReLUの繰り返し
- 出力はSoftmaxで19×19の交点それぞれの着手予想確率を出力する
- CNN
- ValueNet(価値関数)
- CNN
- 構成はPolicyNetに似ているが、入力は49チャンネルとPolicyNetより1チャンネル多い
- 出力は原局面の勝率(Flattenで1次元化されたもの)をtanh関数によって-1〜1で表したものが出力される
- CNN
- 学習のステップ
-
教師あり学習によるRollout policyとSL (Supervised Learning) policy networkの学習
- Rollout policyはNNではなく線形の方策関数
- 探索中に高速に着手確率を出すために使用される
- SL policy networkの教師は人間同士の対局の棋譜データ
- Rollout policyはNNではなく線形の方策関数
-
強化学習によるRL (Reinforcement Learning) Policy networkの学習
-
強化学習によるValue networkの学習
- N手までSL Policy networkに従う
- N+1手はランダム(ただし合法手)
- 以降はRL Policy networkで終局まで
-
- 学習後の実践では、モンテカルロ木探索を用いて相手の手を読む
- 探索の結果、最も訪問回数が多い行動を次の差し手として採用する
- PolicyNet(方策関数)
- AlphaGo Zero
- いくつかAlphaGo Leeから進化
- 教師あり学習を行わず、強化学習のみで構成
- PolicyNetとValueNetを一つのネットワークに統合
- Residual Netを導入
- 一つのネットワークから2つの出力(Policy出力とValue出力)を得る
- Residual Network
- ネットワークにショートカット構造を追加して、勾配の爆発、消失を抑える効果を狙ったもの
- Residual Networkを使うことにより、 100層を超えるネットワークでの安定した学習が可能となった
- Residual Network を使うことにより層数の違うNetworkのアンサンブル効果が得られている
- AlphaGo Zeroでは、Residual Blockを39個つなげて使っている
- ネットワークにショートカット構造を追加して、勾配の爆発、消失を抑える効果を狙ったもの
- いくつかAlphaGo Leeから進化
A3C(Asynchronous Advantage Actor-Critic)
- 強化学習の学習法の一つ
- 特徴は複数のエージェントが同一の環境で非同期に学習すること
- Asynchronous: 複数のエージェントによる非同期な並列学習
- Advantage: 複数ステップ先を考慮して更新する手法
- Actor: 方策によって行動を選択
- Critic: 状態価値関数に応じて方策を修正
- Actor-Criticとは、行動を決めるActor(行動器)を直接改善しながら、方策を評価するCritic(評価器)を同時に学習させるアプローチ
- A3Cによる非同期学習の詳細
- 複数のエージェントが並列に自律的に、rollout(ゲームプレイ)を実行し、勾配計算を行う
- その勾配情報をもって、好き勝手なタイミングで共有ネットワーク(パラメーターサーバー)を更新する
- 各エージェントは定期的に自分のネットワーク(local network) の重みをglobal networkの重みと同期する
- 非同期であるため、方策モデルにRNNモデルやLSTMモデルを適用可能
- 並列分散エージェントで学習を行うA3Cのメリット
- 学習が高速化
- GPUなしでも、より短い訓練期間で学習することができる
- マルチコアCPUマシンで、A3CはDQNよりも短い演算時間で高い性能が得られる
- GPUなしでも、より短い訓練期間で学習することができる
- 学習を安定化
- 経験の自己相関が引き起こす学習の不安定化は、強化学習の長年の課題
- A3Cは方策オン手法(直接方策を評価する)であり、サンプルを集めるエージェントを並列化することで自己相関を低減することに成功
- 学習が高速化
- A3Cの課題
- Python言語の特性上、非同期並行処理を行うのが面倒
- パフォーマンスを最大化するためには、大規模なリソースを持つ環境が必要
- A3Cのロス関数
- 一般的なActor-Criticでは、方策ネットワークと価値ネットワークを別々に定義し、別々のロス関数(方策勾配ロス/価値ロス)でネットワークを更新する
- これに対し、A3Cはパラメーター共有型のActor Criticであり、1つの分岐型のネットワークが、方策と価値の両方を出力し、たった1つの「トータルロス関数」でネットワークを更新
- トータルロス関数は、アドバンテージ方策勾配、価値関数ロス、方策エントロピーの組み合わせ
- A2C
- A3Cの後に、同期処理を行うA2Cが発表された
- 同期処理なので、Pythonでも実装しやすい
- 性能がA3Cに劣らないことがわかったので、その後よく使われるようになった
- A3Cの後に、同期処理を行うA2Cが発表された