LoginSignup
9
12

More than 1 year has passed since last update.

【強化学習ライブラリ】RLlibの使い方

Posted at

RLlibとは

RLlibはPythonの分散実行ライブラリ「Ray」の1つのサブパッケージであり、強化学習用のオープンソースライブラリです。

RLlibではかなり多くのアルゴリズムを自由度高く用いることができます。
Available Algorithms - Overview にあるように、現在 DQN, PPO, R2D2, AlphaZero, Dreamerなどなど多くが実装されており、日々アップデートされています。

深層学習フレームワークとして、TensorFlow(1.xと2.xの両方)と PyTorch が用いられており、ほとんどのアルゴリズムではどちらのフレームワークでも実行することができます。

多くのことが実装されている一方、コードはかなり複雑でドキュメントもそれほど親切ではないので、強化学習を利用していても使ったことがない人は多いのではないかと思います。

そこで今回は、自前の環境やモデルを用いて学習を実行する方法を簡単に説明します。
この記事ではPyTorchを使用する場合を想定していますがTensorFlowを用いる場合でもほとんど同じ要領でできます。

インストール

pip install -U "ray[rllib]" で簡単にインストールすることができますが、コードを自分でいじることが多くなると思うので、ソースコードからbuildすることをおすすめします。公式のドキュメントに書いてありますがここでも簡単に説明します。

まず、以下にあるversionから適したものをダウンロードしてきます。

次にダウンロードした .whl を用いてインストールします。MacOS, Python3.8の場合以下のようになります。

$ pip install -U ray-2.0.0.dev0-cp38-cp38-macosx_10_15_x86_64.whl

rayリポジトリをforkし、cloneします。

$ git clone https://github.com/[your username]/ray.git
$ cd ray
$ git remote add upstream https://github.com/ray-project/ray.git
# Make sure you are up-to-date on master.

最後にインストールしたrayパッケージをcloneしてきたコードで置き換えます。

# This replaces `<package path>/site-packages/ray/<package>`
# with your local `ray/python/ray/<package>`.
$ python python/ray/setup-dev.py

学習実行

$ rllib train --run DQN --env CartPole-v0

といったようにコマンドで簡単に実行することができます。

ただ、カスタマイズしやすいようにPythonスクリプトで実行します。
以下がDQNでAtariのBreakoutを学習させる場合のコードです。

dqn.py
# from ray.rllib.agents.dqn import R2D2Trainer
from ray.rllib.agents.dqn import DQNTrainer

config = {
        'framework': 'torch',   # 'torch' or 'tf' or 'tf2'
        'env': 'BreakoutNoFrameskip-v4',
        'dueling': True,
        'lr': 0.0005,
        'exploration_config':
            {
                'epsilon_timesteps': 50000,
            },
        'model':
            {
                'fcnet_hiddens': [64],
                'fcnet_activation': 'linear',
             },
}

trainer = DQNTrainer(config=config)

# モデル構造確認
policy = trainer.get_policy()
print(policy.model)

# Run it for n training iterations. A training iteration includes
# parallel sample collection by the environment workers as well as
# loss calculation on the collected batch and a model update.
for _ in range(3):
    trainer.train()

# Evaluate the trained Trainer (and render each timestep to the shell's
# output).
trainer.evaluate()

使いたいアルゴリズムのTrainerimportしてきて、configを渡し、任意の数だけ学習イテレーションを回しています。

カスタマイズできるconfigは

Common Parameters

で確認できます。

アルゴリズム固有のconfigは

Available Algorithms - Overview

でアルゴリズムを選択すると確認できます。

結果確認

学習のログや結果は ~/ray_results/ 以下にtensorboard のファイルとして保存されます。以下を実行し http://localhost:6006/ (デフォルト) にアクセスすることで確認できます。

$ tensorboard --logdir=~/ray_results

モデルをカスタマイズ

configでカスタマイズ

モデル構造を簡単にカスタマイズしたい場合には、configで値を指定します。
指定できるconfigは

Default Model Config Settings

にまとまっています。具体的には以下のようなconfigを書き、Trainerに渡します。

.py
config = {
    'model':
        {
            'fcnet_hiddens': [64],
            'fcnet_activation': 'linear',
            'use_lstm': True,
            'dim': 6,  # input image size
            "conv_filters": [[32, [3, 3], 2], [64, [2, 2], 1], [64, [2, 2], 1]],  # [output_channels, [kernel_w, kernel_h], stride]
        },
}

classを定義してカスタマイズ

自前のモデルclassを定義して利用する際は、TorchModelV2 のサブクラスとして定義し __init__(), forward() を実装する必要があります。
また、価値関数を表す value_functionも自前で上書きすることができます。

動かす場合には以下のようにconfigを定義してTrainerに渡します。

class CustomTorchModel(TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name): ...
    def forward(self, input_dict, state, seq_lens): ...
    def value_function(self): ...


config={
    "framework": "torch",
    "model": {
        "custom_model": CustomTorchModel,
        # Extra kwargs to be passed to your model's c'tor.
        "custom_model_config": {},
    },
}

画像のような3次元データを入力とした場合のデフォルトのネットワークが

で定義されているので、これを参考にカスタマイズするのが良いと思います。

参考
Custom PyTorch Models

RNNの利用

configのmodelのところでuse_lstm: Trueuse_attension: TrueとすることでLSTMやAttensionネットワークを使用することができます。

ただ、R2D2など特定のアルゴリズムではclassで定義した自前のモデルと組み合わせた場合に動かないがあります。そういう場合には根気強くデバッグするか他のリポジトリを利用するのが良いと思います。

参考
Wrapping a Custom Model (TF and PyTorch) with an LSTM- or Attention Net
Implementing custom Recurrent Networks

環境をカスタマイズ

オリジナルの環境を使う場合、以下のような環境のclassを定義します。

import gym

class MyEnv(gym.Env):
    def __init__(self, env_config):
        self.action_space = <gym.Space>
        self.observation_space = <gym.Space>
    def reset(self):
        return <obs>
    def step(self, action):
        return <obs>, <reward: float>, <done: bool>, <info: dict>

行動と状態の型は、OpenAI Gymの型で定義する必要があり、obsはself.observation_spaceに含まれる必要があります。self.observation_space.contains(obs) で確認できます。

作成した環境のclassをconfigで指定することで使用できます。簡単なカスタマイズされた環境を用いた例が以下になります。

custom_env.py
import gym
from ray.rllib.agents.dqn import DQNTrainer

class MyEnv(gym.Env):
    def __init__(self, env_config):
        self.img_size = 6
        self.i = 0

        self.action_space = gym.spaces.Discrete(10)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(5, self.img_size, self.img_size))

    def reset(self):
        self.i = 0
        obs = np.zeros((5, self.img_size, self.img_size))
        return obs

    def step(self, action):
        self.i += 1
        obs = np.ones((5, self.img_size, self.img_size))
        reward = 0
        done = self.i == 10
        info = {}
        return obs, reward, done, info


config = {
        'framework': 'torch',
        'env': MyEnv,
        'lr': 0.0005,
        'exploration_config':
            {
                'epsilon_timesteps': 50000,
            },
        'model':
            {
                'fcnet_hiddens': [64],
                'fcnet_activation': 'linear',
                'dim': 6,
                "conv_filters": [[32, [3, 3], 2], [64, [3, 3], 2], [64, [2, 2], 1]],
             },
}

trainer = DQNTrainer(config=config)

for _ in range(3):
    trainer.train()

trainer.evaluate()

参考
Configuring Environments

おわりに

他にrllibでできることとして

  • 状態によって取りうる行動が変わる場合のモデル定義
  • Lossのカスタマイズ
  • Policyのカスタマイズ
  • マルチエージェントでの学習
  • オフラインRL
  • 模倣学習

などなど、かなり多くのことが実装されています。
ただ、これらを利用するにはかなりのコード理解が必要になる上に組み合わせによっては動かないことがあるので注意が必要です。

個人的には、いくつかのアルゴリズムで比較検証したい場合にはRLlibを用いるのが良くて、特定のアルゴリズムのみ用いる場合や既存手法をカスタマイズしたい場合には別のリポジトリを使うか自分で実装するのが良いと感じています。

9
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
12