0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【強化学習】方策学習をもっとシンプルに!シンプルに方策を学習できるオフポリシー型PPOができた

Posted at

はじめに

少し前に書いたターゲットネットワーク不要のDQNをPPOにも応用できないか考えていたらいい感じのものが出来たので紹介です。

PPOのアドバンテージの改良

PPO(TRPO)は以下の代理目的関数を最大化するのが目的でした。

L(\theta) = \mathbb{E} \left[ 
    \frac
        {\pi_{\theta}(a|s)}
        {\pi_{\text{old}}(a|s)}
    A_{\pi_{\text{old}}} \right
]

ここで $A_{\pi_{\text{old}}}$ の計算に注目します。
PPOではGAEで計算されますが、話を簡単にしてTD誤差の計算と見てみます。

方策が重要になるので方策の記号も入れたTD誤差の式は以下です。

$$
A_{\pi_{\text{old}}} = r_{\pi_{\text{old}}} + \gamma V_{\pi_{\text{old}}}(s_{t+1}^{\pi_{\text{old}}}) - V_{\pi_{\text{old}}}(s_t^{\pi_{\text{old}}})
$$

$s_t^{\pi_{\text{old}}}$,$s_{t+1}^{\pi_{\text{old}}}$,$r_{\pi_{\text{old}}}$ はデータ収集時の状態と即時報酬で、$\gamma$ は割引率です。
ここで状態価値 $V$ を最新の状態で計算するように変更したのが本手法となります。

$$
A_{\pi_{\text{old}}} = r_{\pi_{\text{old}}} + \gamma V_{{\color{red}{\pi_\theta}}}(s_{t+1}^{\pi_{\text{old}}}) - V_{{\color{red}{\pi_\theta}}}(s_t^{\pi_{\text{old}}})
$$

ここからはコードベースで見ていきます。

学習(Trainer)

アイデアの本質は、学習時に行動価値を計算する事で、最新の状態価値を学習に利用する事です。
行動価値の計算は以下です。

$$
Q_{\pi_{\text{old}}} = r_{\pi_{\text{old}}} + \gamma V_{{\color{red}{\pi_\theta}}}(s_{t+1}^{\pi_{\text{old}}})
$$

$Q$が$\pi_{\text{old}}$になるのは、状態と報酬が古い方策で選ばれているからです。

discount = 割引率
state, action, next_state, reward, done = 学習用データバッファからランダムに取得

next_v = critic_net(next_state)
q = reward + (1-done) * discount * next_v

critic_net は状態価値を学習するニューラルネットワークで、状態を入力させると状態価値を返すモデルを想定しています。

これで行動価値が計算できたのでアドバンテージを計算します。

$$
A_{\pi_{\text{old}}} =Q_{\pi_{\text{old}}} - V_{{\color{red}{\pi_\theta}}}(s_t^{\pi_{\text{old}}})
$$

v = critic_net(state)
adv = q - v

次に状態価値の学習です。

\begin{align}
V(\theta) &= \mathbb{E} \left[ 
    \frac
        {\pi_{\theta}(a|s)}
        {\pi_{\text{old}}(a|s)}
    Q_{\pi_{\text{old}}} \right
] \\
Q_{\pi_{\text{old}}} &= r_{\pi_{\text{old}}} + \gamma V_{\pi_{\theta}}(s_{t+1}^{\pi_{\text{old}}})
\end{align}

$\frac{\pi_{\theta}(a|s)}{\pi_{\text{old}}(a|s)}$ は重点サンプリングで期待値を修正する項となります。
コード的には以下。

old_pi = 学習データ取得時の方策の確率密度

p_dist = actor_net(state)     # 現在の状態の方策を取得
new_pi = p_dist.prob(action)  # 現在の方策での確率(密度)

# 状態価値は全アクションのQ値の期待値で計算できる
# ただ、Q_old は確率が違うの重点サンプリングで補正
ratio = new_pi / old_pi
loss_value = mse_loss(ratio * q, v)

# 過大評価対策、割引報酬和から離れすぎないように正則化項を入れる
loss_v_align = mse_loss(ratio * total_reward, v)

最後に方策の学習ですがここはPPOと同じなので省略します。
まとめると以下です。

def train():
    state, action, old_pi, next_state, reward, done, total_reward = 学習用データバッファからランダムに取得

    # 現在の方策とISを計算
    p_dist = actor_net(state)
    new_pi = p_dist.prob(action)
    ratio = new_pi / old_pi
    
    # 最新のニューラルネットで、Q値とadvを計算
    next_v = critic_net(next_state)
    q = reward + (1-done) * discount * next_v
    v = critic_net(state)
    adv = q - v

    # Criticを学習
    loss_value = mse_loss(ratio * q, v)
    loss_v_align = mse_loss(ratio * total_reward, v)

    # Actorを学習(PPO clip)
    ratio_clipped = torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
    loss_policy = torch.minimum(ratio * adv, ratio_clipped * adv)
    loss_policy = -torch.mean(loss_policy)

    loss = loss_policy + loss_value + 0.1*loss_v_align

イメージとしては、今の方策での状態価値を学習 → 今の状態価値より得られる価値が大きくなるように方策を更新 → 新しい方策での状態価値を学習、を繰り返す感じです。

データ収集(Worker)

上述の通り、学習時にQ値を計算するのでデータ収集時の方策は自由に設定できます。(Actorにする必要がない)
ただActorは探索用の方策を持っておらず、そのまま使うと探索をしないのである程度ランダム性を持たせる必要があります。

ここではε-greedyを採用しました。
またコード例は連続行動空間の場合を書いています。

class Worker(RLWorker):
    def on_reset(self, worker):
        pass

    def policy(self, worker):
        # 現在の状態から方策を取得
        p_dist = actor_net(state)

        if self.training:
            # ε-greedy
            if random.random() < 0.1:
                # 標準正規分布の乱数
                loc = 0.0
                scale = 1.0
            else:
                loc = p_dist.mean()
                scale = p_dist.stddev()

            # アクションを取得
            self.action = np.random.normal(loc, scale)
            # アクションの確率密度の対数を計算
            self.log_prob = compute_normal_logprob(self.action, loc, scale)
        else:
            # 評価時は平均を使う
            self.action = p_dist.mean()

        env_action = self.actionを環境に合わせて整形
        return env_action

    def on_step(self, worker):
        if not self.training:
            return

        worker.add_tracking(
            {
                "state": worker.state,
                "next_state": worker.next_state,
                "action": self.action,
                "log_prob": self.log_prob,
                "reward": worker.reward,
                "not_done": int(not worker.terminated),
            }
        )
        if worker.done:
            # 累積割引報酬和を計算してメモリに送る
            total_reward = 0
            for b in reversed(worker.get_trackings()):
                total_reward = b[4] + self.config.discount * total_reward
                self.memory.add(b + [total_reward])


def compute_normal_logprob(x, loc, scale):
    """
    -0.5 * log(2*pi) - log(stddev) - 0.5 * ((x - mean) / stddev)^2
    """
    return -0.5 * math.log(2 * math.pi) - np.log(scale) - 0.5 * (((x - loc) / scale) ** 2)

一応この手法はV-PPOと命名しています。
Vは状態価値を表し、状態価値の学習がメインのPPOという意味となります。

学習結果

GymのPendulum-v1を学習させてみました。
ナイーブなPPOはハイパーパラメータの調整が難しく、学習にかなり苦労した記憶があります。
ですが、このV-PPOでは元のPPOの細かいテクニックをほとんど使わずに学習できています。

ナイーブなPPOと今回実装したV-PPO比較は以下です。

PPO V-PPO
方策 オンポリシー オフポリシー
Clipped Surrogate Objective PPOのクリップ
Advantageの計算方法 GAE 学習時にTD誤差を計算
Value Clipping 状態価値もクリップする手法
Reward scaling baselineをAdvantageの標準偏差で割る手法
Orthogonal initialization and layer scaling NNの重みをOrthogonalで初期化する手法
Adam learning rate annealing 学習率の焼きなまし
Reward Clipping 報酬をclipする手法
Observation Normalization 状態を標準化する手法
Observation Clipping 状態をclipする手法
Hyperbolic tan activations 活性化関数にtanhを使う手法
Global Gradient Clipping 勾配をL2でclipする手法
Squashed Gaussian policy 正規分布にtanhを使い、アクションの出力を-1~1にする手法

Colabで実際に学習させた結果は以下です。

ss1.png

ちゃんと学習できていますね。
ナイーブなPPOに比べてだいぶ学習しやすい印象はありますが、結構ぶれる印象です。
また、オフポリシーとはいいつつ古い経験ばかりだとPPOのclipに引っかかって学習しにくくなる印象があります。
(完全ランダムな経験だけで学習できることは確認できているので、オフポリシー扱いでも問題はないかと)

学習後の動画です。

vppo.gif

連続値の学習によるアクションがぬるっと変わる感じがいいですね。

NoTarget SAC

蛇足です。
ついでにSACもターゲットネットワークをなくして実装してみました。
ただ、探索の方策を入れたくなかったのでエントロピー項はなくしています。
(なので(SAC+DDPG)÷2みたいな手法です)

過去記事は以下。

DDPG: https://qiita.com/pocokhc/items/6746df2eb9e7840e6814
SAC: https://qiita.com/pocokhc/items/354a2ddf4cbd742d3191

学習(Trainer)

SAC/DDPGの特徴である、「Q値を学習し、Q値が最大となるように方策を学習する」をそのまま書き起こします。

def train():
    state, action, next_state, reward, done, total_reward = 学習用データバッファからランダムに取得

    # --- 学習用のQ値を計算
    with torch.no_grad():
        # 次の状態の方策から、次取るであろうアクションを取得(実際はランダムガウス)
        n_action = policy_net(n_state)

        # 状態とアクションから次のQ値を計算
        n_q = qnet(n_state, n_action)

        target_q = reward + (1-done) * discount * n_q

    # --- q値の学習
    q = qnet(state, action)
    loss_q = huber_loss(target, q)
    loss_q_align = mse_loss(total_reward, q)  # 過大評価抑止用の正則化項

    loss = loss_q + 0.1 * loss_q_align
    opt_q.zero_grad()
    loss.backward()
    opt_q.step()

    # --- policyの学習
    action = policy_net(state)
    q = qnet(state, action)
    loss_policy = -q.mean()  # Q値が最大になるように学習

    opt_policy.zero_grad()
    loss_policy.backward()
    opt_policy.step()

Worker側ですが、V-PPOに比べてアクションの確率(prob)の計算がなくなる分簡単になります。

学習結果

DDPG(TD3)/SACと今回実装したNoT-SAC比較は以下です。

DDPG(TD3) SAC NoT-SAC
方策 オフポリシー オフポリシー オフポリシー
方策エントロピー項 目的関数にエントロピー項を加え、方策と探索を同時に学習する手法
Target Network TargetNetを用いて最新のQ値と学習用のQ値をずらす手法
探索ノイズ SACは方策に探索が含まれているので不要
Clipped Double Q learning 2つのQ-netから値が小さい方を使い過大評価を抑える手法
Target Policy Smoothing 学習時のTarget方策にノイズを混ぜて頑健性を向上させる手法
Delayed Policy Update 方策の更新頻度を上げる手法
Squashed Gaussian policy 正規分布にtanhを使い、アクションの出力を-1~1にする手法

Colabで実際に学習させた結果は以下です。

ss2.png

ちゃんと学習できています。
SACと同様に、V-PPOより安定しているイメージがやはりあります。(ブレが少ない)

学習後の動画は以下です。

not_sac.gif

コード

GoogleColabに置いてあります。

また、本記事の手法(V-PPO/NoT-SAC)は自作フレームワークにも実装しています。

・フレームワークの記事

・GitHub

さいごに

状態価値/行動価値の過大評価を気にしないだけで考え方がだいぶ簡単になりました。
やはり元のアルゴリズムは考え方だけで学習は可能で、過大評価がなくなるだけで細かいテクニックは使わずにちゃんと学習できることが分かったのは大きかったです。

またオフポシリー型PPOですが、今回のような手法は昔からありそうな気がしますが、完全なオフポリシー型PPOは見つけられませんでした。(ChatGPTに軽く調べさせたが似た論文はでてこなかった)
多分実用上では状態価値の過大評価の学習がノイズになっていて実用まで行きつかなかったのかな?と思います。
アドバンテージの計算で古い状態価値を使っているのがかなりネックになっていた気がします。

この記事が誰かの参考になれば幸いです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?