0
4

More than 3 years have passed since last update.

強化学習6 初めてのChainerRL

Last updated at Posted at 2019-11-12

ChainerRLクイックリファレンス
https://chainer-colab-notebook.readthedocs.io/ja/latest/notebook/hands_on/chainerrl/quickstart.html

強化学習5まで終了しているのが前提です。
クイックリファレンスを参照しながら、以下のファイルを作ります。

train.py
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
env = gym.make('CartPole-v0')
print('observation space:', env.observation_space)
print('action space:', env.action_space)

obs = env.reset()
env.render()
print('initial observation:', obs)

action = env.action_space.sample()
obs, r, done, info = env.step(action)
print('next observation:', obs)
print('reward:', r)
print('done:', done)
print('info:', info)

class QFunction(chainer.Chain):

    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super().__init__()
        with self.init_scope():
            self.l0 = L.Linear(obs_size, n_hidden_channels)
            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_actions)

    def __call__(self, x, test=False):
        """
        Args:
            x (ndarray or chainer.Variable): An observation
            test (bool): a flag indicating whether it is in test mode
        """
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)

optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

# Set the discount factor that discounts future rewards.
gamma = 0.95

# Use epsilon-greedy for exploration
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)

# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)

# Since observations from CartPole-v0 is numpy.float64 while
# Chainer only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x: x.astype(np.float32, copy=False)

# Now create an agent that will interact with the environment.
agent = chainerrl.agents.DoubleDQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size=500, update_interval=1,
    target_update_interval=100, phi=phi)

n_episodes = 200
max_episode_len = 200
for i in range(1, n_episodes + 1):
    obs = env.reset()
    reward = 0
    done = False
    R = 0  # return (sum of rewards)
    t = 0  # time step
    while not done and t < max_episode_len:
        # Uncomment to watch the behaviour
        # env.render()
        action = agent.act_and_train(obs, reward)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
    if i % 10 == 0:
        print('episode:', i,
              'R:', R,
              'statistics:', agent.get_statistics())
    agent.stop_episode_and_train(obs, reward, done)
print('Finished.')
agent.save('agent')
env.close()

※env.close()を入れないとエラー表示が出ます。

ちゃんと動けば、以下のように出力されます。

observation space: Box(4,)
action space: Discrete(2)
initial observation: [-0.04736688 -0.01970095 -0.00356997  0.01937746]
next observation: [-0.0477609   0.17547202 -0.00318242 -0.27442969]
reward: 1.0
done: False
info: {}
episode: 10 R: 54.0 statistics: [('average_q', 0.02431227437087797), ('average_loss', 0), ('n_updates', 0)]
episode: 20 R: 38.0 statistics: [('average_q', 0.6243046798922441), ('average_loss', 0.08867046155807262), ('n_updates', 378)]
episode: 30 R: 46.0 statistics: [('average_q', 2.2610338271644586), ('average_loss', 0.09550022600040467), ('n_updates', 784)]
episode: 40 R: 84.0 statistics: [('average_q', 5.323362298387771), ('average_loss', 0.16771472656243475), ('n_updates', 1399)]
episode: 50 R: 91.0 statistics: [('average_q', 9.851513830734694), ('average_loss', 0.19145745620246343), ('n_updates', 2351)]
episode: 60 R: 99.0 statistics: [('average_q', 14.207080180752635), ('average_loss', 0.22097823899753388), ('n_updates', 3584)]
episode: 70 R: 200.0 statistics: [('average_q', 17.49337381852232), ('average_loss', 0.18525375351216344), ('n_updates', 5285)]
episode: 80 R: 124.0 statistics: [('average_q', 18.933387631649587), ('average_loss', 0.1511605453710412), ('n_updates', 7063)]
episode: 90 R: 200.0 statistics: [('average_q', 19.55727598346719), ('average_loss', 0.167370220872378), ('n_updates', 8496)]
episode: 100 R: 200.0 statistics: [('average_q', 19.92113421424675), ('average_loss', 0.15092426599174535), ('n_updates', 10351)]
episode: 110 R: 161.0 statistics: [('average_q', 19.870179660112395), ('average_loss', 0.1369066775700466), ('n_updates', 12169)]
episode: 120 R: 200.0 statistics: [('average_q', 19.985680296882315), ('average_loss', 0.13667809001004586), ('n_updates', 13991)]
episode: 130 R: 200.0 statistics: [('average_q', 20.016279858512945), ('average_loss', 0.14053696154447365), ('n_updates', 15938)]
episode: 140 R: 180.0 statistics: [('average_q', 19.870299413261478), ('average_loss', 0.1270716956269478), ('n_updates', 17593)]
episode: 150 R: 200.0 statistics: [('average_q', 19.990808581945565), ('average_loss', 0.1228807602095278), ('n_updates', 19442)]
episode: 160 R: 130.0 statistics: [('average_q', 19.954955203815164), ('average_loss', 0.14701205384726732), ('n_updates', 21169)]
episode: 170 R: 133.0 statistics: [('average_q', 19.994069560095422), ('average_loss', 0.12502104946859763), ('n_updates', 22709)]
episode: 180 R: 200.0 statistics: [('average_q', 19.973195015705674), ('average_loss', 0.1227321977377075), ('n_updates', 24522)]
episode: 190 R: 200.0 statistics: [('average_q', 20.050942533128573), ('average_loss', 0.09264820379188309), ('n_updates', 26335)]
episode: 200 R: 191.0 statistics: [('average_q', 19.81062306392066), ('average_loss', 0.11778217212419012), ('n_updates', 28248)]
Finished.
0
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
4