2
1

【強化学習】好奇心による探索RNDを改良したSNDを解説・実装してみた

Last updated at Posted at 2024-07-06

この記事は自作している強化学習フレームワークの解説記事です。

はじめに

ふと以下のAtariゲームのベンチマークを見ていたら Montezuma's Revenge でかなり高いスコアを出していたSND-Vというアルゴリズムがあったので見てみました。(Go-Exploreっていうのも興味がある)

どうやらこれはAgent57でも使われていたRNDという手法を改良したアルゴリズムのようです。

SND論文: https://arxiv.org/abs/2302.11563

2023年2月にv1でv4が2024年6月と新しめの論文ですね。

SND(Self-supervised Network Distillation)

日本語だと自己教師ありネットワーク蒸留でしょうか。(ぱっと見日本語の記事はなさそうでした)

強化学習の課題として報酬が疎(全く手に入らない)環境では学習が進まない問題がありました。
一般的な強化学習のアプローチでは、最初の報酬が手に入るまでは学習目標がないため、学習に意味がなくなります。
最初の報酬が手に入るかどうかは運しだいとなり、報酬が疎な環境では学習が絶望的になります。
このような環境に対するアプローチの一つに内発的動機付けによる探索の促進があります。

ベンチマークでのSNDの位置づけは以下です。(論文にはなかったので上記サイトから取得、多分非公式)

aa.drawio.png

SNDは3種類あるようで、V/VIC/STDがあります。
SND-VはAgent57を超えてますね。

内発的動機付け、内的モチベーション (Intrinsic Motivation; IM)

外発的・内発的動機付けは元は心理学の分野の話で、それを強化学習に応用した形となります。
(ディープラーニングも人間の脳の動きを模倣していたりと人から学ぶことは多いですね)

内発的動機付けとは自分自身の内側から湧き出る興味や楽しさ、満足感や好奇心などによって行動を起こす動機付けのことです。
例えば、絵を描くことが好きで、その行為自体が楽しいから絵を描く場合、これは内発的動機付けによって行われていると言えます。
反対に、外的な要因(例えば、お金や賞賛など)によって行動する場合は外発的動機付けとなります。

強化学習では環境からの報酬は外発的動機付けとなるので、エージェントに内発的動機付けを足して探索を促そうという試みは自然な流れですね。

SND論文内で参照されているIMの論文:https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4158798/

RND(Random Network Distillation model)

SNDの元となっている手法です。
RNDでは内発的動機付けとして、未知の状態を高く評価する好奇心的探索を行っています。
具体的には、初めて訪れた状態には高い報酬を与え、よく訪れる状態は報酬を低くするという手法をとっています。
ここでRNDが自ら生成した報酬をIntrinsicReward(内発的報酬・内部報酬)、反対に環境の報酬をExtrinsicReward(外発的報酬・外部報酬)と言います。
※用語ですが、本記事では主に概念に関する場合は内発的報酬・外発的報酬とし、主にアルゴリズムに関する場合は内部報酬・外部報酬としています。

RND論文:https://arxiv.org/abs/1810.12894

RNDのアルゴリズム

ss1.png
図はSND側の論文より

まずランダムに初期化されたターゲットモデル(TargetModel)と予測モデル(PredictorModel)の2つのモデル(ニューラルネットワーク)を用意します。
同じ状態を入力とし、この2つのモデルからの出力を得ます。
それぞれのモデルの出力の差(L2ノルムの2乗)を内部報酬とするのがRNDとなります。

また、ここでターゲットモデルは固定し、予測モデルの出力をターゲットモデルに近づけるように学習します。(モデルの蒸留)
この学習により、初めて見る状態はモデル間の差が大きく(内部報酬が大きい)なり、よく見る状態はモデル間の差が小さく(内部報酬が小さい)なります。

RNDの問題点

SND論文内では2点指摘されています。

1.ランダムネットワークを適切に初期化する必要がある
2.時間の経過と共に予測モデルが十分に学習されるため、内部報酬が消える

この問題を解決しているのがSNDとなります。

SNDアルゴリズムの概要

SNDの概要は以下です。

ss2.png

SNDは大きく2つの構成要素からなり、上部が自己教師あり学習によるターゲットモデルの学習、下部がRNDと同様のターゲットモデルの蒸留と内部報酬の計算になります。

ターゲットモデルは自己教師あり学習で学習されるため、学習後の分布は(アルゴリズムが学習した)特徴空間の分布になります。
それに比べ予測モデルは状態空間の分布となります。(画像を直接入力しているので)
この分布の違いによりランダムな初期化よりも離れた分布を表現でき、かつ常に更新されるので時間の経過とともに内部報酬がなくなることもありません。
これによりRNDの問題点を解決しています。

自己教師あり学習のアルゴリズムは特に決まっていないようで、論文ではを3つ試しています。

SND-V

1つ目は対照学習(Contrastive Learning)によって学習する方法です。
この方法は、ランダムに選択された2つの状態に対してMSE(平均二乗誤差)でターゲット距離を計算します。
2つの状態は50%の確率で状態→同じ状態と状態→違う状態のバッチを作成し、同じ状態→ターゲット距離0、違う状態→ターゲット距離1で学習します。
(ここの違う状態は連続した状態ではなく、ランダムに選んだ画像を指していると思います)

損失関数は以下です。

ss4.png

$n$ がバッチサイズとなり、$Z_n$ と $Z'_n$ がターゲットモデルから出力された値で $||.||^2_2$ がL2ノルムの2乗です。
$\tau_n$ は $Z_n$ と $Z'_n$ の出力元の画像が同じなら0、違うなら1となり、50%の割合で混ざります。

この損失関数は、同じ状態($\tau_n=0$)の特徴ベクトルを互いに近づけます。
また、違う状態($\tau_n=1$)では特徴ベクトルの距離が1を超える場合にのみ近づける性質があります。

画像の前処理

SND-Vでは状態(画像)には画像変換の前処理を加え、ロバスト性を高めます。
以下の3段階を実施しているようです。

ss3.png

  1. ランダム畳み込みフィルター(random convolution filter)
  2. ランダムタイルマスク(タイルサイズ:2,4,8,12,16)
  3. [-0.2,0.2]の範囲の一様乱数を追加

ランダム畳み込みフィルターはよく分かりませんでした。(ランダムな重みのConv2d層を通すだけ?)
ランダムタイルマスクは黒いタイルでマスクするようです。

SND-STD

2つ目は、Spatio-Temporal DeepInfoMax (ST-DIM)アルゴリズムによって学習する方法です。
ST-DIMアルゴリズムの基本的な考え方は以下です。

  • 2つの連続する状態の出力層の表現と、位置情報を持つ中間層の表現を近づける事(グローバル-ローカル目標)
  • 2つの連続する状態の中間層の表現を近づける事(ローカル-ローカル目標)

教師データは、2つの連続する状態を正解、同じミニバッチの連続しない状態を不正解として学習します。
損失は以下です。

ss5.png

$H$と$W$は画像のサイズ、$g_{h,w}(s_t,s_{t+1})$ はスコア関数で中間層の出力と出力層の出力の非正規化コサイン類似度(the unnormalized cosine similarity)です。
$f_{h,w}(s_t,s_{t+1})$ も同様に中間層同士の非正規化コサイン類似度となります。
($g$,$f$については論文にもう少し詳しい式がのっているので詳細はそちらをどうぞ)

また以下の正則化項を追加しています。

$f$と$g$のL2ノルムの最小化

ss6.png

特徴空間が増加し続ける傾向にあるようで、それを抑えるために追加されたようです。

出力の標準偏差の最大化

ss7.png

$\Phi$ はターゲットモデルを表します。(表記ミスを恐れずに書くなら $\Phi(s)=Z$ )
出力値の分散を最大化することで、全ての次元の特徴空間が使用されるようにしたとの事です。

最終的な損失は以下です。

ss8.png

$\beta$ はハイパーパラメータです。( $\beta_1=0.0001$ , $\beta_2=0.0001$ )

SND-VIC

3つ目は、非対照自己教師あり学習法に属するVICReg(Variance-Invariance-Covariance Regularization)アルゴリズムによって学習する方法です。
データとして、元のアルゴリズム(VICReg)で使う画像は1つですが、SND-VICは今の画像 $s_t$ と次の画像 $s_{t+1}$ の2つを使い学習します。

損失は3つの項から成ります。

1.分散正規化項

ss9.png

$D$ はバッチサイズを表し、$Z$ はターゲットモデルの出力値、$\sigma(Z_d)$ は各バッチの標準偏差を表します。
$\tau$ は標準偏差の目標値で1で固定です。

この項は、各バッチの標準偏差が各次元に沿って $\tau$ と等しくなるように促すことで、全ての状態が同じ特徴ベクトルにマッピングされる事を防ぐ役割があります。

2.共分散正則化項

ss10.png

$C(Z)$ の非対角要素の二乗和として計算されます。(the sum of the squared off-diagonal coefficients of the covariance matrix C(Z))

この項は特徴ベクトルの相関をなくし、各特徴ベクトルが固有の情報を学習できるようにする役割があります。

3.特徴ベクトルを近づける不変項

ss11.png

今の状態と次の状態の特徴量のL2ノルムの2乗ですね。

最終的な損失は以下です。

ss12.png

ハイパーパラメータは、$\lambda=1$ , $\mu=1$ , $\nu=\frac{1}{25}$ です。

性能

論文の4.4の箇所となります。
論文内ではもう少しいろいろ分析していますがきりがないので一部のみ抜粋です。

内発的報酬の有効性

ss13.png

横軸がステップ数、縦軸が累積外部報酬を表しています。
どの手法も比較的RNDより早い段階で学習が始まっていそる感じです。
(最終的な報酬も多そう)

次の図は新規性の比較です。

ss14.png

※論文にはSTDとVICも載っています

図の横軸は"Montezuma’s Revenge"の学習過程を表し、縦軸が各エージェントが収集した状態の内部報酬を表したものです。
(赤い点は上の画像に対応)

RNDは最終的に学習が収束するので最初しか信号を出力していませんね。
SNDは継続的にターゲットモデルが更新されるので学習全体を通して信号が出力されています。

特徴空間の性質

ss15.png

図は"Montezuma’s Revenge"の学習済みモデルに対する出力をt-SNEで可視化したものになります。
色はゲーム内の部屋に対応しているとの事(部屋が変わると大きく状態が変わる)

RNDは同じ部屋では円が小さいですね、これは異なる部屋同士はうまく表現できますが、同じ部屋内では分散が小さいので探索能力が低くなっています。
SNDは同じ部屋でも円が大きい(分散が高い)ので、同じ部屋の中でもより細かい探索が出来そうです。

実装

論文ではRLアルゴリズムとしてPPOを実装していましたが、PPOはあまり安定しているイメージがないのでDQNベースで実装しました。
また実装はSND-Vのみとなります。

※SRLがv0.16.1のコードです。バージョンが進むと動かない可能性があります。

Config

内部報酬で探索するのでε-greedyは試しになくしてみました。(あってもいいと思います)
また、入力画像はDQNと同じにして(84,84)のサイズです。
論文ではAtariは(96,96)、Procgenは(64,64)です。

@dataclass
class Config(RLConfig):
    lr: Union[float, SchedulerConfig] = 0.001
    discount: float = 0.99
    target_model_update_interval: int = 1000
    int_reward_scale: float = 0.5  # 内部報酬の割合

    def get_processors(self) -> List[Optional[RLProcessor]]:
        return [ImageProcessor(SpaceTypes.GRAY_2ch, (84, 84), enable_norm=True)]

ハイパーパラメータがほとんどありませんね…

Memory

SND用のメモリとQ学習用のメモリをそれぞれ用意して使い分けます。
どちらもランダムに取得するExperienceMemoryBufferです。

class Memory(RLMemory):
    def __init__(self, *args):
        super().__init__(*args)

        self.memory_snd = []
        self.memory_q = []

    def add(self, mode: str, batch) -> None:
        if mode == "snd":
            self.memory_snd.append(batch)
            if len(self.memory_snd) > self.config.memory_capacity:
                self.memory_snd.pop(0)
        else:
            self.memory_q.append(batch)
            if len(self.memory_q) > self.config.memory_capacity:
                self.memory_q.pop(0)

    def sample_snd(self, batch_size):
        return random.sample(self.memory_snd, batch_size)

    def sample_q(self):
        return random.sample(self.memory_q, self.config.batch_size)

NNモデル

SNDNetwork

入力層はDQNと同じです。

class SNDNetwork(keras.Model):
    def __init__(self, in_shape, **kwargs):
        super().__init__(**kwargs)

        self.h_layers = [
            kl.Conv2D(32, (8, 8), strides=(4, 4), padding="same", activation="relu"),
            kl.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu"),
            kl.Conv2D(64, (3, 3), strides=(1, 1), padding="same", activation="relu"),
            kl.Flatten(),
            kl.Dense(512),
        ]

    def call(self, x, training=False):
        for h in self.h_layers:
            x = h(x, training=training)
        return x

QNetwork

DQNと同じです。

class QNetwork(keras.Model):
    def __init__(self, in_shape, action_num, **kwargs):
        super().__init__(**kwargs)

        self.h_layers = [
            kl.Conv2D(32, (8, 8), strides=(4, 4), padding="same", activation="relu"),
            kl.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu"),
            kl.Conv2D(64, (3, 3), strides=(1, 1), padding="same", activation="relu"),
            kl.Flatten(),
            kl.Dense(512, activation="relu"),
            kl.Dense(action_num),
        ]

    def call(self, x, training=False):
        for h in self.h_layers:
            x = h(x, training=training)
        return x

Parameter

class Parameter(RLParameter):
    def __init__(self, *args):
        super().__init__(*args)

        in_shape = self.config.observation_space.shape
        action_num = self.config.action_space.n
        self.q_online = QNetwork(in_shape, action_num, name="Q_online")
        self.q_target = QNetwork(in_shape, action_num, name="Q_target")
        self.q_target.set_weights(self.q_online.get_weights())

        self.snd_target = SNDNetwork(in_shape, name="target")
        self.snd_predictor = SNDNetwork(in_shape, name="predictor")

Worker

class Worker(RLWorker):
    def on_reset(self, worker):
        # 初期状態を送ります
        self.memory.add("snd", worker.state)

    def policy(self, worker) -> int:
        # Q値が最大の行動を選びます。
        # 内部報酬があるのでε-greedyは省略しています。
        state = worker.state[np.newaxis, ...]
        q = self.parameter.q_online(state)[0].numpy()
        action = int(np.argmax(q))
        return action

    def on_step(self, worker):
        if not self.training:
            return

        # 報酬を計算
        # 報酬 = 外部報酬 + int_scale * 内部報酬
        r_ext = worker.reward
        r_int = self._calc_intrinsic_reward(worker.state)
        reward = r_ext + self.config.int_reward_scale * r_int

        batch = [
            worker.prev_state,
            worker.state,
            funcs.one_hot(worker.action, self.config.action_space.n),
            reward,
            int(not worker.terminated),
        ]
        self.memory.add("q", batch)
        self.memory.add("snd", worker.state)

    def _calc_intrinsic_reward(self, state):
        # 内部報酬を計算
        state = state[np.newaxis, ...]
        z1 = self.parameter.snd_target(state)[0]
        z2 = self.parameter.snd_predictor(state)[0]

        # L2ノルムの2乗
        distance = np.sum(np.square(z1 - z2))

        return distance

Trainer


class Trainer(RLTrainer):
    def __init__(self, *args):
        super().__init__(*args)

        self.opt_q = keras.optimizers.Adam(self.config.lr)
        self.opt_snd_target = keras.optimizers.Adam(self.config.lr)
        self.opt_snd_predictor = keras.optimizers.Adam(self.config.lr)

        self.loss_mse = keras.losses.MeanSquaredError()
        self.loss_huber = keras.losses.Huber()

    def train(self) -> None:
        self._train_snd()
        self._train_q()

    def _train_snd(self):
        if len(self.memory.memory_snd) < self.config.memory_warmup_size:
            return

        # --- 対照学習、画像の前処理は実施していません
        # (s1, s1) -> tau=0
        # (s1, s2) -> tau=1
        batch_half = int(self.config.batch_size / 2)
        state1 = self.memory.sample_snd(self.config.batch_size)
        state2 = self.memory.sample_snd(batch_half)
        state1 = np.asarray(state1)

        # state2は半分はstate1を使う
        state2 = np.concatenate([state1[:batch_half], state2], axis=0)
        tau = np.concatenate(
            [np.zeros((batch_half, 1)), np.ones((batch_half, 1))],
            axis=0,
            dtype=self.config.dtype,
        )

        with tf.GradientTape() as tape:
            z1 = self.parameter.snd_target(state1, training=True)
            z2 = self.parameter.snd_target(state2, training=True)
            # L2ノルムの2乗
            loss = tau - tf.reduce_sum((z1 - z2) ** 2, axis=-1, keepdims=True)
            loss = tf.reduce_sum(loss**2)
        grad = tape.gradient(loss, self.parameter.snd_target.trainable_variables)
        self.opt_snd_target.apply_gradients(
            zip(grad, self.parameter.snd_target.trainable_variables)
        )

    def _train_q(self):
        if len(self.memory.memory_q) < self.config.memory_warmup_size:
            return
        batchs = self.memory.sample_q()
        state, n_state, action, reward, undone = zip(*batchs)
        state = np.asarray(state, self.config.dtype)
        n_state = np.asarray(n_state, self.config.dtype)
        action = np.asarray(action, self.config.dtype)
        reward = np.array(reward, self.config.dtype)
        undone = np.array(undone, self.config.dtype)

        # --- distillation
        z1 = self.parameter.snd_target(n_state)
        with tf.GradientTape() as tape:
            z2 = self.parameter.snd_predictor(state, training=True)
            loss = self.loss_mse(z1, z2)
        grad = tape.gradient(loss, self.parameter.snd_predictor.trainable_variables)
        self.opt_snd_predictor.apply_gradients(
            zip(grad, self.parameter.snd_predictor.trainable_variables)
        )

        # --- calc next q
        n_q = self.parameter.q_online(n_state)
        n_q_target = self.parameter.q_target(n_state).numpy()
        n_act_idx = np.argmax(n_q, axis=-1)
        maxq = n_q_target[np.arange(self.config.batch_size), n_act_idx]
        target_q = reward + undone * self.config.discount * maxq
        target_q = target_q[..., np.newaxis]

        # --- train q
        with tf.GradientTape() as tape:
            q = self.parameter.q_online(state, training=True)
            q = tf.reduce_sum(q * action, axis=1)
            loss = self.loss_huber(target_q, q)
        grad = tape.gradient(loss, self.parameter.q_online.trainable_variables)
        self.opt_q.apply_gradients(
            zip(grad, self.parameter.q_online.trainable_variables)
        )

        # --- targetと同期
        if self.train_count % self.config.target_model_update_interval == 0:
            self.parameter.q_target.set_weights(self.parameter.q_online.get_weights())

        self.train_count += 1

実行結果

AtariのPongで動かしてみました。比較はDQNです。
学習しやすいように以下の操作をしています。

・上部のスコア等、不要な部分をクリッピング
・スコアの上限を20ではなく5に変更
・グレー化+白黒の2値化を実施
・フレームスキップは8フレーム

学習結果

DQNは何もないのがepsilon=0.1、"DQN e0"がepsilon=0です。
SNDもepsilon=0なので比較で見てみました。
(そしてやって見てわかりましたがPongってepsilon=0でも学習できるんですね…)

Pong.png

Pongだとあまり差は見られませんでした。
学習対象としてはあまり良くなかったかも…。

実際の動き、左下はRLアルゴリズムが実際に見えている画像です。

Pong_SND.gif

コード全体

※SRLがv0.16.1で動作確認しています。

import os
import random
from dataclasses import dataclass
from typing import List, Optional, Union

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras

import srl
from srl.algorithms import dqn
from srl.base.define import SpaceTypes
from srl.base.rl.algorithms.base_dqn import RLConfig, RLWorker
from srl.base.rl.memory import RLMemory
from srl.base.rl.parameter import RLParameter
from srl.base.rl.processor import RLProcessor
from srl.base.rl.registration import register
from srl.base.rl.trainer import RLTrainer
from srl.rl import functions as funcs
from srl.rl.memories.experience_replay_buffer import RLConfigComponentExperienceReplayBuffer
from srl.rl.models.config.framework_config import RLConfigComponentFramework
from srl.rl.processors.atari_processor import AtariPongProcessor
from srl.rl.processors.image_processor import ImageProcessor
from srl.rl.schedulers.scheduler import SchedulerConfig
from srl.utils import common

kl = keras.layers


@dataclass
class Config(
    RLConfig,
    RLConfigComponentExperienceReplayBuffer,
    RLConfigComponentFramework,
):
    lr: Union[float, SchedulerConfig] = 0.001
    discount: float = 0.99
    target_model_update_interval: int = 1000
    int_reward_scale: float = 0.5

    def get_processors(self) -> List[Optional[RLProcessor]]:
        return [ImageProcessor(SpaceTypes.GRAY_2ch, (84, 84), enable_norm=True)]

    def get_framework(self) -> str:
        return "tensorflow"

    def get_name(self) -> str:
        return "SND"

    def assert_params(self) -> None:
        super().assert_params()
        self.assert_params_memory()
        self.assert_params_framework()
        assert self.batch_size % 2 == 0


register(
    Config(),
    __name__ + ":Memory",
    __name__ + ":Parameter",
    __name__ + ":Trainer",
    __name__ + ":Worker",
)


class Memory(RLMemory[Config]):
    def __init__(self, *args):
        super().__init__(*args)

        self.memory_snd = []
        self.memory_q = []

    def length(self):
        return len(self.memory_q)

    def add(self, mode: str, batch) -> None:
        if mode == "snd":
            self.memory_snd.append(batch)
            if len(self.memory_snd) > self.config.memory_capacity:
                self.memory_snd.pop(0)
        else:
            self.memory_q.append(batch)
            if len(self.memory_q) > self.config.memory_capacity:
                self.memory_q.pop(0)

    def sample_snd(self, batch_size):
        return random.sample(self.memory_snd, batch_size)

    def sample_q(self):
        return random.sample(self.memory_q, self.config.batch_size)


class SNDNetwork(keras.Model):
    def __init__(self, in_shape, **kwargs):
        super().__init__(**kwargs)

        self.h_layers = [
            kl.Conv2D(32, (8, 8), strides=(4, 4), padding="same", activation="relu"),
            kl.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu"),
            kl.Conv2D(64, (3, 3), strides=(1, 1), padding="same", activation="relu"),
            kl.Flatten(),
            kl.Dense(512),
        ]

        # build
        self(np.zeros((1,) + in_shape))

    def call(self, x, training=False):
        for h in self.h_layers:
            x = h(x, training=training)
        return x


class QNetwork(keras.Model):
    def __init__(self, in_shape, action_num, **kwargs):
        super().__init__(**kwargs)

        self.h_layers = [
            kl.Conv2D(32, (8, 8), strides=(4, 4), padding="same", activation="relu"),
            kl.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu"),
            kl.Conv2D(64, (3, 3), strides=(1, 1), padding="same", activation="relu"),
            kl.Flatten(),
            kl.Dense(512, activation="relu"),
            kl.Dense(action_num),
        ]

        # build
        self(np.zeros((1,) + in_shape))

    def call(self, x, training=False):
        for h in self.h_layers:
            x = h(x, training=training)
        return x


class Parameter(RLParameter[Config]):
    def __init__(self, *args):
        super().__init__(*args)

        in_shape = self.config.observation_space.shape
        action_num = self.config.action_space.n
        self.q_online = QNetwork(in_shape, action_num, name="Q_online")
        self.q_target = QNetwork(in_shape, action_num, name="Q_target")
        self.q_target.set_weights(self.q_online.get_weights())

        self.snd_target = SNDNetwork(in_shape, name="target")
        self.snd_predictor = SNDNetwork(in_shape, name="predictor")

    def call_restore(self, data, **kwargs):
        self.q_online.set_weights(data[0])
        self.q_target.set_weights(data[0])
        self.snd_target.set_weights(data[1])
        self.snd_predictor.set_weights(data[2])

    def call_backup(self, **kwargs):
        return [
            self.q_online.get_weights(),
            self.snd_target.get_weights(),
            self.snd_predictor.get_weights(),
        ]

    def summary(self, **kwargs):
        self.q_online.summary(**kwargs)
        self.snd_target.summary(**kwargs)


class Trainer(RLTrainer[Config, Parameter, Memory]):
    def __init__(self, *args):
        super().__init__(*args)

        self.opt_q = keras.optimizers.Adam(self.config.lr)
        self.opt_snd_target = keras.optimizers.Adam(self.config.lr)
        self.opt_snd_predictor = keras.optimizers.Adam(self.config.lr)

        self.loss_mse = keras.losses.MeanSquaredError()
        self.loss_huber = keras.losses.Huber()

        self.sync_count = 0

    def train(self) -> None:
        self._train_snd()
        self._train_q()

    def _train_snd(self):
        if len(self.memory.memory_snd) < self.config.memory_warmup_size:
            return

        # 対照学習
        # (s1, s1) -> tau=0
        # (s1, s2) -> tau=1
        batch_half = int(self.config.batch_size / 2)
        state1 = self.memory.sample_snd(self.config.batch_size)
        state2 = self.memory.sample_snd(batch_half)
        state1 = np.asarray(state1)
        state2 = np.asarray(state2)

        # random convolution filter skip...
        # random tile mask skip...
        # random
        state1 += np.random.uniform(-0.2, 0.2, size=state1.shape)
        state2 += np.random.uniform(-0.2, 0.2, size=state2.shape)

        # state2は半分はstate1を使う
        state2 = np.concatenate([state1[:batch_half], state2], axis=0)
        tau = np.concatenate(
            [np.zeros((batch_half, 1)), np.ones((batch_half, 1))],
            axis=0,
            dtype=self.config.dtype,
        )

        with tf.GradientTape() as tape:
            z1 = self.parameter.snd_target(state1, training=True)
            z2 = self.parameter.snd_target(state2, training=True)
            # L2ノルムの2乗
            loss = tau - tf.reduce_sum((z1 - z2) ** 2, axis=-1, keepdims=True)
            loss = tf.reduce_sum(loss**2)
        grad = tape.gradient(loss, self.parameter.snd_target.trainable_variables)
        self.opt_snd_target.apply_gradients(zip(grad, self.parameter.snd_target.trainable_variables))
        self.info["loss_snd_target"] = loss.numpy()

    def _train_q(self):
        if len(self.memory.memory_q) < self.config.memory_warmup_size:
            return
        batchs = self.memory.sample_q()
        state, n_state, action, reward, undone = zip(*batchs)
        state = np.asarray(state, self.config.dtype)
        n_state = np.asarray(n_state, self.config.dtype)
        action = np.asarray(action, self.config.dtype)
        reward = np.array(reward, self.config.dtype)
        undone = np.array(undone, self.config.dtype)

        # --- distillation
        z1 = self.parameter.snd_target(n_state)
        with tf.GradientTape() as tape:
            z2 = self.parameter.snd_predictor(n_state, training=True)
            loss = self.loss_mse(z1, z2)
        grad = tape.gradient(loss, self.parameter.snd_predictor.trainable_variables)
        self.opt_snd_predictor.apply_gradients(zip(grad, self.parameter.snd_predictor.trainable_variables))
        self.info["loss_snd_predictor"] = loss.numpy()

        # --- calc next q
        n_q = self.parameter.q_online(n_state)
        n_q_target = self.parameter.q_target(n_state).numpy()
        n_act_idx = np.argmax(n_q, axis=-1)
        maxq = n_q_target[np.arange(self.config.batch_size), n_act_idx]
        target_q = reward + undone * self.config.discount * maxq
        target_q = target_q[..., np.newaxis]

        # --- train q
        with tf.GradientTape() as tape:
            q = self.parameter.q_online(state, training=True)
            q = tf.reduce_sum(q * action, axis=1)
            loss = self.loss_huber(target_q, q)
        grad = tape.gradient(loss, self.parameter.q_online.trainable_variables)
        self.opt_q.apply_gradients(zip(grad, self.parameter.q_online.trainable_variables))
        self.info["loss_q"] = loss.numpy()

        # --- targetと同期
        if self.train_count % self.config.target_model_update_interval == 0:
            self.parameter.q_target.set_weights(self.parameter.q_online.get_weights())
            self.sync_count += 1
            self.info["sync"] = self.sync_count

        self.train_count += 1


class Worker(RLWorker[Config, Parameter]):
    def on_reset(self, worker):
        self.memory.add("snd", worker.state)

    def policy(self, worker) -> int:
        state = worker.state[np.newaxis, ...]
        q = self.parameter.q_online(state)[0].numpy()
        action = int(np.argmax(q))
        return action

    def on_step(self, worker):
        if not self.training:
            return

        r_ext = worker.reward
        r_int = self._calc_intrinsic_reward(worker.state)
        reward = r_ext + self.config.int_reward_scale * r_int

        batch = [
            worker.prev_state,
            worker.state,
            funcs.one_hot(worker.action, self.config.action_space.n),
            reward,
            int(not worker.terminated),
        ]
        self.memory.add("q", batch)
        self.memory.add("snd", worker.state)

    def _calc_intrinsic_reward(self, state):
        state = state[np.newaxis, ...]
        z1 = self.parameter.snd_target(state)[0]
        z2 = self.parameter.snd_predictor(state)[0]

        # L2ノルムの2乗
        distance = np.sum(np.square(z1 - z2))

        return distance

    def render_terminal(self, worker, **kwargs) -> None:
        # policy -> render -> env.step -> on_step

        # --- int reward
        r_int = self._calc_intrinsic_reward(worker.state)
        print(f"intrinsic reward: {r_int:.6f}")

        q = self.parameter.q_online(worker.state[np.newaxis, ...])[0]
        maxa = np.argmax(q)

        def _render_sub(a: int) -> str:
            return f"{q[a]:7.5f}"

        funcs.render_discrete_action(int(maxa), self.config.action_space, worker.env, _render_sub)


def train(name):
    common.logger_print()

    env_config = srl.EnvConfig(
        "ALE/Pong-v5",
        kwargs=dict(frameskip=7, repeat_action_probability=0, full_action_space=False),
    )
    env_config.processors = [AtariPongProcessor()]

    if name in ["DQN", "DQN_e0"]:
        rl_config = dqn.Config(
            target_model_update_interval=2_000,
            epsilon=0.1,
            discount=0.99,
            lr=0.0005,
            enable_reward_clip=False,
            enable_double_dqn=True,
            enable_rescale=False,
            memory_warmup_size=1000,
            memory_capacity=10_000,
            memory_compress=False,
            window_length=4,
        )
        rl_config.input_image_block.set_dqn_block()
        rl_config.hidden_block.set((512,))
        if name == "DQN_e0":
            rl_config.epsilon = 0
    elif name == "SND":
        rl_config = Config(
            target_model_update_interval=2_000,
            discount=0.99,
            lr=0.0005,
            memory_warmup_size=1000,
            memory_capacity=10_000,
            window_length=4,
        )

    runner = srl.Runner(env_config, rl_config)
    runner.model_summary()

    # history setting
    his_path = os.path.join(os.path.dirname(__file__), f"Pong_{name}")
    runner.set_history_on_file(his_path, interval_mode="time", interval=5, enable_eval=True)

    # train
    runner.train(max_train_count=100_000)

    # animation
    runner.animation_save_gif(os.path.join(os.path.dirname(__file__), f"Pong_{name}.gif"))


def history_plot():
    base_dir = os.path.dirname(__file__)
    his = srl.Runner.load_histories(
        [
            os.path.join(base_dir, "Pong_DQN"),
            os.path.join(base_dir, "Pong_DQN_e0"),
            os.path.join(base_dir, "Pong_SND"),
        ]
    )
    his.plot("train", "eval_reward0", aggregation_num=20, _no_plot=True)
    plt.savefig(os.path.join(os.path.dirname(__file__), "Pong.png"))


if __name__ == "__main__":
    train("DQN")
    train("DQN_e0")
    train("SND")
    history_plot()

おわりに

SND-V/SND-STD/SND-VICでどのアルゴリズムがいいかみたいな話は論文にはなさそうでした。
LLM系が出てきて強化学習は下火っぽくなっていると思いましたが、まだ新しいアルゴリズムが誕生しているんですね。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1