4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【強化学習】全く新しい手法で最難関ゲームMontezuma's Revengeを攻略したGo-Exploreを解説・実装

Last updated at Posted at 2024-07-12

この記事は自作している強化学習フレームワークの解説記事です。

はじめに

ふと以下のAtariゲームのベンチマークを見ていたら Montezuma's Revenge でかなり高いスコアを出していたGo-Exploreというアルゴリズムがあったので見てみました。
https://paperswithcode.com/task/atari-games

Go-Explore論文
https://arxiv.org/abs/1901.10995
https://arxiv.org/abs/2004.12919

背景が分かりませんがなぜか2種類あります。2020年頃の論文ですね。
ちなみにこの論文は概念的な話がメインで具体的な話はあまりなかったり…、記事も読み物っぽくなっています。

また私の所感は引用分で書いています。

Go-Explore

強化学習の課題として報酬が疎(全く手に入らない)環境で学習が進まない問題がありました。
この問題に対するアプローチの一つとして内発的動機付けによる報酬を追加し、探索を促進させるなどの方法がとられています。(Agent57のRND前回のSNDの記事など)
Go-Exploreではこれらの方法とは全く違う新しいアプローチでこの問題を解決します。

Go-Exploreの位置づけは以下です。

ss11.png

Montezuma’s Revenge のスコアになります。
Agent57が出る前の手法です、RNDやApe-Xが霞んで見えます。
Human Avg. どころか HumanExpertも超えて圧倒的性能です。

内発的動機付け、内的モチベーション (Intrinsic Motivation; IM)

(前回のSNDの記事と同じ内容です)

報酬が疎な環境でのアプローチの一つで、内発的動機付けによって報酬を追加する手法を指します。
外発的・内発的動機付けは元は心理学の分野の話で、それを強化学習に応用した形となります。

内発的動機付けとは自分自身の内側から湧き出る興味や楽しさ、満足感や好奇心などによって行動を起こす動機付けのことです。
例えば、絵を描くことが好きで、その行為自体が楽しいから絵を描く場合、これは内発的動機付けによって行われていると言えます。
反対に、外的な要因(例えば、お金や賞賛など)によって行動する場合は外発的動機付けとなります。

強化学習では環境からの報酬は外発的動機付けとなり、エージェントに内発的動機付けをして探索を促そうという試みとなります。

※用語ですが、本記事では内部報酬・内発的報酬、外部報酬・外発的報酬を同じ意味として扱います。
(内部報酬が言いやすいので…)

内発動機付け(IM)の問題点

論文内ではIMで動くエージェントが行動する際の問題点を2点指摘しています。

1. 高い内部報酬からの分離

IMによって動くエージェントは高い内部報酬のエリアから分離される可能性があるという問題です。
まず内部報酬は消費され続ける資源であると考えます。
ここで内部報酬が高い複数のエリアがあると、エージェントは以下のように動作します。

  1. エージェントはあるエリアAを重点的に探索します
  2. エリアAの内部報酬は取りつくされます
  3. 次にエージェントはエリアBを重点的に探索します

ここでエリアAの内部報酬は取りつくされているのでエージェントがエリアAを再度探索することはありません。
また一般なRLアルゴリズムは新しく訪れた領域を優先して学習するので(オンポリシーな方策なら尚更)、エリアAにたどり着く方法も忘れる可能性があります。
もしエリアAに重要な報酬がまだ残っていた場合、これにたどり着くことはほぼありえなくなります。

sss1.png

画像が具体例です。
スタートが真ん中で、緑が未探索エリア(内部報酬が高いエリア)、白が探索済み(内部報酬なし)、紫が探索中のエリアです。
画像の左下(3)では左のエリアはスタート地点付近は白エリア(探索済み)なのでこれ以上探索することはありません。
これ以降は右エリアのみが探索され、左エリアの残っている緑のエリアは探索されることがありません。

この分離問題に対して、DQNで使われるようなリプレイバッファではサイズが無限なら防ぐことができますが、実際にはサイズに上限があるため学習する前に破棄される可能性があり、解決しきれているとは言えていないというのが論文の主張です。

2. 新しい領域への脱線

ここの脱線とは、エージェントが有望な状態を発見した後に、その状態に再度戻って探索したい場合に発生する可能性があります。
一般なRLアルゴリズムは、初期状態からポリシーを再度実行することで有望な状態に再度訪れますが、わずかに確率的な変動(探索の促進など)を混ぜるためそこに到達しない場合があり、これを脱線といいます。

これはIMエージェントが2つの探索メカニズムで動いているからと考えられます。

  1. 新しい状態に高い内部報酬を付与することによる動機付け
  2. RLの基本的な探索アルゴリズム(ε-Greedy法、NNパラメータへの空間ノイズ、アクション空間ノイズの追加など)

IMエージェントは2の探索により高い内部報酬を見つけ、1によりその状態に戻るという動作を繰り返します。
ただ高い内部報酬が初期状態から遠い場合、2の確率的な行動により高い内部状態に行きつくことが難しくなります。
(この行きつけない現象を脱線といっていると思います)

Go-Exploreはこの"分離"と"脱線"に対処したアルゴリズムとなります。

Go-Exploreアルゴリズム概要

Go-Exploreは以下の3つの原則で動きます。

  1. 以前に訪れた状態を記憶する
  2. (記憶した状態から)有望な状態に戻り、そこから探索する
  3. 最適な手段を模倣学習により堅牢化する

以前に訪れた状態を記憶することで"分離"を、記憶した状態に戻ることで"脱線"を解決します。
実際にはGo-Exploreは2つのフェーズに分かれて学習します。

ss1.png

・詳細

sss2.png

Phase1

  1. 過去に訪れた状態(アーカイブ)から有望な状態を選びます
  2. その状態を復元します
  3. その状態から探索します
  4. 新しい状態を見つけたらアーカイブに追加します

Phase2

アーカイブ中のベストな履歴を模倣学習し、方策を堅牢化する。

ここまで読んで私は、いやこれが出来れば苦労しないよ、と思いました。
古典強化学習の範囲なら可能に見えますが、これは状態が連続なディープラーニング時代の話をしています。
とりあえず思いついた疑問を書いておきます。

  • 状態は無限にあるので、訪れた状態を全部記憶していたらメモリが足りない
  • 過去の状態に戻るのはどうやって?

では論文解説の続きになります。

1.セルの表現

セルはアーカイブで各状態を区別するインデックスとなります。
理論上はセルと状態を1対1で対応したいですが現実的ではありません。
セルの要件としては、違う意味の状態は違うセルに、似た意味の状態は同じセルで表現してほしいです。

セルの表現方法については様々な研究が既にあり、例えば以下のようなものがあります。

  • 従来のRLアルゴリズムでトレーニングされたNNの中央から潜在コードを取得する
  • AutoEncoder
  • 将来の状態を予測する教師なし学習
  • pixel control(画像が大きく変化する動きを学習する手法らしい)
  • オプションで報酬の予測などの補助タスクを追加

しかし、Go-Exploreの最初の実験では以下の単純な2つの方法を試します。

  • 単純な画像のダウンサンプリング
  • ドメイン固有の情報(マップ情報やx,y等の座標等)

ダウンサンプリングは単純な方法ですが、"Montezuma’s Revenge" で驚くほど好成績がでました。

本記事ではドメイン固有の情報は省略します。
ここは将来の研究で改善の余地がありそうです。

画像のダウンサンプリング

以下の3stepでダウンサンプリングします。

  1. グレースケールに変換
  2. 11×8にリサイズ(値は領域内の平均ピクセル値を使用)
  3. [0,255]から[0,8]の整数にスケール

ss2.png

状態数は11×8×9=792通りですね

2.セルの選択

アーカイブから次に探索の候補となるセルを選択します。
一番簡単な方法はランダムに選ぶことです。
次に論文内の選択方法について説明します。(A.5 Cell selection details の内容)

まず以下の値 $a$ を使います。

a1. このセルが選択された回数
a2. 探索フェーズ中にこのセルが訪問された回数
a3. 新しいセルまたは優れたセルが発見されてからこのセルが選択された回数

これを元に以下でカウントスコアを出します。

ss3.png

$c$ は対象のセル、$w_a$ と $p_a$ は $a$ に関するハイパーパラメータ、$v(c,a)$ は $a$ の回数、$\epsilon_1$ は0除算回避用、$\epsilon_2$ は確率0を回避するハイパーパラメータで $\epsilon_1=0.001$、$\epsilon_2=0.00001$ です。

次に、ドメイン知識がある場合に隣接セルが存在するかどうかを表す隣接スコアを出します。

ss4.png

$n$ は隣接するセルで、$w_n$ はハイパーパラメータ、$HasNeighbor(c,n)$ は隣接していれば1を返し、違う場合は0を返します。
ドメイン知識がない場合、$NeighScore(c,n)=0$ です。

最後に、ドメイン知識がある場合にレベルに対して補正をかけます。

ss5.png

ドメイン知識がない場合この値は1です。

上記3種類の値を元にセルスコア及びセルの選択確率は以下で計算されます。

ss6.png

ss7.png

論文では基本的にランダムよりヒューリスティックなセル選択の方が結果が良いとあり、ここの改善も将来の研究テーマですね

3.状態の復元

選択されたセルの状態に戻る方法ですが3種類の場合を分けて考えます。

シミュレータで復元させる

一番簡単な方法です。
この方法がRLの研究として許容されるかどうかは議論の対象で、論文内にも言及があります。

ざっくりまとめると、一般的にロボット工学などで強化学習を使う場合、実物に使う前にシミュレータ上でトレーニングする場合が多いです。
この場合、シミュレータに復元機能(または決定論的な動作)を持たせ、これを活用してパフォーマンスを向上させるのは悪くない選択肢だという話です。

個人的にはこの選択肢がRLとしてあるのか…というのが衝撃だった

決定論的に復元する

復元と似ていますがもしステップ間の状態遷移が決定論的なら、初期状態から同じアクションを実行すれば元の状態を復元できます。
確率的に遷移する環境では使えませんが、もし環境が決定論的なら使える手法です。

その他、確率的に状態遷移する環境

初期状態から目標の状態に戻る目標条件付きポリシー(Goal-Conditioned Policies)を学習する事で対処できると考えていますが、Go-Exploreがこの条件で役立つかどうかの議論は今後の研究に残しておきます。

え…

Go-Exploreに関するブログの投稿前に確率性をトレーニングする必要があるかの議論はなかったようです。
Atariゲームは確率的なゲームと決定論的なゲームが混在しているため、それぞれでベンチマークが必要である可能性を主張していました。

4.セルからの探索

探索は任意の方法を使えます。
論文では最初にランダムなアクションを選択し、95%の確率で同じアクション、それ以外でランダムな別のアクションに変更する行動をしているそうです。
また探索は、100stepまたはエピソード終了までとしています。
探索中に見つかった新しいセルはアーカイブに保存されます。

ここまででニューラルネットワークを使っていないという事実、ただのランダムな探索で非常にうまく機能するという事実は、Go-Exploreの潜在的な性能を示しています。

ここも将来の研究で改善の余地がありそう

5.アーカイブの更新

2つの条件で更新されます。

  1. アーカイブにセルがない場合
  2. 新しいセルが既存のセルより優れている場合

アーカイブにセルがない場合

以下のデータを保存します。

  • セルに到達した方法(初期状態からセルまでの軌跡)
  • (復元する場合)状態を復元するための情報
  • 軌跡のスコア
  • 軌跡の長さ
  • その他スコアや選択の計算に必要な情報

既存セルの更新

以下の条件で更新します。

  • 軌跡のスコアが高い場合
  • 軌跡のスコアが同じでかつ軌跡の長さが短い場合

更新は以下の情報となります。

  • セルの軌跡
  • セルの選択に影響する情報のリセット(セルの選択回数など)
    →セルに到達する新しい方法は探索に有望の可能性が高いため優先したい
  • セルの訪問回数はリセットしない
    →発見されたセルか更新されたセルか判断できないため
     発見されたセルの方が探索価値は高い

Phase2.方策の堅牢化

Phase1で確率論を排除した学習をし、テスト時に確率論的な動作になる場合、このPhase2を実行します。
(例えばロボットで、学習はシミュレーション、テスト時は実際のロボットなど)

Phase1が終わるとアーカイブに有用な軌跡が集まります。
この軌跡を利用して模倣学習を行います。
模倣学習をする事で、軌跡自体と軌跡に存在しない状況に対処する方法を学習します。
ただ、環境の確率性によってはこれは非常に困難になる場合があります。
ですが、スパース報酬問題を最初から解決するよりは簡単になる場合の方が多いはずです。

アーカイブのどの軌跡を使うかという話は見つけられませんでした…
また、模倣学習ではなく、オフライン強化学習でもいいかも???

模倣学習ではエキスパートな軌跡を厳密に再現しようとするアルゴリズムよりも、エキスパートな軌跡を改善できるアルゴリズムの方が良いです。

論文ではバックワードアルゴリズム(the Backward Algorithm from Salimans and Chen)で試しています。(オンポリシーなRLアルゴリズム前提のアルゴリズムかな?)
これはエージェントを軌跡の最後から最初に向かって通常のRLアルゴリズム(例えばPPO)を学習する手法となります。

このフェーズによる利点は以下です。

  1. (Phase1で集めた)エキスパートな軌跡を更に最適化し、それを超えて一般化する
  2. バックワードアルゴリズムを利用することで、フォワード的なアルゴリズムに比べて報酬の伝播効率が上がる
  3. 不要なアクション(行き止まりを訪れて戻るなど)を排除した状態で学習できる

実装

論文では模倣学習をバックワードアルゴリズムとRLをPPOで書いてますが、PPOはハイパーパラメータの影響が多く安定しないイメージなのでDQN系列で実装します。
以下は、模倣学習としては実装が簡単なR2D3ベースのDQNを実装します。(R2D3が厳密に模倣学習かは置いておきます…)
内容は、DQNでバッチを作る際にエキスパートな軌跡を混ぜてバッチを作成します。

模倣学習と逆強化学習
かなり似ている概念なので説明します。(私の認識も間違っているかもしれませんが…)
模倣学習は行動データ(例えばベテランの人の行動等)を模倣して方策を学習する方法です。
逆強化学習は方策(行動データ)から報酬を推定する方法です。(強化学習は報酬から方策)
ただ、逆強化学習は推定した報酬を元に更に最適な法則を学習するまで含まれている事が多いようです。
その場合、逆強化学習は「行動データ→報酬→方策」という学習ですが、模倣学習は「行動データ→方策」とみることができます。
ですので模倣学習は、逆強化学習の報酬の学習を省略し、直接方策を学習する手法とも解釈できます。

※SRLがv0.16.3のコードです。バージョンが進むと動かない可能性があります。

Config

@dataclass
class Config(RLConfig):
    # --- archive parameters
    action_change_rate: float = 0.05
    explore_max_step: int = 100
    demo_batch_rate: float = 0.1
    w_visit: float = 0.3
    w_select: float = 0
    w_total_select: float = 0.1
    eps1: float = 0.001
    eps2: float = 0.00001

    # --- DQN parameters
    test_epsilon: float = 0
    epsilon: float = 0.01
    memory_warmup_size: int = 1_000
    memory_capacity: int = 10_000
    lr: float = 0.0005
    batch_size: int = 32
    discount: float = 0.99
    target_model_update_interval: int = 2000

Memory

アーカイブ側とQ用のメモリーとデモメモリーの3種類があります。

class Memory(RLMemory):
    def __init__(self, *args):
        super().__init__(*args)

        self.archive = {}
        self.memory_q = []
        self.memory_demo = []

    def length(self):
        return len(self.memory_q)

    def add(self, mode: str, batch) -> None:
        if mode == "archive":
            self.archive_update(batch)
        elif mode == "q":
            self.memory_q.append(batch)
            if len(self.memory_q) > self.config.memory_capacity:
                self.memory_q.pop(0)
        elif mode == "demo":
            self.memory_demo.append(batch)
            if len(self.memory_demo) > self.config.memory_capacity:
                self.memory_demo.pop(0)

addの引数で度のメモリーに保存するか区別します。
まずはアーカイブ側の実装です。

class Memory(RLMemory):
    # --- 状態からcellを区別する文字列を生成します。
    def create_cell(self, state):
        space = self.config.observation_space
        # (1) color -> gray
        if space.stype == SpaceTypes.COLOR:
            state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
        elif space.stype == SpaceTypes.GRAY_3ch:
            state = np.squeeze(state, axis=-1)

        # (2) down sampling
        state = cv2.resize(state, (11, 8), interpolation=cv2.INTER_NEAREST)

        # (3) 255->8
        state = np.round(state * 8.0 / 255.0)

        return "".join([str(int(n)) for n in state.flatten().tolist()])

    def archive_update(self, batch):
        state = batch[0]
        states = batch[1]
        actions = batch[2]
        rewards = batch[3]
        undone = batch[4]
        step = batch[5]
        total_reward = batch[6]
        backup = batch[7]

        cell_key = self.create_cell(state)
        if cell_key not in self.archive:
            self.archive[cell_key] = {
                "step": np.inf,
                "total_reward": -np.inf,
                "score": -np.inf,
                "visit": 0,
                "select": 0,
                "total_select": 0,
            }
        cell = self.archive[cell_key]
        cell["visit"] += 1

        # --- update archive
        _update = False
        if cell["total_reward"] < total_reward:
            _update = True
        elif (cell["total_reward"] == total_reward) and (cell["step"] > step):
            _update = True
        if _update:
            cell["score"] = self._calc_score(cell)
            cell["step"] = step
            cell["select"] = 0
            cell["total_reward"] = total_reward
            cell["states"] = states
            cell["actions"] = actions
            cell["rewards"] = rewards
            cell["undone"] = undone
            cell["backup"] = backup

    def _calc_score(self, cell):
        cnt_score1 = self.config.w_visit * (1 / (cell["visit"] + self.config.eps1)) + self.config.eps2
        cnt_score2 = self.config.w_select * (1 / (cell["select"] + self.config.eps1)) + self.config.eps2
        cnt_score3 = self.config.w_total_select * (1 / (cell["total_select"] + self.config.eps1)) + self.config.eps2
        neigh_score = 0
        level_weight = 1
        score = level_weight * (neigh_score + cnt_score1 + cnt_score2 + cnt_score3 + 1)
        return score

    # --- アーカイブからランダムにセルを選択する
    def archive_select(self):
        if len(self.archive) == 0:
            return None

        # 累積和による乱数
        total = sum([c["score"] for c in self.archive.values()])
        if total == 0:
            return None
        r = random.random() * total
        n = 0
        for cell in self.archive.values():
            n += cell["score"]
            if r <= n:
                break

        cell["select"] += 1
        cell["total_select"] += 1
        cell["score"] = self._calc_score(cell)
        return cell

次にデモメモリーの実装です。
create_demo_memoryでアーカイブからデモメモリー用のbatchを作成します。

class Memory(RLMemory):
    def create_demo_memory(self):
        for cell in self.archive.values():
            for i in range(cell["step"]):
                batch = [
                    cell["states"][i],
                    cell["states"][i + 1],
                    cell["actions"][i],
                    cell["rewards"][i],
                    cell["undone"][i],
                ]
                self.memory_demo.append(batch)

    def sample_demo(self, batch_size):
        return random.sample(self.memory_demo, batch_size)

最後にDQNのQメモリーです。
ランダムに選ぶだけです。

class Memory(RLMemory):
    def sample_q(self, batch_size):
        return random.sample(self.memory_q, batch_size)

Qネットワーク

DQNのシンプルな実装です。
入力画像は0-255で正規化していないので255で割っています。

class QNetwork(keras.Model):
    def __init__(self, in_shape, action_num, **kwargs):
        super().__init__(**kwargs)

        self.h_layers = [
            kl.Conv2D(32, (8, 8), strides=(4, 4), padding="same", activation="relu"),
            kl.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu"),
            kl.Conv2D(64, (3, 3), strides=(1, 1), padding="same", activation="relu"),
            kl.Flatten(),
            kl.Dense(512, activation="relu"),
            kl.Dense(action_num),
        ]

        # build
        self(np.zeros((1,) + in_shape))
        self.loss_func = keras.losses.Huber()

    def call(self, x, training=False):
        x = x / 255
        for h in self.h_layers:
            x = h(x, training=training)
        return x

    @tf.function
    def compute_train_loss(self, state, onehot_action, target_q):
        q = self(state, training=True)
        q = tf.reduce_sum(q * onehot_action, axis=1)
        loss = self.loss_func(target_q, q)
        return loss

Parameter

DQNです。

class Parameter(RLParameter):
    def __init__(self, *args):
        super().__init__(*args)

        in_shape = self.config.observation_space.shape
        action_num = self.config.action_space.n
        self.q_online = QNetwork(in_shape, action_num, name="Q_online")
        self.q_target = QNetwork(in_shape, action_num, name="Q_target")
        self.q_target.set_weights(self.q_online.get_weights())

Trainer

DQNと同じですが、バッチにデモメモリーも混ぜて学習します。

class Trainer(RLTrainer):
    def __init__(self, *args):
        super().__init__(*args)
        self.opt_q = keras.optimizers.Adam(self.config.lr)

    def train(self) -> None:
        if len(self.memory.memory_q) < self.config.memory_warmup_size:
            return
        batchs = self.memory.sample_q(self.config.batch_size)
        state = []
        n_state = []
        action = []
        reward = []
        undone = []
        for b in batchs:
            state.append(b[0])
            n_state.append(b[1])
            action.append(b[2])
            reward.append(b[3])
            undone.append(b[4])

        # デモメモリもバッチに追加
        demo_size = int(self.config.batch_size * self.config.demo_batch_rate)
        demo_size = demo_size if demo_size > 0 else 1
        if len(self.memory.memory_demo) > demo_size:
            batchs = self.memory.sample_demo(demo_size)
            for b in batchs:
                state.append(b[0])
                n_state.append(b[1])
                action.append(b[2])
                reward.append(b[3])
                undone.append(b[4])
        state = np.asarray(state, self.config.dtype)
        n_state = np.asarray(n_state, self.config.dtype)
        action = np.asarray(action, self.config.dtype)
        reward = np.array(reward, self.config.dtype)
        undone = np.array(undone, self.config.dtype)

        # --- calc next q
        batch_size = n_state.shape[0]
        n_q = self.parameter.q_online(n_state)
        n_q_target = self.parameter.q_target(n_state).numpy()
        n_act_idx = np.argmax(n_q, axis=-1)
        maxq = n_q_target[np.arange(batch_size), n_act_idx]
        target_q = reward + undone * self.config.discount * maxq
        target_q = target_q[..., np.newaxis]

        # --- train q
        with tf.GradientTape() as tape:
            loss = self.parameter.q_online.compute_train_loss(state, action, target_q)
        grad = tape.gradient(loss, self.parameter.q_online.trainable_variables)
        self.opt_q.apply_gradients(zip(grad, self.parameter.q_online.trainable_variables))
        self.info["loss"] = loss.numpy()

        # --- targetと同期
        if self.train_count % self.config.target_model_update_interval == 0:
            self.parameter.q_target.set_weights(self.parameter.q_online.get_weights())
            self.sync_count += 1
            self.info["sync"] = self.sync_count

        self.train_count += 1

Worker


class Worker(RLWorker):
    def on_start(self, worker, context):
        # 学習の最初にアーカイブからデモメモリのバッチを作成
        if self.training and not self.rollout:
            self.memory.create_demo_memory()

    def on_reset(self, worker):
        if self.rollout:
            # rolloutはアーカイブを作成するフェーズとしました

            # rolloutの行動は再保にランダムに行動を決め、95%で行動を変える
            self.action = self.sample_action()

            # 初期状態をアーカイブに入れる
            batch = [
                worker.state,
                [worker.state],
                [],
                [],
                [],
                0,
                0,
                worker.backup(),
            ]
            self.memory.add("archive", batch)

            # アーカイブから目標のセルを選んでrestore
            self.cell_step = 0
            cell = self.memory.archive_select()
            if cell is not None:
                worker.restore(cell["backup"])
                self.episode_step = cell["step"]
                self.episode_reward = cell["total_reward"]
                self.recent_states = cell["states"][:]
                self.recent_actions = cell["actions"][:]
                self.recent_rewards = cell["rewards"][:]
                self.recent_undone = cell["undone"][:]
            else:
                self.episode_step = 0
                self.episode_reward = 0
                self.recent_states = [worker.state]
                self.recent_actions = []
                self.recent_rewards = []
                self.recent_undone = []

    def policy(self, worker) -> int:
        # rolloutの行動は再保にランダムに行動を決め、95%で行動を変える
        # それ以外の行動はDQN準拠
        if self.rollout:
            
            if random.random() < 0.05:
                self.action = self.sample_action()
            return self.action
        elif self.training:
            epsilon = self.config.epsilon
        else:
            epsilon = self.config.test_epsilon

        if random.random() < epsilon:
            action = self.sample_action()
        else:
            state = worker.state[np.newaxis, ...]
            q = self.parameter.q_online(state)[0].numpy()
            action = int(np.argmax(q))
        return action

    def on_step(self, worker):
        if not self.training:
            return

        # rollout中はアーカイブ用の情報を集めて保存
        # それ以外はDQN準拠
        if self.rollout:
            self.episode_step += 1
            self.episode_reward += worker.reward
            self.recent_states.append(worker.state)
            self.recent_actions.append(funcs.one_hot(worker.action, self.config.action_space.n))
            self.recent_rewards.append(worker.reward)
            self.recent_undone.append(int(not worker.terminated))
            batch = [
                worker.state,
                self.recent_states[:],
                self.recent_actions[:],
                self.recent_rewards[:],
                self.recent_undone[:],
                self.episode_step,
                self.episode_reward,
                worker.backup(),
            ]
            self.memory.add("archive", batch)
            self.info["archive_size"] = len(self.memory.archive)

            # 一定数行動したらエピソードを終了させる
            self.cell_step += 1
            if self.cell_step > self.config.explore_max_step:
                worker.env.abort_episode()

        else:
            batch = [
                worker.prev_state,
                worker.state,
                funcs.one_hot(worker.action, self.config.action_space.n),
                worker.reward,
                int(not worker.terminated),
            ]
            self.memory.add("q", batch)

実行結果

学習結果は以下です。

Pong_Go-Explore.gif

学習過程はDQNとみてみましたがPongだとあまり変わりませんね….
(SNDに続き例が悪い…)

Pong.png

コード全体

※SRLがv0.16.3で動作確認しています。

import os
import random
from dataclasses import dataclass

import cv2
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow import keras

import srl
from srl.algorithms import dqn
from srl.base.define import RLBaseTypes, SpaceTypes
from srl.base.rl.algorithms.base_dqn import RLConfig, RLWorker
from srl.base.rl.memory import RLMemory
from srl.base.rl.parameter import RLParameter
from srl.base.rl.registration import register
from srl.base.rl.trainer import RLTrainer
from srl.rl import functions as funcs
from srl.rl.models.config.input_config import RLConfigComponentInput
from srl.rl.processors.atari_processor import AtariPongProcessor
from srl.utils import common

kl = keras.layers


@dataclass
class Config(
    RLConfig,
    RLConfigComponentInput,
):
    test_epsilon: float = 0
    epsilon: float = 0.01

    # --- archive parameters
    explore_max_step: int = 100
    demo_batch_rate: float = 0.1
    w_visit: float = 0.3
    w_select: float = 0
    w_total_select: float = 0.1
    eps1: float = 0.001
    eps2: float = 0.00001

    # --- q parameters
    memory_warmup_size: int = 1_000
    memory_capacity: int = 10_000
    lr: float = 0.0005
    batch_size: int = 32
    discount: float = 0.99
    target_model_update_interval: int = 2000

    def get_base_observation_type(self) -> RLBaseTypes:
        return RLBaseTypes.IMAGE

    def get_framework(self) -> str:
        return "tensorflow"

    def get_name(self) -> str:
        return "Go-Explore"

    def assert_params(self) -> None:
        super().assert_params()
        self.assert_params_input()


register(
    Config(),
    __name__ + ":Memory",
    __name__ + ":Parameter",
    __name__ + ":Trainer",
    __name__ + ":Worker",
)


class Memory(RLMemory[Config]):
    def __init__(self, *args):
        super().__init__(*args)

        self.archive = {}
        self.memory_q = []
        self.memory_demo = []

    def length(self):
        return len(self.memory_q)

    def add(self, mode: str, batch) -> None:
        if mode == "archive":
            self.archive_update(batch)
        elif mode == "q":
            self.memory_q.append(batch)
            if len(self.memory_q) > self.config.memory_capacity:
                self.memory_q.pop(0)
        elif mode == "demo":
            self.memory_demo.append(batch)
            if len(self.memory_demo) > self.config.memory_capacity:
                self.memory_demo.pop(0)

    def create_cell(self, state):
        space = self.config.observation_space
        # (1) color -> gray
        if space.stype == SpaceTypes.COLOR:
            state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
        elif space.stype == SpaceTypes.GRAY_3ch:
            state = np.squeeze(state, axis=-1)

        # (2) down sampling
        state = cv2.resize(state, (11, 8), interpolation=cv2.INTER_NEAREST)

        # (3) 255->8
        state = np.round(state * 8.0 / 255.0)

        return "".join([str(int(n)) for n in state.flatten().tolist()])

    def archive_update(self, batch):
        state = batch[0]
        states = batch[1]
        actions = batch[2]
        rewards = batch[3]
        undone = batch[4]
        step = batch[5]
        total_reward = batch[6]
        backup = batch[7]

        cell_key = self.create_cell(state)
        if cell_key not in self.archive:
            self.archive[cell_key] = {
                "step": np.inf,
                "total_reward": -np.inf,
                "score": -np.inf,
                "visit": 0,
                "select": 0,
                "total_select": 0,
            }
        cell = self.archive[cell_key]
        cell["visit"] += 1

        # --- update archive
        _update = False
        if cell["total_reward"] < total_reward:
            _update = True
        elif (cell["total_reward"] == total_reward) and (cell["step"] > step):
            _update = True
        if _update:
            cell["score"] = self._calc_score(cell)
            cell["step"] = step
            cell["select"] = 0
            cell["total_reward"] = total_reward
            cell["states"] = states
            cell["actions"] = actions
            cell["rewards"] = rewards
            cell["undone"] = undone
            cell["backup"] = backup

    def _calc_score(self, cell):
        cnt_score1 = self.config.w_visit * (1 / (cell["visit"] + self.config.eps1)) + self.config.eps2
        cnt_score2 = self.config.w_select * (1 / (cell["select"] + self.config.eps1)) + self.config.eps2
        cnt_score3 = self.config.w_total_select * (1 / (cell["total_select"] + self.config.eps1)) + self.config.eps2
        neigh_score = 0
        level_weight = 1
        score = level_weight * (neigh_score + cnt_score1 + cnt_score2 + cnt_score3 + 1)
        return score

    def archive_select(self):
        if len(self.archive) == 0:
            return None

        # 累積和による乱数
        total = sum([c["score"] for c in self.archive.values()])
        if total == 0:
            return None
        r = random.random() * total
        n = 0
        for cell in self.archive.values():
            n += cell["score"]
            if r <= n:
                break

        cell["select"] += 1
        cell["total_select"] += 1
        cell["score"] = self._calc_score(cell)
        return cell

    def create_demo_memory(self):
        for cell in self.archive.values():
            for i in range(cell["step"]):
                batch = [
                    cell["states"][i],
                    cell["states"][i + 1],
                    cell["actions"][i],
                    cell["rewards"][i],
                    cell["undone"][i],
                ]
                self.memory_demo.append(batch)

    def sample_q(self, batch_size):
        return random.sample(self.memory_q, batch_size)

    def sample_demo(self, batch_size):
        return random.sample(self.memory_demo, batch_size)


class QNetwork(keras.Model):
    def __init__(self, in_shape, action_num, **kwargs):
        super().__init__(**kwargs)

        self.h_layers = [
            kl.Conv2D(32, (8, 8), strides=(4, 4), padding="same", activation="relu"),
            kl.Conv2D(64, (4, 4), strides=(2, 2), padding="same", activation="relu"),
            kl.Conv2D(64, (3, 3), strides=(1, 1), padding="same", activation="relu"),
            kl.Flatten(),
            kl.Dense(512, activation="relu"),
            kl.Dense(action_num),
        ]

        # build
        self(np.zeros((1,) + in_shape))
        self.loss_func = keras.losses.Huber()

    def call(self, x, training=False):
        x = x / 255
        for h in self.h_layers:
            x = h(x, training=training)
        return x

    @tf.function
    def compute_train_loss(self, state, onehot_action, target_q):
        q = self(state, training=True)
        q = tf.reduce_sum(q * onehot_action, axis=1)
        loss = self.loss_func(target_q, q)
        return loss


class Parameter(RLParameter[Config]):
    def __init__(self, *args):
        super().__init__(*args)

        in_shape = self.config.observation_space.shape
        action_num = self.config.action_space.n
        self.q_online = QNetwork(in_shape, action_num, name="Q_online")
        self.q_target = QNetwork(in_shape, action_num, name="Q_target")
        self.q_target.set_weights(self.q_online.get_weights())

    def call_restore(self, data, **kwargs):
        self.q_online.set_weights(data)
        self.q_target.set_weights(data)

    def call_backup(self, **kwargs):
        return self.q_online.get_weights()

    def summary(self, **kwargs):
        self.q_online.summary(**kwargs)


class Trainer(RLTrainer[Config, Parameter, Memory]):
    def __init__(self, *args):
        super().__init__(*args)

        self.opt_q = keras.optimizers.Adam(self.config.lr)
        self.sync_count = 0

    def train(self) -> None:
        if len(self.memory.memory_q) < self.config.memory_warmup_size:
            return
        batchs = self.memory.sample_q(self.config.batch_size)
        state = []
        n_state = []
        action = []
        reward = []
        undone = []
        for b in batchs:
            state.append(b[0])
            n_state.append(b[1])
            action.append(b[2])
            reward.append(b[3])
            undone.append(b[4])
        demo_size = int(self.config.batch_size * self.config.demo_batch_rate)
        demo_size = demo_size if demo_size > 0 else 1
        if len(self.memory.memory_demo) > demo_size:
            batchs = self.memory.sample_demo(demo_size)
            for b in batchs:
                state.append(b[0])
                n_state.append(b[1])
                action.append(b[2])
                reward.append(b[3])
                undone.append(b[4])
        state = np.asarray(state, self.config.dtype)
        n_state = np.asarray(n_state, self.config.dtype)
        action = np.asarray(action, self.config.dtype)
        reward = np.array(reward, self.config.dtype)
        undone = np.array(undone, self.config.dtype)

        # --- calc next q
        batch_size = n_state.shape[0]
        n_q = self.parameter.q_online(n_state)
        n_q_target = self.parameter.q_target(n_state).numpy()
        n_act_idx = np.argmax(n_q, axis=-1)
        maxq = n_q_target[np.arange(batch_size), n_act_idx]
        target_q = reward + undone * self.config.discount * maxq
        target_q = target_q[..., np.newaxis]

        # --- train q
        with tf.GradientTape() as tape:
            loss = self.parameter.q_online.compute_train_loss(state, action, target_q)
        grad = tape.gradient(loss, self.parameter.q_online.trainable_variables)
        self.opt_q.apply_gradients(zip(grad, self.parameter.q_online.trainable_variables))
        self.info["loss"] = loss.numpy()

        # --- targetと同期
        if self.train_count % self.config.target_model_update_interval == 0:
            self.parameter.q_target.set_weights(self.parameter.q_online.get_weights())
            self.sync_count += 1
            self.info["sync"] = self.sync_count

        self.train_count += 1


class Worker(RLWorker[Config, Parameter]):
    def on_start(self, worker, context):
        assert not self.distributed
        self.memory: Memory = self.memory

        if self.training and not self.rollout:
            self.memory.create_demo_memory()
            self.info["demo_size"] = len(self.memory.memory_demo)

    def on_reset(self, worker):
        if self.rollout:
            self.action = self.sample_action()
            batch = [
                worker.state,
                [worker.state],
                [],
                [],
                [],
                0,
                0,
                worker.backup(),
            ]
            self.memory.add("archive", batch)

            self.cell_step = 0
            cell = self.memory.archive_select()
            if cell is not None:
                worker.restore(cell["backup"])
                self.episode_step = cell["step"]
                self.episode_reward = cell["total_reward"]
                self.recent_states = cell["states"][:]
                self.recent_actions = cell["actions"][:]
                self.recent_rewards = cell["rewards"][:]
                self.recent_undone = cell["undone"][:]
            else:
                self.episode_step = 0
                self.episode_reward = 0
                self.recent_states = [worker.state]
                self.recent_actions = []
                self.recent_rewards = []
                self.recent_undone = []

    def policy(self, worker) -> int:
        if self.rollout:
            if random.random() < 0.05:
                self.action = self.sample_action()
            return self.action
        elif self.training:
            epsilon = self.config.epsilon
        else:
            epsilon = self.config.test_epsilon

        if random.random() < epsilon:
            action = self.sample_action()
        else:
            state = worker.state[np.newaxis, ...]
            q = self.parameter.q_online(state)[0].numpy()
            action = int(np.argmax(q))
        return action

    def on_step(self, worker):
        if not self.training:
            return

        if self.rollout:
            self.episode_step += 1
            self.episode_reward += worker.reward
            self.recent_states.append(worker.state)
            self.recent_actions.append(funcs.one_hot(worker.action, self.config.action_space.n))
            self.recent_rewards.append(worker.reward)
            self.recent_undone.append(int(not worker.terminated))
            batch = [
                worker.state,
                self.recent_states[:],
                self.recent_actions[:],
                self.recent_rewards[:],
                self.recent_undone[:],
                self.episode_step,
                self.episode_reward,
                worker.backup(),
            ]
            self.memory.add("archive", batch)
            self.info["archive_size"] = len(self.memory.archive)

            self.cell_step += 1
            if self.cell_step > self.config.explore_max_step:
                worker.env.abort_episode()

        else:
            batch = [
                worker.prev_state,
                worker.state,
                funcs.one_hot(worker.action, self.config.action_space.n),
                worker.reward,
                int(not worker.terminated),
            ]
            self.memory.add("q", batch)

    def render_terminal(self, worker, **kwargs) -> None:
        # policy -> render -> env.step -> on_step

        # --- archive
        print(f"size: {len(self.memory.archive)}")
        key = self.memory.create_cell(worker.state)
        print(key in self.memory.archive)
        if key in self.memory.archive:
            cell = self.memory.archive[key]
            print(f"step        : {cell['step']}")
            print(f"total_reward: {cell['total_reward']}")
            print(f"score       : {cell['score']}")
            print(f"visit       : {cell['visit']}")
            print(f"select      : {cell['select']}")
            print(f"total_select: {cell['total_select']}")

        # --- q
        q = self.parameter.q_online(worker.state[np.newaxis, ...])[0]
        maxa = np.argmax(q)

        def _render_sub(a: int) -> str:
            return f"{q[a]:7.5f}"

        funcs.render_discrete_action(int(maxa), self.config.action_space, worker.env, _render_sub)


def train(name):

    env_config = srl.EnvConfig(
        "ALE/Pong-v5",
        kwargs=dict(frameskip=7, repeat_action_probability=0, full_action_space=False),
        processors=[AtariPongProcessor()],
    )

    if name == "DQN":
        rl_config = dqn.Config(
            target_model_update_interval=2_000,
            epsilon=0.1,
            discount=0.99,
            lr=0.001,
            enable_reward_clip=False,
            enable_double_dqn=True,
            enable_rescale=False,
            memory_warmup_size=1000,
            memory_capacity=10_000,
            memory_compress=False,
            window_length=8,
        )
        rl_config.input_image_block.set_dqn_block()
        rl_config.hidden_block.set((512,))
    elif name == "Go-Explore":
        rl_config = Config(
            target_model_update_interval=2_000,
            epsilon=0.1,
            discount=0.99,
            lr=0.001,
            memory_warmup_size=1000,
            memory_capacity=10_000,
            window_length=8,
        )

    runner = srl.Runner(env_config, rl_config)
    runner.model_summary()

    # history setting
    his_path = os.path.join(os.path.dirname(__file__), f"Pong_{name}")
    runner.set_history_on_file(his_path, interval_mode="time", interval=5, enable_eval=True)

    if name == "Go-Explore":
        # phase1
        runner.rollout(max_steps=2_000_000)
        # phase2
        runner.train(max_train_count=50_000)
    else:
        runner.train(max_train_count=200_000)

    # animation
    runner.animation_save_gif(os.path.join(os.path.dirname(__file__), f"Pong_{name}.gif"))


def history_plot():
    base_dir = os.path.dirname(__file__)
    his = srl.Runner.load_histories(
        [
            os.path.join(base_dir, "Pong_DQN"),
            os.path.join(base_dir, "Pong_Go-Explore"),
        ]
    )
    his.plot("train", "eval_reward0", aggregation_num=20, _no_plot=True)
    plt.savefig(os.path.join(os.path.dirname(__file__), "Pong.png"))


if __name__ == "__main__":
    common.logger_print()

    train("DQN")
    train("Go-Explore")
    history_plot()

個人的まとめ

・かなり実践的な内容
・Go-Exploreは確率的な遷移を無視することで効率的なRLアルゴリズムを提案
・確率的な遷移にも対処できる余地を残していると個人的には感じている

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?