1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

強化学習22 Colaboratory+CartPole+ChainerRL+A3C

Last updated at Posted at 2019-11-28

強化学習21まで終了していることが前提です。
A3Cは、
Asynchronous Advantage Actor-Critic
の略です。
詳しい説明は、こちらをどうぞ。
【強化学習】実装しながら学ぶA3C【CartPoleで棒立て:1ファイルで完結】https://qiita.com/sugulu/items/acbc909dd9b74b043e45

21と同じように、chainerRLをそのままnotebookにしました。
それなりに時間がかかり、90分ルールにひっかかるので、スモールサイズでやりました。

#Google drive mount

import google.colab.drive
google.colab.drive.mount('gdrive')
!ln -s gdrive/My\ Drive mydrive

#program install

!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install pyvirtualdisplay > /dev/null 2>&1
!pip -q install JSAnimation
!pip -q install chainerrl

#Main program
An example of training A3C against OpenAI Gym Envs.

This script is an example of training a A3C agent against OpenAI Gym envs.
Both discrete and continuous action spaces are supported.

modules import


from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import *  # NOQA
from future import standard_library
standard_library.install_aliases()  # NOQA
import argparse
import os
import sys


import chainer
from chainer import functions as F
from chainer import links as L
import gym
import numpy as np

import chainerrl
from chainerrl.agents import a3c
from chainerrl import experiments
from chainerrl import links
from chainerrl import misc
from chainerrl.optimizers.nonbias_weight_decay import NonbiasWeightDecay
from chainerrl.optimizers import rmsprop_async
from chainerrl import policies
from chainerrl.recurrent import RecurrentChainMixin
from chainerrl import v_function

Class A3CFFSoftmax

An example of A3C feedforward softmax policy.


class A3CFFSoftmax(chainer.ChainList, a3c.A3CModel):
    def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
        self.pi = policies.SoftmaxPolicy(
            model=links.MLP(ndim_obs, n_actions, hidden_sizes))
        self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
        super().__init__(self.pi, self.v)

    def pi_and_v(self, state):
        return self.pi(state), self.v(state)

##Class A3CFFMellowmax
An example of A3C feedforward mellowmax policy.


class A3CFFMellowmax(chainer.ChainList, a3c.A3CModel):
    def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
        self.pi = policies.MellowmaxPolicy(
            model=links.MLP(ndim_obs, n_actions, hidden_sizes))
        self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
        super().__init__(self.pi, self.v)

    def pi_and_v(self, state):
        return self.pi(state), self.v(state)

##Class A3CLSTMGaussian
An example of A3C recurrent Gaussian policy.


class A3CLSTMGaussian(chainer.ChainList, a3c.A3CModel, RecurrentChainMixin):
    def __init__(self, obs_size, action_size, hidden_size=200, lstm_size=128):
        self.pi_head = L.Linear(obs_size, hidden_size)
        self.v_head = L.Linear(obs_size, hidden_size)
        self.pi_lstm = L.LSTM(hidden_size, lstm_size)
        self.v_lstm = L.LSTM(hidden_size, lstm_size)
        self.pi = policies.FCGaussianPolicy(lstm_size, action_size)
        self.v = v_function.FCVFunction(lstm_size)
        super().__init__(self.pi_head, self.v_head,
                         self.pi_lstm, self.v_lstm, self.pi, self.v)

    def pi_and_v(self, state):

        def forward(head, lstm, tail):
            h = F.relu(head(state))
            h = lstm(h)
            return tail(h)

        pout = forward(self.pi_head, self.pi_lstm, self.pi)
        vout = forward(self.v_head, self.v_lstm, self.v)

        return pout, vout

##Main

###args


import logging

parser = argparse.ArgumentParser()
parser.add_argument('--processes', type=int,default=8)
parser.add_argument('--env', type=str, default='CartPole-v0')
parser.add_argument('--arch', type=str, default='FFSoftmax',choices=('FFSoftmax', 'FFMellowmax', 'LSTMGaussian'))
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--outdir', type=str, default='mydrive/OpenAI/CartPole/result-a3c')
parser.add_argument('--t-max', type=int, default=5)
parser.add_argument('--beta', type=float, default=1e-2)    
parser.add_argument('--profile', action='store_true')
parser.add_argument('--steps', type=int, default=8 * 10 ** 7)
parser.add_argument('--eval-interval', type=int, default=10 ** 5)
parser.add_argument('--eval-n-runs', type=int, default=10)
parser.add_argument('--reward-scale-factor', type=float, default=1e-2)
parser.add_argument('--rmsprop-epsilon', type=float, default=1e-1)
parser.add_argument('--render', action='store_true', default=False)
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--weight-decay', type=float, default=0.0)
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default='')
parser.add_argument('--logger-level', type=int, default=logging.INFO)
parser.add_argument('--monitor', action='store_true')

変更したいところは、

args =parser.parse_args([--env].[CartPole-v0'])

のようにする。


args = parser.parse_args(['--steps','300000','--eval-interval','10000'])
logging.basicConfig(level=args.logger_level, stream=sys.stdout, format='')

Set a random seed used in ChainerRL.

If you use more than one processes, the results will be no longer

deterministic even with the same random seed.


misc.set_random_seed(args.seed)

Set different random seeds for different subprocesses.

If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].

If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].


process_seeds = np.arange(args.processes) + args.seed * args.processes
assert process_seeds.max() < 2 ** 32
if not os.path.exists(args.outdir):
  os.makedirs(args.outdir)

###function


def make_env(process_idx, test):
    env = gym.make(args.env)
    # Use different random seeds for train and test envs
    process_seed = int(process_seeds[process_idx])
    env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
    env.seed(env_seed)
    # Cast observations to float32 because our model uses float32
    env = chainerrl.wrappers.CastObservationToFloat32(env)
    if args.monitor and process_idx == 0:
        env = chainerrl.wrappers.Monitor(env, args.outdir)
    if not test:
        # Scale rewards (and thus returns) to a reasonable range so that
        # training is easier
        env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
    if args.render and process_idx == 0 and not test:
        env = chainerrl.wrappers.Render(env)
    return env

actionのタイプでモデルを選択します。


sample_env = gym.make(args.env)
timestep_limit = sample_env.spec.tags.get(
    'wrapper_config.TimeLimit.max_episode_steps')
obs_space = sample_env.observation_space
action_space = sample_env.action_space

# Switch policy types accordingly to action space types
if args.arch == 'LSTMGaussian':
    model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size)
elif args.arch == 'FFSoftmax':
    model = A3CFFSoftmax(obs_space.low.size, action_space.n)
elif args.arch == 'FFMellowmax':
    model = A3CFFMellowmax(obs_space.low.size, action_space.n)

###optimizer


opt = rmsprop_async.RMSpropAsync(
    lr=args.lr, eps=args.rmsprop_epsilon, alpha=0.99)
opt.setup(model)
opt.add_hook(chainer.optimizer.GradientClipping(40))
if args.weight_decay > 0:
    opt.add_hook(NonbiasWeightDecay(args.weight_decay))

###Agent


agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99,
                beta=args.beta)
if args.load:
    agent.load(args.load)

###train


experiments.train_agent_async(
    agent=agent,
    outdir=args.outdir,
    processes=args.processes,
    make_env=make_env,
    profile=args.profile,
    steps=args.steps,
    eval_n_steps=None,
    eval_n_episodes=args.eval_n_runs,
    eval_interval=args.eval_interval,
    max_episode_len=timestep_limit)

agent.save(args.outdir+'/agent')

import pandas as pd
import glob
import os
score_files = glob.glob(args.outdir+'/scores.txt')
score_files.sort(key=os.path.getmtime)
score_file = score_files[-1]
df = pd.read_csv(score_file, delimiter='\t' )
df

df.plot(x='steps',y='average_value')

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1024, 768))
display.start()

from JSAnimation.IPython_display import display_animation
from matplotlib import animation
import matplotlib.pyplot as plt
%matplotlib inline

frames = []
env = gym.make(args.env)

process_seeds = np.arange(args.processes) + args.seed  * args.processes
assert process_seeds.max() < 2 ** 32
env_seed = int(process_seeds[0])
env.seed(env_seed)
env = chainerrl.wrappers.CastObservationToFloat32(env)
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)

envw = gym.wrappers.Monitor(env, args.outdir, force=True)

for i in range(3):
    obs = envw.reset()
    done = False
    R = 0
    t = 0
    while not done and t < 200:
        frames.append(envw.render(mode = 'rgb_array'))
        action = agent.act(obs)
        obs, r, done, _ = envw.step(action)
        R += r
        t += 1
    print('test episode:', i, 'R:', R)
    agent.stop_episode()
envw.close()

from IPython.display import HTML
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),dpi=72)
patch = plt.imshow(frames[0])
plt.axis('off') 
def animate(i):
    patch.set_data(frames[i])
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),interval=50)
anim.save(args.outdir+'/test.mp4')
HTML(anim.to_jshtml())
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?