0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[深層強化学習] PrioritizedExperienceReplayの実装とcartpoleにおける性能比較

Posted at

はじめに

PrioritizedExperienceReplayとは

本記事ではPrioritizedExperienceReplay(PER)を実装し、性能比較していきます。

Github

参考

実装

PERの計算及びサンプリングに必要なSumTreeと実際のPERを実装していきます。
詳細のコードはGithubを確認してください。

SumTreeの実装

SumTreeの実装はこちらのコードを参考にしました。
https://github.com/jaara/AI-blog/blob/master/SumTree.py

途中で出てくるmax()関数以外は全て同じコードとなります。


class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros( 2*capacity - 1 )
        self.data = np.zeros( capacity, dtype=object )
        self.index_leaf_start = capacity - 1

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def max(self):
        return self.tree[self.index_leaf_start:].max()

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])

    def total(self):
        return self.tree[0]

    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

PrioritizedExperienceReplayの実装

PERの実装をしていきます。

実装上のポイントはtransitionをpush()する際に初期優先度として、sumtree.max()関数からpriorityを取得し、その値もしくは1を利用する点です。論文では初期優先度はreplayに溜まっているpriorityの最大値もしくは1にすると記述されているため、このような実装になっています。

ポイントの二つ目は、後で優先度を更新(update_priority())する際に必要となるindicesを保持しておくため、sample()するたびにindicesに追加しています。

class PrioritizedReplayMemory():

    def __init__(self, CAPACITY, epsilon = 0.0001, alpha = 0.6):
        self.capacity = CAPACITY  # メモリの最大長さ
        self.tree = SumTree(CAPACITY)
        self.epsilon = epsilon
        self.alpha = alpha
        self.size = 0
        self.indices = []

    def _getPriority(self, td_error):
        return (np.abs(td_error) + self.epsilon) ** self.alpha
 
    def push(self, transition: Transition):
        """state, action, state_next, rewardをメモリに保存します"""
        self.size += 1
        if self.size > self.capacity:
            self.size = self.capacity

        priority = self.tree.max()
        if priority <= 0:
            priority = 1

        self.tree.add(priority, transition)
 
    def sample(self, batch_size):
        """batch_size分だけ、ランダムに保存内容を取り出します"""
        list = []
        self.indices = []
        for rand in np.random.uniform(0, self.tree.total(), batch_size):
            (idx, _, data) = self.tree.get(rand)
            list.append(data)
            self.indices.append(idx)

        return list

    def update(self, td_errors):
        if self.indices != None:
            for i, td_error in enumerate(td_errors):
                priority = self._getPriority(td_error)
                self.tree.update(self.indices[i], priority)

    def update_priority(self, idx, td_error):
        priority = self._getPriority(td_error)
        self.tree.update(idx, priority)
 
    def __len__(self):
        return self.size

性能検証結果

ランダムサンプリングをするReplayと優先度付きReplayの二つをDuelingDQNで実行します。今回は下記のハイパーパラメータを利用し3000エピソード実行しました。

Parameter Value Description
episode 3000 エピソード数
capacity 10000 experience replayの容量
batch_size 32 バッチサイズ
gamma 0.97 時間割引率
target_update_iter 20 target networkの更新間隔
eps_start 0.9 ε-greedyのパラメータ
eps_end 0.5 ε-greedyのパラメータ
eps_decay 200 ε-greedyのパラメータ
epsilon 0.0001 PERでのパラメータ
alpha 0.6 PERでのパラメータ

通常のExperiencedReplayを用いたDuelingDQNの学習結果

スクリーンショット 2021-09-12 14.51.20.png

PrioritizedExperienceReplayを用いたDuelingDQNの学習結果

スクリーンショット 2021-09-12 19.24.59.png

通常のExperiencedReplayをを使用した場合は1200エピソード超えたあたりから成績が向上していますが、PrioritizedExperienceReplayを使用した場合は500エピソードあたりから向上しており、学習効率の向上が見て取れます。

通常のExperiencedReplayを用いたDuelingDQNの学習結果

ezgif.com-gif-maker.gif

こちらは3000エピソード経過した際のモデルを用いて可視化してみました。
3000エピソードでは学習が足りないのか、不安定なのがわかると思います。

PrioritizedExperienceReplayを用いたDuelingDQNの学習結果

ezgif.com-gif-maker (2).gif

こちらがPrioritizedExperienceReplayを用いた学習結果になります。もっと中心で耐えてくれると学習が安定していることがわかりやすかったのですが、そうはなりませんでした。(ただ、何回かは中心で保とうとする様子もみれました。)

まとめ

今回は PrioritizedExperiencedReplayの実装を行い、性能を検証しました。

エピソード数が少ないため可視化結果では対して変化がないように見えますが、step推移をみるとPrioritizedExperiencedReplayを行うと学習の安定が早いことがわかると思います。

次回はMulti-step-bootstrapの実装を行なっていきます。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?