25
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

chainerRLを使ってみる

Last updated at Posted at 2017-03-02

##はじめに
先日(2017/2/16)にchainerから深層強化学習ライブラリChainerRLがリリースされた。
https://github.com/pfnet/chainerrl
これまで頑張ってコード書いてた身としては悲しい出来事であるが、時代の流れには逆らえない。むしろ早々と使いこなすことで時代に喰らい着いていきたい。

今回はQickStartに従ってsampleを動かす。
##環境
OS:Ubuntu14.04
GPU:GTX1070
CUDA:8.0 RC
cuDNN:5.1
python:2.7.6
chainer:1.20.0.1
など

gym等、必要なライブラリはインストール済み
##インストール
上記GitHubのREAME.mdに従ってインストールする。

sudo pip install chainerrl

一瞬で完了。
##Quick Startを試す
次にここ
https://github.com/pfnet/chainerrl/blob/master/examples/quickstart/quickstart.ipynb
に従って、quickstartを試す。以下、train.pyに必要なコードを入力していく。まずはimportするパッケージ。

train.py
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np

今回は'CartPole-v0'というgymでよくテストに使われるゲームを試す。

train.py
env = gym.make('CartPole-v0')
print('observation space:', env.observation_space)
print('action space:', env.action_space)

obs = env.reset()
env.render()
print('initial observation:', obs)

action = env.action_space.sample()
obs, r, done, info = env.step(action)
print('next observation:', obs)
print('reward:', r)
print('done:', done)
print('info:', info)

これを実行すると

python train.py
[2017-03-03 03:33:00,118] Making new env: CartPole-v0
('observation space:', Box(4,))
('action space:', Discrete(2))
('initial observation:', array([ 0.0428945 , -0.00220352, -0.04834014, -0.04557467]))
('next observation:', array([ 0.04285043, -0.19660017, -0.04925163,  0.23147321]))
('reward:', 1.0)
('done:', False)
('info:', {})

こんな感じで表示された。observationは4つの要素の配列。actionのspaceがDiscrete(2)となってるので、行動は2種類っぽい。報酬はスカラー。

env.resetで環境を初期化して初期状態のobservationを得る。env.stepは行動を与えて次の状態のobservation、報酬、などなどを返す。強化学習の1Stepね。
##agentなどを設定する
次に学習させる。まず通常のchainerと同じくクラスでアーキテクチャを定義する。

train.py
class QFunction(chainer.Chain):

    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super(QFunctin, self).__init__(
            l0=L.Linear(obs_size, n_hidden_channels),
            l1=L.Linear(n_hidden_channels, n_hidden_channels),
            l2=L.Linear(n_hidden_channels, n_actions))

    def __call__(self, x, test=False):
        """
        Args:
            x (ndarray or chainer.Variable): An observation
            test (bool): a flag indicating whether it is in test mode
        """
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)
q_func.to_qpu()

親クラスの初期化の部分だけpython2系ように書き換えた。

クラス内ではcall関数でchainerrl.action_value.DiscreteActionValue(self.l2(h))を返すところだけが通常のDLと違う。observationのサイズや行動の数を取得した後、クラスをオブジェクト化。最後にGPUに送る。

次にoptimizerの設定。

train.py
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

なぜかAdam。DQN論文ではRMSPropだったような・・・。次に他のパラメータを設定する。

train.py
# Set the discount factor that discounts future rewards.
gamma = 0.95

# Use epsilon-greedy for exploration
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)

# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)

# Since observations from CartPole-v0 is numpy.float64 while
# Chainer only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x: x.astype(np.float32, copy=False)

# Now create an agent that will interact with the environment.
agent = chainerrl.agents.DoubleDQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size=500, update_frequency=1,
    target_update_frequency=100, phi=phi)

gammaは報酬の割引率。イプシロンは今回0.3に固定。replay memoryは10の6乗に設定。phiはゲームから来るobservationがfloat32でない場合に備えた関数。アーキテクチャは今回DoubleDQN。ちなみにchainerrl/agents/を見ると、doubleDQNの他にA3CやACERなど主要なモデルの幾つかが実装されている。

replay_start_sizeはreplay memoryにobservationの組が500個溜まってから学習を始めるという意味か?target_update_frequencyはtargetの方のパラメータにコピーする頻度だろう。

これでagentと環境の設定は完了。
##学習させる
次にループを組んで学習させる。

train.py
n_episodes = 200
max_episode_len = 200
for i in range(1, n_episodes + 1):
    obs = env.reset()
    reward = 0
    done = False
    R = 0  # return (sum of rewards)
    t = 0  # time step
    while not done and t < max_episode_len:
        # Uncomment to watch the behaviour
        # env.render()
        action = agent.act_and_train(obs, reward)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
    if i % 10 == 0:
        print('episode:', i,
              'R:', R,
              'statistics:', agent.get_statistics())
    agent.stop_episode_and_train(obs, reward, done)
print('Finished.')

200ステップのゲームを200エピソード行う。

python train.py
[2017-03-03 04:26:10,092] Making new env: CartPole-v0
('observation space:', Box(4,))
('action space:', Discrete(2))
('initial observation:', array([ 0.00686075, -0.00872222,  0.0262407 , -0.00072111]))
('next observation:', array([ 0.0066863 , -0.20421048,  0.02622628,  0.30012421]))
('reward:', 1.0)
('done:', False)
('info:', {})
('episode:', 10, 'R:', 101.0, 'statistics:', [(u'average_q', 3.839410603590915), (u'average_loss', 0.03739733160654156)])
('episode:', 20, 'R:', 126.0, 'statistics:', [(u'average_q', 8.168648753284264), (u'average_loss', 0.07854945771670627)])
('episode:', 30, 'R:', 143.0, 'statistics:', [(u'average_q', 12.791942632038957), (u'average_loss', 0.08622470318985226)])
......
......
......
('episode:', 190, 'R:', 200.0, 'statistics:', [(u'average_q', 19.98455506145719), (u'average_loss', 0.0627547477493958)])
('episode:', 200, 'R:', 105.0, 'statistics:', [(u'average_q', 20.00981249193506), (u'average_loss', 0.04760965162900658)])
Finished.

~~なんか収益増えてない気がするけど・・・。~~コメント欄の開発者様の御指摘通り、収益のmaxは200なので、190episodeなどはmaxにいっている。

##テストする
次に、トレーニングしたモデルをテストする。トレーニング部分のコードに続いて以下のコードを足す。

train.py
for i in range(10):
    obs = env.reset()
    done = False
    R = 0
    t = 0
    while not done and t < 200:
        env.render()
        action = agent.act(obs)
        obs, r, done, _ = env.step(action)
        R += r
        t += 1
    print('test episode:', i, 'R:', R)
    agent.stop_episode()

これを走らせると以下のような結果となった。

python train.py
.....
.....
.....
.....
('episode:', 190, 'R:', 11.0, 'statistics:', [(u'average_q', 20.311043770645103), (u'average_loss', 0.09018823270227477)])
('episode:', 200, 'R:', 196.0, 'statistics:', [(u'average_q', 20.352799336256385), (u'average_loss', 0.05998000704222223)])
training Finished.
('test episode:', 0, 'R:', 200.0)
('test episode:', 1, 'R:', 200.0)
('test episode:', 2, 'R:', 200.0)
('test episode:', 3, 'R:', 200.0)
('test episode:', 4, 'R:', 200.0)
('test episode:', 5, 'R:', 200.0)
('test episode:', 6, 'R:', 200.0)
('test episode:', 7, 'R:', 200.0)
('test episode:', 8, 'R:', 200.0)
('test episode:', 9, 'R:', 200.0)
test Finished.

ことごとくmax収益をたたき出している。ちなみに、trainingさせずにtestだけさせてみた結果が以下。

python test_1.py
.....
.....
.....
.....
('reward:', 1.0)
('done:', False)
('info:', {})
('test episode:', 0, 'R:', 14.0)
('test episode:', 1, 'R:', 15.0)
('test episode:', 2, 'R:', 16.0)
('test episode:', 3, 'R:', 13.0)
('test episode:', 4, 'R:', 15.0)
('test episode:', 5, 'R:', 16.0)
('test episode:', 6, 'R:', 16.0)
('test episode:', 7, 'R:', 15.0)
('test episode:', 8, 'R:', 15.0)
('test episode:', 9, 'R:', 17.0)
test Finished.

学習済みと全然違うね。ちなみに、test.pyという名前のファイルだと何故かエラーとなるので、使わない方がいい。
##セーブとロード
以下のコードでセーブができる。

# Save an agent to the 'agent' directory
agent.save('agent')

ロードはこれ。

# Uncomment to load an agent from the 'agent' directory
agent.load('agent')

しかし、これはagentをセーブするだけでは?例えば、パソコンが固まった時のために定期的に学習をセーブする場合、replay memoryもセーブする必要があると思うが、そちらはセーブされないのね?

25
24
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?