5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【強化学習】DeepMind製Experience ReplayライブラリReverbの使い方調査【クライアント編】

Posted at

英語で投稿したブログ記事の日本語焼き直し

1. はじめに

前回に引き続き、DeepMind製のExperience ReplayライブラリのReberbについて。
今回は、データの入出力操作を指示するクライアント側について、ソースコードを読みながら、READMEにか書かれていない部分まで調査した。

2. ClientTFClient

Reverbはサーバー・クライアントモデルを採用しているが、クライアント側のクラスとしては、reverb.Clientreverb.TFClientの2つがある。

Client が開発初期向けで、TFClientが実際の学習プログラムの中で利用するものという位置づけとのことである。大きな違いは、TFClientはその名のとおり、TensorFlowの計算グラフの中で利用することが意図されている。

両者のAPIや使い方は一貫しておらずややこしかったのが、今回整理記事を書こうと思ったモチベーションの1つである。

以下の記事では、サーバーとクライアントプログラムが、以下のように初期化されていることを前提とする。

import numpy as np
import tensorflow as tf
import reverb

table_name = "ReplayBuffer"
alpha = 0.8
buffer_size = 1000
batch_size = 32

server = reverb.Server(tables=[reverb.Table(name=table_name,
                                            sampler=reverb.selectors.Prioritized(alpha),
                                            remover=reverb.selectors.Fifo(),
                                            max_size=buffer_size,
                                            rate_limiter=reverb.rate_limiters.MinSize(1))])

client = reverb.Client(f"localhost:{server.port}")
tf_client = reverb.TFClient(f"localhost:{server.port}")


obs = np.zeros((4,))
act = np.ones((1,))
rew = np.ones((1,))
next_obs = np.zeros_like(obs)
done = np.zeros((1,))
priority = 1.0

dtypes = [tf.float64,tf.float64,tf.float64,tf.float64,tf.float64]
shapes = [4,1,1,4,1]

3. 遷移(transition)または軌道(trajectory)の保存

3.1 Client.insert

client.insert([obs,act,rew,next_obs,done],priorities={table_name: priority})

priorities引数が、 dict になっているのは、同じデータを複数のテーブル (リプレイ・バッファ) にそれぞれ異なるpriorityで同時に登録可能な仕様だからである。 (そんなニーズがあるなんて知らなかった)

たとえ、Prioritizedではない普通のリプレイ・バッファであっても、priorityを指定する必要がある。

Client.insert は呼び出されるたびに、データをサーバーに送信する

3.2 Client.writer

with client.writer(max_sequence_length=3) as writer:
    writer.append([obs,act,rew,next_obs,done])
    writer.create_item(table_name,num_timesteps=1,priority=priority)

    writer.append([obs,act,rew,next_obs,done])
    writer.append([obs,act,rew,next_obs,done])
    writer.append([obs,act,rew,next_obs,done])
    writer.create_item(table_name,num_timesteps=3,priority=priority) # 3ステップを1項目として登録。
# withブロックから出る際に、サーバーへ送信する

Client.writer メソッドが返す、 reverb.Writer をコンテキストマネージャーとして利用することで、
より柔軟に保存する内容を設定することができる。

例えば、上のサンプルコードの前半は、3.1と同じ内容を保存しているが、後半は3ステップをまとめて1つの項目として保存している。つまり、サンプルする際にも3つをまとめて1つとしてサンプルできる。例えば、エピソード毎サンプリングしたい際に利用することが想定できる。

Writer.flush() または、Writer.close() が呼ばれる際 (withブロックから抜ける際にも自動で呼ばれる) に、サーバーへデータが送信される。

3.3 TFClient.insert

tf_client.inser([tf.constant(obs),
                 tf.constant(act),
                 tf.constant(rew),
                 tf.constant(next_obs),
                 tf.constant(done)],
                 tablea=tf.constant([table_name]),
                 priorities=tf.constant([priority],dtype=tf.float64))

tables引数は、strのランク1のtf.Tensorで、prioritiesは、float64(明記しないとfloat32になる)のランク1のtf.Tensorであり、両者のshapeは一致しないといけない。

後で、サンプルするときのために、各データのtf.Tensorもランク1以上にしておくことがおそらく必要。

3.4 まとめ

サーバーへの送信 複数ステップを1項目に TF計算グラフ内での利用 データ
Client.insert 都度 X X 何でも良い
Client.writer Writer.close(), Writer.flush() (含 with の退出時) O X 何でも良い
TFClient.insert 都度 X O tf.Tensor

4. 遷移(transition)または軌道(trajectory)の読出し

いずれの手法もPrioritized Experience Replayの重み補正のβパラメータに対応していない。(そもそも重点サンプリングの重みを計算してくれない)

4.1 Client.sample

client.sample(table_name,num_samples=batch_size)

戻り値は、reverb.replay_sample.ReplaySamplegeneratorである。 ReplaySampleは名前付きTupleで、infodataを所持している。dataには、保存したデータが、infoにはkeyやpriorityなどの情報が含まれている。

4.2 TFClient.sample

tf_client.sample(tf.constant([table_name]),data_dtypes=dtypes)

この手法は、残念ながらバッチサンプリングに対応していない。戻り値は、ReplaySampleになる。

4.3 TFClient.dataset

tf_client.dataset(tf.constant([table_name]),dtypes=dtypes,shapes=shapes)

他の手法とは全く異なる方式を採用しており、大規模な本番学習で主にはこの手法を採用することが意図されていると思われる。

この関数の戻り値は、tf.data.Datasetを継承したreverb.ReplayDatasetである。このReplayDatasetgeneratorのようにReplaySampleを引き出すことができ、適切なタイミングでサーバーからデータを自動でフェッチしてきてくれる。つまり、毎回 sampleするのではなく、一度 ReplayDatasetを設定するとあとは、自動で保存されたデータを排出し続けてくれる仕組みになっている。

shapes は、0を要素に指定するとエラーを発生させるので、保存するデータはランク1以上にしておくことが必要であると思われる。

その他パフォーマンス調整のための各種パラメータはここに書くには細かいので、ソースコードのコメントを確認してほしい。

4.4 まとめ

バッチ出力 戻り値型 型の指定 形状の指定
Client.sample O replay_sample.ReplaySamplegenerator 不要 不要
TFClient.sample X replay_sample.ReplaySample 必要 不要
TFClient.dataset O (内部で自動で実施) ReplayDataset 必要 必要

5. 重要度 (priority) の更新

他のリプレイ・バッファの実装と異なり、要素を指定するIDは0始まりの連番ではなく、一見ランダムに見えるハッシュである。ReplaySample.info.keyでアクセス可能である。
(書きにくいので、サンプルコードを一部省略した書き方にします。すみません。)

5.1 Client.mutate_priorities

client.mutate_priorities(table_name,updates={key:new_priority},deletes=[...])

こちらは、更新だけでなく、削除も可能である。

5.2 TFClient.update_priorities

tf_client.update_priorities(tf.constant([table_name]),
                            keys=tf.constant([...]),
                            priorities=tf.constant([...],dtype=tf.float64))

6. 性能比較

せっかくなので、拙作のcpprbも含めてベンチマークをとった。

注意: 強化学習には、深層学習の学習や環境の更新など他にも重たい処理があるので、リプレイ・バッファの速度だけで決まるものではない。(一方、リプレイ・バッファの実装と条件次第では、リプレイ・バッファの処理時間と深層学習の処理時間が同じぐらいになることもあるらしい。)

以下のDockerfileによる環境で、ベンチマークを実行した。

Dockerfile
FROM python:3.7

RUN apt update \
	&& apt install -y --no-install-recommends libopenmpi-dev zlib1g-dev \
	&& apt clean \
	&& rm -rf /var/lib/apt/lists/* \
	&& pip install tf-nightly==2.3.0.dev20200604 dm-reverb-nightly perfplot


# Reverb requires development version TensorFlow

CMD ["bash"]

(cpprbのレポジトリのCI上で、実施しているので、追加でcpprbもインストールされている。 pip install cpprb とほぼ同義)

そして、以下のベンチマークスクリプトを実行して実行時間のグラフを描いた。

benchmark.py
import gc
import itertools


import numpy as np
import perfplot
import tensorflow as tf

# DeepMind/Reverb: https://github.com/deepmind/reverb
import reverb

from cpprb import (ReplayBuffer as RB,
                   PrioritizedReplayBuffer as PRB)


# Configulation
buffer_size = 2**12

obs_shape = 15
act_shape = 3

alpha = 0.4
beta  = 0.4

env_dict = {"obs": {"shape": obs_shape},
            "act": {"shape": act_shape},
            "next_obs": {"shape": obs_shape},
            "rew": {},
            "done": {}}


# Initialize Replay Buffer
rb  =  RB(buffer_size,env_dict)


# Initialize Prioritized Replay Buffer
prb  =  PRB(buffer_size,env_dict,alpha=alpha)


# Initalize Reverb Server
server = reverb.Server(tables =[
    reverb.Table(name='ReplayBuffer',
                 sampler=reverb.selectors.Uniform(),
                 remover=reverb.selectors.Fifo(),
                 max_size=buffer_size,
                 rate_limiter=reverb.rate_limiters.MinSize(1)),
    reverb.Table(name='PrioritizedReplayBuffer',
                 sampler=reverb.selectors.Prioritized(alpha),
                 remover=reverb.selectors.Fifo(),
                 max_size=buffer_size,
                 rate_limiter=reverb.rate_limiters.MinSize(1))
])

client = reverb.Client(f"localhost:{server.port}")
tf_client = reverb.TFClient(f"localhost:{server.port}")


# Helper Function
def env(n):
    e = {"obs": np.ones((n,obs_shape)),
         "act": np.zeros((n,act_shape)),
         "next_obs": np.ones((n,obs_shape)),
         "rew": np.zeros(n),
         "done": np.zeros(n)}
    return e

def add_client(_rb,table):
    """ Add for Reverb Client
    """
    def add(e):
        n = e["obs"].shape[0]
        with _rb.writer(max_sequence_length=1) as _w:
            for i in range(n):
                _w.append([e["obs"][i],
                           e["act"][i],
                           e["rew"][i],
                           e["next_obs"][i],
                           e["done"][i]])
                _w.create_item(table,1,1.0)
    return add

def add_client_insert(_rb,table):
    """ Add for Reverb Client
    """
    def add(e):
        n = e["obs"].shape[0]
        for i in range(n):
            _rb.insert([e["obs"][i],
                        e["act"][i],
                        e["rew"][i],
                        e["next_obs"][i],
                        e["done"][i]],priorities={table: 1.0})
    return add

def add_tf_client(_rb,table):
    """ Add for Reverb TFClient
    """
    def add(e):
        n = e["obs"].shape[0]
        for i in range(n):
            _rb.insert([tf.constant(e["obs"][i]),
                        tf.constant(e["act"][i]),
                        tf.constant(e["rew"][i]),
                        tf.constant(e["next_obs"][i]),
                        tf.constant(e["done"])],
                       tf.constant([table]),
                       tf.constant([1.0],dtype=tf.float64))
    return add

def sample_client(_rb,table):
    """ Sample from Reverb Client
    """
    def sample(n):
        return [i for i in _rb.sample(table,num_samples=n)]

    return sample

def sample_tf_client(_rb,table):
    """ Sample from Reverb TFClient
    """
    def sample(n):
        return [_rb.sample(table,
                           [tf.float64,tf.float64,tf.float64,tf.float64,tf.float64])
                for _ in range(n)]

    return sample

def sample_tf_client_dataset(_rb,table):
    """ Sample from Reverb TFClient using dataset
    """
    def sample(n):
        dataset=_rb.dataset(table,
                            [tf.float64,tf.float64,tf.float64,tf.float64,tf.float64],
                            [4,1,1,4,1])
        return itertools.islice(dataset,n)
    return sample


# ReplayBuffer.add
perfplot.save(filename="ReplayBuffer_add2.png",
              setup = env,
              time_unit="ms",
              kernels = [add_client_insert(client,"ReplayBuffer"),
                         add_client(client,"ReplayBuffer"),
                         add_tf_client(tf_client,"ReplayBuffer"),
                         lambda e: rb.add(**e)],
              labels = ["DeepMind/Reverb: Client.insert",
                        "DeepMind/Reverb: Client.writer",
                        "DeepMind/Reverb: TFClient.insert",
                        "cpprb"],
              n_range = [n for n in range(1,102,10)],
              xlabel = "Step size added at once",
              title = "Replay Buffer Add Speed",
              logx = False,
              logy = False,
              equality_check = None)


# Fill Buffers
for _ in range(buffer_size):
    o = np.random.rand(obs_shape) # [0,1)
    a = np.random.rand(act_shape)
    r = np.random.rand(1)
    d = np.random.randint(2) # [0,2) == 0 or 1
    client.insert([o,a,r,o,d],priorities={"ReplayBuffer": 1.0})
    rb.add(obs=o,act=a,rew=r,next_obs=o,done=d)


# ReplayBuffer.sample
perfplot.save(filename="ReplayBuffer_sample2.png",
              setup = lambda n: n,
              time_unit="ms",
              kernels = [sample_client(client,"ReplayBuffer"),
                         sample_tf_client(tf_client,"ReplayBuffer"),
                         sample_tf_client_dataset(tf_client,"ReplayBuffer"),
                         rb.sample],
              labels = ["DeepMind/Reverb: Client.sample",
                        "DeepMind/Reverb: TFClient.sample",
                        "DeepMind/Reverb: TFClient.dataset",
                        "cpprb"],
              n_range = [2**n for n in range(1,8)],
              xlabel = "Batch size",
              title = "Replay Buffer Sample Speed",
              logx = False,
              logy = False,
              equality_check=None)


# PrioritizedReplayBuffer.add
perfplot.save(filename="PrioritizedReplayBuffer_add2.png",
              time_unit="ms",
              setup = env,
              kernels = [add_client_insert(client,"PrioritizedReplayBuffer"),
                         add_client(client,"PrioritizedReplayBuffer"),
                         add_tf_client(tf_client,"PrioritizedReplayBuffer"),
                         lambda e: prb.add(**e)],
              labels = ["DeepMind/Reverb: Client.insert",
                        "DeepMind/Reverb: Client.writer",
                        "DeepMind/Reverb: TFClient.insert",
                        "cpprb"],
              n_range = [n for n in range(1,102,10)],
              xlabel = "Step size added at once",
              title = "Prioritized Replay Buffer Add Speed",
              logx = False,
              logy = False,
              equality_check=None)


# Fill Buffers
for _ in range(buffer_size):
    o = np.random.rand(obs_shape) # [0,1)
    a = np.random.rand(act_shape)
    r = np.random.rand(1)
    d = np.random.randint(2) # [0,2) == 0 or 1
    p = np.random.rand(1)

    client.insert([o,a,r,o,d],priorities={"PrioritizedReplayBuffer": p})

    prb.add(obs=o,act=a,rew=r,next_obs=o,done=d,priority=p)


perfplot.save(filename="PrioritizedReplayBuffer_sample2.png",
              time_unit="ms",
              setup = lambda n: n,
              kernels = [sample_client(client,"PrioritizedReplayBuffer"),
                         sample_tf_client(tf_client,"PrioritizedReplayBuffer"),
                         sample_tf_client_dataset(tf_client,"PrioritizedReplayBuffer"),
                         lambda n: prb.sample(n,beta=beta)],
              labels = ["DeepMind/Reverb: Client.sample",
                        "DeepMind/Reverb: TFClient.sample",
                        "DeepMind/Reverb: TFClient.dataset",
                        "cpprb"],
              n_range = [2**n for n in range(1,9)],
              xlabel = "Batch size",
              title = "Prioritized Replay Buffer Sample Speed",
              logx=False,
              logy=False,
              equality_check=None)

結果は、以下の様になった。
(結果が古くなっているかもしれないので、最新版はcpprbのプロジェクトサイトで確認できる。)

ReplayBuffer_add2.png

ReplayBuffer_sample2.png

PrioritizedReplayBuffer_add2.png

PrioritizedReplayBuffer_sample2.png

7. おわりに

DeepMind製Experience ReplayフレームワークのReverbのクライアントの使い方を調査し整理した。
OpenAI/Baselinesに代表されるような他のリプレイ・バッファの実装と比べると、APIや利用方法が異なる部分も多くわかりにくいと感じた。
(安定版の公開時までに、もう少しわかりやすくなっていると良いですね。)

少なくとも、大規模分散学習を行ったり、強化学習のすべてをTensorFlowの計算グラフ内で完結させたりということをしなければ、性能面でも優れているわけではなさそうであった。

もちろん、大規模分散学習や計算グラフ内での強化学習はパフォーマンスを大幅に向上させる可能性があるので、引き続き検討を続ける必要があると思っている。

5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?