2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【強化学習】拡散モデルで世界モデルを作ったDIAMONDを解説・実装

Posted at

この記事は自作している強化学習フレームワークの解説記事です。

・フレームワークの記事
https://qiita.com/pocokhc/items/a2f1ba993c79fdbd4b4d

・GitHub
https://github.com/pocokhc/simple_distributed_rl

はじめに

WorldModelシリーズの最新作です。
夢の次はダイアモンドです。

前:Dreamer3

DIAMOND (DIffusion As a Model Of eNvironment Dreams)

毎回名前で遊んでいますね。

前々回のDreamerV2では内部表現を連続空間(ガウス分布)から離散空間(カテゴリカル分布)に変更することで、性能の大幅な向上が実現されました。
この改善の要因として、未来の状況を予測する時に生じる「予測誤差の累積」を緩和できたことが挙げられます。

連続空間では、初期の予測誤差がそのまま次のステップへと伝播していくため、長いステップ予測すると精度の劣化が顕著になります。
一方で離散空間では、内部表現が有限なカテゴリに丸められるので、各ステップでの誤差が一定範囲に収まり、誤差の影響が抑えられると考えられます。
この性質により、安定した長期予測が可能になり、結果として高い予測性能が得られました。

しかし、内部表現の離散化は情報の損失というトレードオフがあります。
たとえば、遠くに小さく映る歩行者のような小さい情報は離散化によって捉えにくくなり、こうした情報の欠落が予測精度の低下につながる可能性があります。

この課題に対して近年の画像生成タスクで主流となっている拡散モデル(diffusion models)の適用を考えます。
拡散モデルは以下の特性があります。

  • 画像をモデル化可能(画像が元々離散表現)
  • 条件付けが簡単で、モード崩壊を起こすことなく多様な分布を表現できる

DIAMONDはこの拡散モデルの特性に注目し、WorldModelを別の角度から離散表現したアプローチとなり、従来の離散化がもつ「情報の欠落」という弱点を拡散モデル生成能力によって補った手法となります。

参考

拡散世界モデル(Diffusion world model)

拡散モデルを知っていた方が理解しやすいと思います。
拡散モデルについては以前書いた記事をどうぞ。

[拡散モデル入門] ゼロから理解する拡散モデルの最新理論(図解付き)

DIAMONDでは強化学習の環境として部分観測マルコフ決定過程(POMDP)を仮定しています。(POMDPについては過去の記事を参照)
POMDPではエージェントが環境の完全な状態を直接観測できないため、過去の観測と行動の履歴をもとに次の状態を予測します。

具体的には以下の条件付き確率分布になります。

$$p(x_{t+1}|x_{\leq t}, a_{\leq t})$$

ここで $x_{\leq t}$ および $a_{\leq t}$ はそれぞれ時刻 $t$ までの観測と行動の履歴を表し、この条件の元で次の状態 $x_{t+1}$ を予測します。

イメージ図は以下です。(論文より)

ss1.png

一番上の行が強化学習におけるステップで、状態 $x^0$、ポリシー $\pi_{\phi}$、アクション $a$ となり、横軸が各ステップ $t$ を表しています。
縦が拡散モデルの予測部分で、完全なノイズ $x^{\tau}$ から拡散世界モデル $D_{\theta}$ でノイズを除去し、状態 $x^0$ を生成します。
(生成する際に条件 Conditioning として、過去の履歴を使用)

拡散世界モデル $D_{\theta}$ は拡散モデルと同じ方法で学習され、損失関数は以下です。

ss2.png

拡散過程のステップが $\tau$ となり、次の状態の画像が $x^{\tau}_{t+1}$ となります。

内容としては以上で以降は細かいポイントです。

報酬と終了モデル

報酬と終了はスカラー予測問題(scalar prediction problems)として個別にモデル化します。
報酬は {-1,0,1} のいずれかにクリップされた形式で扱われます。
報酬と終了のヘッダ部分は共有され、POMDPに対応するためにCNN+LSTMレイヤーで構成されます。
学習は報酬と終了共にクロスエントロピー損失で学習されます。

疑似コード例は以下。
(ResBlocksやDownsample等は拡散モデルの記事やGitHubの実装を参照)

class RewardEndModel(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # --- action embedding
        self.act_emb = kl.Embedding(action_num, 128)

        # --- cnn encoder
        self.conv_in = Conv2D3x3(32)
        self.down_block1 = ResBlocks([32, 32], use_attention=False)
        self.downsample1 = Downsample()
        ×N
        self.down_block4 = ResBlocks([32, 32], use_attention=False)
        self.downsample4 = Downsample()
        self.down_block_last = ResBlocks([32, 32], use_attention=True)

        # --- lstm
        self.lstm = kl.LSTM(512, return_sequences=True, return_state=True)

        # --- reward, done
        self.mid_layer = kl.Dense(512, activation="silu")
        self.reward_layer = kl.Dense(3, use_bias=False)  # -1,0,1のone-hot
        self.done_layer = kl.Dense(2, use_bias=False)  # False,Trueのone-hot

        self.reward_loss = keras.losses.CategoricalCrossentropy(from_logits=True)
        self.done_loss = keras.losses.CategoricalCrossentropy(from_logits=True)


    def call(self, obs, act, next_obs, hidden_state=None):
        condition = self.act_emb(act)

        # obs/next_obsの入力は(batch, timestep, H , W, CH)
        # obsとnext_obsを結合
        x = tf.concat([obs, next_obs], axis=-1)
        # batchとtimestepを結合
        x = tf.reshape(x, (b * t, h, w, c * 2))

        # --- cnn encoder
        x = self.conv_in(x)
        x = self.down_block1(x, condition)
        x = self.downsample1(x)
        ・・・
        x = self.down_block4(x, condition)
        x = self.downsample4(x)
        x = self.down_block_last(x)

        # batchとtimestepを分解して、h,w,cをflatにする
        # (b*t, h, w, c) -> (b, t, -1)
        x = tf.reshape(x, (b, t, -1))

        # LSTM
        x, hx, cx = self.lstm(x, initial_state=hidden_state)

        # 出力層
        x = self.mid_layer(x)
        r = self.reward_layer(x)
        d = self.done_layer(x)
        return r, d, (hx, cx)

    def compute_train_loss(self, obs, act, next_obs, reward, done, hidden_state):
        r, d, hidden_state = self(obs, act, next_obs, hidden_state)

        # rewardとdoneの損失はクロスエントロピーで学習
        loss_r = self.reward_loss(reward, r)
        loss_d = self.done_loss(done, d)
        return loss_r, loss_d, hidden_state

ActorCritic

ポリシーと状態価値も、共有のヘッダを持つCNN+LSTMでモデル化されます。
疑似コード例は以下。

class ActorCritic(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # --- cnn encoder
        self.encoder_layers = [
            Conv2D3x3(32),
            SmallResBlock(32),
            kl.MaxPool2D(2),
            SmallResBlock(32),
            kl.MaxPool2D(2),
            SmallResBlock(64),
            kl.MaxPool2D(2),
            SmallResBlock(64),
            kl.MaxPool2D(2),
        ]

        # lstm
        self.lstm = kl.LSTM(512, return_sequences=True, return_state=True)

        # output layer
        self.critic_linear = kl.Dense(1)
        self.actor_linear = kl.Dense(action_num)

    def call(self, obs, hidden_state=None):
        # cnn encoder
        for h in self.encoder_layers:
            obs = h(obs)

        # (b,1,h,w,c) -> (b,1,h*w*c)
        obs = tf.reshape(obs, (obs.shape[0], 1, -1))
        # lstm
        obs, hx, cx = self.lstm(obs, initial_state=hidden_state)
        # timestep軸を削除
        obs = tf.squeeze(obs, axis=1)

        act_logits = self.actor_linear(obs)
        v = self.critic_linear(obs)
        return act_logits, v, (hx, cx)

学習サイクル

各モデルは基本的に独立して学習できます。
論文に書かれている疑似コードをpython風で書くと以下です。(Appending.G Algorithm1の内容)

def training_loop():
    for epoch in range(epochs):
        collect_experience(steps_collect) # 軌跡の収集
        for _ in steps_diffusion_model:
            update_diffusion_model()      # 拡散世界モデルの学習
        for _ in steps_reward_end_model:
            update_reward_end_model()     # 報酬・終了モデルの学習
        for _ in steps_actor_critic:
            update_actor_critic()         # ActorCriticの学習

def collect_experience(n):
    x = env.reset()
    for t in range(n):
        act = ActorCriticからアクションをサンプル
        x = env.step(act)
        軌跡をメモリに追加
        if env.done:
            x = env.reset()

def update_diffusion_model():
    batch = メモリから軌跡をランダムに取得denoiserの入力長+1
    obs = batch["state"][-1]         # 現在の状態
    recent_obs = batch["state"][:-1] # 状態の履歴
    recent_act = batch["action"]     # アクションの履歴

    # --- EDM update
    sigma = np.exp(np.random.normal() * noise_std + noise_mean)
    noisy_obs = obs + np.random.normal() * sigma
    with tf.GradientTape() as tape:
        denoised_obs = denoise(noisy_obs, sigma, recent_obs, recent_act)
        loss = MSE(obs, denoised_obs)
    diffutionモデルの更新


def update_reward_end_model():
    batch = メモリから軌跡をランダムに取得burnin+horizon+1の長さ
    seq_obs =  batch["state"]
    seq_act =  batch["action"]
    seq_reward =  batch["reward"]
    seq_done =  batch["done"]

    # burnin
    hc = LSTMの初期隠れ状態
    _, _, hc = reward_end_model([seq_obs[:burnin], seq_ct[:burnin]], hc)

    # burnin以降の内容で学習
    with tf.GradientTape() as tape:
        r, d = reward_end_model([seq_obs[burnin:], seq_act[burnin:]], hc)
        r_loss = CategoricalCrossentropy(seq_reward[burnin:], r)
        d_loss = CategoricalCrossentropy(seq_done[burnin:], d)
    reward_end_modelの更新


def update_actor_critic():
    batch = メモリから軌跡をランダムに取得denoiserの入力長+1
    seq_obs =  batch["state"][:-1]
    seq_next_obs =  batch["state"][1:]
    seq_act =  batch["action"]
    seq_reward =  batch["reward"]
    seq_done =  batch["done"]

    # --- burnin
    hc_rewend = LSTMの初期隠れ状態
    hc_act = LSTMの初期隠れ状態
    _, _, hc_rewend = reward_end_model([seq_obs, seq_act, seq_next_obs], hc_rewend)
    _, _, hc_act = actor_critic_model(seq_obs, hc_act)

    # --- 未来の状態を予測し学習
    recent_obs = seq_obs
    recent_act = seq_act
    with tf.GradientTape() as tape:
        v_list = []
        logpi_list = []
        entropy_list = []
        r_list = []
        d_list = []
        for i in range(horizon):
            # actionとvを予想
            obs = recent_obs[i]
            act_dist, v, hc_act = actor_critic_model(obs, hc_act)
            act = act_dist.rsample()       # 伝播できるアクションをサンプル
            logpi = act_dist.log_prob(act) # policy更新用
            entropy = act_dist.entropy()   # entropy loss用

            # 履歴を更新
            del recent_act[0]
            recent_act.append(act)

            # 世界拡散モデルから次の状態を予測し、履歴に追加
            next_obs = sampler(recent_obs, recent_act)
            del recent_obs[0]
            recent_obs.append(next_obs)

            # rewardとdoneを予測
            r, d, hc_rewend = reward_end_model([obs, act, next_obs], hc=hc_rewend)

            v_list.append(v)
            logpi_list.append(logpi)
            entropy_list.append(entropy)
            r_list.append(r)
            d_list.append(d)

        # 逆方向にたどってラムダリターンとlossを計算
        gamma = 0.985
        lambda_ = 0.95
        G_lambda = v_list[-1]
        for i in reversed(range(horizon)):
            v = v_list[i]
            logpi = logpi_list[i]
            entropy = entropy_list[i]
            r = reward_list[i]
            d = dones[i]

            G_lambda = r + gamma * ((1 - lambda_) * v + lambda_ * G_lambda)
            adv = G_lambda - v
            loss_a += logpi * tf.stop_gradient(adv)
            loss_v += MSE(tf.stop_gradient(G_lambda), v)

    actor_critic_modelの更新

Worker

経験収集時のエピソードの流れをフレームワークベースで書くと以下です。

class Worker(RLWorker):
    # エピソード最初に実行
    def on_reset(self, worker):
        self.actor_hc = ActorCriticのLSTMの隠れ状態の初期化

    # 毎ステップのアクションを返す
    def policy(self, worker) -> int:
        obs = worker.state

        # ActorCriticモデルからアクションを出す
        act_dist, v, self.actor_hc = actor_criticモデル(obs, hidden_state=self.actor_hc)
        action = act_distを元にアクションをランダムに選択

        return action

    # 1ステップ後の処理
    def on_step(self, worker):
        # 報酬は-1,0,1のどれか(onehot化)
        if worker.reward < 0:
            reward = [1, 0, 0]  # -1
        elif worker.reward > 0:
            reward = [0, 0, 1]  # 1
        else:
            reward = [0, 1, 0]  # 0

        # 終了したかどうか(onehot化)
        done =  [0, 1] if worker.terminated else [1, 0]

        # バッチを送信(時系列になるように送信)
        self.memory.add({
            "state": worker.state,
            "action": worker.action,
            "next_state": worker.next_state,
            "reward": reward,
            "done": done,
        })

学習結果

学習の所感

・拡散モデルの学習

拡散モデルの学習に癖があり、想像以上に苦戦しました…。
モデルとしては「次の状態」を学習しないといけないのですが、現在と同じ状態を再現するところまではすぐに学習でき、そこで学習が終わったと勘違いしてました。
しかし実際には、次の状態を再現できるだけの学習には至っておらず、これに気が付くまでかなり時間がかかりました…。
(生成された画像を見ると綺麗に生成されている)
これは単に学習時間が足りていなかっただけで、しっかり学習すれば次の状態も予測できるようになりました。

・スペック
拡散モデルでもそうでしたがUNetがかなり重いです。
更に、おそらくですが horizon ステップごとに UNet が複数展開されているようで、 UNet × horizon 分のメモリが消費されているように感じました。
このPCは12GBのGPUをつんでいますがすぐメモリ不足になったり…。(最終的にバッチサイズを8まで落としました…)

学習結果

_diamond.gif

左上がオリジナル画像、その下がアルゴリズム側へ入力される画像です。
縦に並んでいる画像が拡散モデルで生成している画像となり、一番下の画像が元ノイズ、そこから上に向けて生成途中の画像となり、一番上の画像が生成された画像です。
(生成ステップは3ステップ)
また画像の横軸ですが、左から←↓→↑の行動後の予測画像をそれぞれ出しています。

生成ステップはたった3ステップですが、かなり綺麗に画像を生成していますね。(これはEDMですがDDPMだと生成に数千ステップは必要)

最後に

DIAMONDですが、拡散モデルを試すだけで他の部分(reward-endモデルとActorCritic)は手を加えていない印象でした。
たぶん拡散世界モデルの性能を見たかったからで、他の部分を改善すれば性能はもっと上がりそうです。

ただ拡散モデル部分はすごいですね。
次の状態がかなり綺麗に再現できています。
拡散モデルの性能を考えると単に同じ環境を学習させるだけではなく、複数の環境を学習させたほうがいいかもしれませんね。
複数の環境を学習させて使いまわせるようにすれば、強化学習でもモデルを使いまわせるようになるかもしれません。

最後に投稿に関してですが、記事自体はかなり昔に完成していました。
ただ、実装と学習結果の部分に時間がかかりすぎて…。
今後は記事は解説だけにしてもいいかもしれません。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?