LoginSignup
9
6

【強化学習】Dreamerを解説・実装

Last updated at Posted at 2023-06-18

この記事は自作している強化学習フレームワークの解説記事です。

前:PlaNet
次:DreamerV2

星(Planet)の次は夢(Dream)ですね。

PlaNetの問題点とDreamer

PlaNetではダイナミクスモデルで実環境の振る舞いを学習できましたが、プランニングに時間がかかりすぎました。
そこでプランニング(アクションの探索)とアクションの選択を切り離すことでプランニングにかかる時間を回避した手法がDreamerです。

Dreamer

ダイナミクスモデルはPlaNetと同じですが、プランニング用に新しくAction/Valueネットワーク(Actor/Criticモデル1)を追加しています。
Dreamerではダイナミクスモデルが誤差逆転伝搬可能なのを利用し、ダイナミクスモデル内にAction/Valueネットワークを組み込んで直接方策を学習します。2

参考
・論文:https://arxiv.org/abs/1912.01603
・Official Code: https://github.com/danijar/dreamer
Introducing Dreamer: Scalable Reinforcement Learning Using World Models | Google AI Blog
Dreamer: A State-of-the-art Model-Based Reinforcement Learning Agent | Published in Towards AI
Dreamer:長期視点で考える事が出来る強化学習(1/3) | WebBigdata
Google Open Sources Dreamer : 画像から長期視点のタスクを学習できる強化学習エージェント
DayDreamer: Dreamerがついに実ロボットに! | AI-SCHOLAR
DayDreamer論文

1. ダイナミクスモデル

概要は以下で、基本はPlaNetと同じです。

ss2.png

詳細は以下です。(Official Code から図示しました)

draw-ページ13-1.drawio.png

deter(決定的遷移)とstoch(確率的遷移)とアクションを元に次のdeterとstochを予測しています。
stochは実際に観測した状態を元に出力する部分(obs_step→post stoch)と過去の履歴のみで予測する部分(img_step→prior stoch)に分かれているようです。
prior stoch は学習で post stoch に近づけ、アクションの予測では prior stoch を使います。

PlaNetから変わった点として報酬も明示的に含まれました。
また、PlaNetで議論があった Overshooting は加味されていません。
(論文では特に触れていなく、Overshootingなしで学習してもちゃんと学習できたとの記載がありました)
(多分報酬もlossに含まれたので、stepをまたいだ情報が伝播されたのかな?)

ダイナミクスモデルの学習

学習はPlaNetと同様、次の変分下限(Evidence Lower Bound; ELBO)を最大化することです。
(PlaNetと比べ報酬項が追加されています)

$$ E\Big[ \sum_t \log q(o_t|s_t) + \log q(r_t|s_t) - \beta D_{kl}[p(s_t|s_{t-1}, a_{t-1}, o_t)||q(s_t|s_{t-1},a_{t-1})] \Big] $$

RSSMは逐次VAEと見ることができるので、VAEと同様にELBOを最大化することで学習します。
(ELBOは理解が浅いので理論的な話は止めておきます)
(参考:Variational Autoencoder徹底解説 | Qiita

式ですが、第1項と第3項は以前からある再構築項(reconstruction loss)と正則化の役割を果たすKL距離(KL loss)です。
第2項は報酬を予測する項となります。

2. Action/Valueネットワーク

概要は以下です。

ss3.png

Action/Valueネットワークは方策と価値の両方を学習するActor-Criticネットワークとなります。
潜在空間 $s$ を元に、Actionネットワークがアクションを予測し、Valueネットワークが状態価値を予測します。

  • 数式上はActionとValueは別パラメータなのでネットワークを別々に作っていますが、A3C/A2CやAlphaZeroみたいに1つのネットワークにしてもいい気がします。

  • Actionは方策を学習するため、離散値/連続値どちらの環境にも応用できます。

詳細は以下です。

draw-ページ13-2.drawio.png

1step目は画像データがあるので post stoch を元にアクション、報酬、状態価値を予測します。
2step以降は prior stoch を元に予測していきます。

学習に使う状態価値 $V(t+1)$ は予測した報酬 $r'$ と状態価値 $V'$ を元に計算し、この状態価値 $V(t+1)$ を元に ActionModel と ValueModel を学習させます。

Action/Valueネットワークの学習

学習に使う状態価値の計算は、想像上の軌跡(HorizonStep)を元にダイナミクスモデルで計算します。(実環境は使いません)
ただ、予測する軌跡の初期状態は実環境から収集したデータを使います。

状態価値の計算方法は論文内では2種類記載がありました。

報酬による価値の推定

報酬のみを使う最もシンプルな計算方法です。

ss44.png

各stepに対して、HorizonStepまでの報酬の平均値が状態価値になる方法です。

指数加重平均による価値の推定

バイアスと分散のバランスを取るために、各HorizonStepで指数加重平均(Exponentially Weighted Average)を使用する方法です。

ss4.png

$V^{\lambda}$ は HorizonStep に対して計算された価値 $V^k_N$ を指数加重平均した結果です。
$V^k_N$ は各stepから HorizonStep までのstepに対して、割引報酬と割引予測状態価値の期待値を出しています。

  • 指数加重平均について

EWAは以下のようです。

$$
EWA(t) = \beta EWA(t-1) + (1-\beta) * NewSample
$$

$\beta$ を調整するだけで加重平均を効率的に計算する手法のようです。
$\beta$ を大きくすると過去の値を重要視し、小さくすると現在の値を重要視するようです。

参考
Exponentially Weighted Averages

Action/Valueネットワークの更新

上記で計算した価値 $V$ に対して、Value ネットワークはこの値自体をMSEで学習します。

ss55.png

Actionネットワークはそのまま $V$ が最大になるアクションを学習します。(方策勾配法は使いません3

ss56.png

また、理由は書かれていませんが、Actionネットワークの学習ではエントロピーボーナス(SAC等で取り入れられている方策に制約をかけて探索を促進する正則化項)とターゲットネットワーク(DQN等で取り入れられているニューラルネットを用いた価値関数の学習を安定化させるテクニック)は特に用いなかったとの事でした。
(エントロピーボーナスはDayDreamerでは取り入れられていたりします)

学習のサイクル

論文内の疑似コードを書き直すと以下です。

memory = []
while True:
    # --- 学習
    for epoch in range(epochs):
        batchs = memory.sample(batch_size)

        # --- ダイナミクスモデルの学習
        # 計算の過程で各バッチと各シーケンス毎に stoch と deter が出てくる
        # なので、stoch/deterは batch_size * batch_length のサイズ
        stochs, deters = dynamics_model.train(batchs)

        # --- Action/Valueネットワークの学習
        # horizon step 進ませる(horizon stepはハイパーパラメータ)
        V = []
        for step in horizon:
            policy = action_model([stochs, deters])

            # 予測に dynamics_model を使用(ただし学習させない(勾配を流さない))
            stochs, deters = dynamics_model.img_step(stochs, deters, policy)

            reward = dynamics_model.predict_reward([stochs, deters])
            value = value_model([stochs, deters])

            V.append(学習用の価値Vを計算する)

        action_model.update(Vが最大になるように)
        value_model.update(Vを予測できるように)


    # --- 実環境の実行(1episode例)
    sequence_batch = []  # 1episodeのバッチ
    prev_action = 0  # 論文内では初期値の記載は特になし
    prev_deter = None
    prev_stoch = None

    state = env.reset()
    for step in range(env.max_episode_steps):
        # 1つ前のstepの情報から次の状態を予測(ダイナミクスモデル)
        embed = dynamics_model.encode(state)
        prior_stoch, deter = dynamics_model.img_step(prev_stoch, prev_deter, prev_action)
        post_stoch = dynamics_model.obs_step(embed, deter)
        
        # アクションモデルからアクションを予想
        action = action_model.policy([post_stoch, deter])
        prev_stoch = post_stoch
        prev_deter = deter

        # env step
        action += noise  # ノイズを追加して探索を促す
        reward, state = env.step(action)
        prev_action = action

        sequence_batch.append([state, action, reward])

    memory.append(sequence_batch)

問題は終了状態の扱いですが、論文内ではシーケンスサイズ=エピソードサイズのようで可変の場合の記載がありません。

Official Code ではエピソードをまたいで指定stepになるように収集していました。
ただこれだとエピソード途中から始まる状態を学習できないような…(R2D2あたりを参考)

フレームワークでは1エピソードを単位とし、足りない場合はエピソード終了状態を付け足すようにしました。
(なのでシーケンスがエピソード未満だとエピソード全体を学習できません)

コード

実装コードはgithubを見てください。
フレームワーク上はDreamerV1/V2/V3を統合してV3だけにしています。

学習

学習コードはgithubを見てください。

遅くなった理由の一つに学習に時間がかかることがありました…
WorldModelからですが、どうしても画像から環境自体を学習するので時間がかかってしまいます…

_dreamer.gif

画像は、図の左上がオリジナルの環境で original とあるのがWorkerが受け取る状態です。(64×64にリサイズされた後を受け取っています)
decode は original 画像を VAE を通して復元した結果です。
action の下にある画像は RSSM を通して予測された次の状態 z を復元したものです。
復元結果は一番上がmean(平均)を用いた画像で、下2つはランダムに出力した画像です。

おわりに

実装に手間取りました。。。

公式ページですが、軒並み DreamerV2 を進めてきますね。
DreamerV2 もそのうち記事にしたいですね。

  1. 論文内ではAction/Valueという用語が使われていますが、機能としてはActor/Criticと同じです

  2. これはモデルフリー強化学習にはできない方法です

  3. 方策勾配法を使わず直接状態価値を最大化できるのは、ダイナミクスモデルが微分可能で複数ステップに渡る勾配を計算できるからとなります

9
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
6