強化学習21まで終了していることが前提です。
A3Cは、
Asynchronous Advantage Actor-Critic
の略です。
詳しい説明は、こちらをどうぞ。
【強化学習】実装しながら学ぶA3C【CartPoleで棒立て:1ファイルで完結】https://qiita.com/sugulu/items/acbc909dd9b74b043e45
21と同じように、chainerRLをそのままnotebookにしました。
それなりに時間がかかり、90分ルールにひっかかるので、スモールサイズでやりました。
#Google drive mount
import google.colab.drive
google.colab.drive.mount('gdrive')
!ln -s gdrive/My\ Drive mydrive
#program install
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install pyvirtualdisplay > /dev/null 2>&1
!pip -q install JSAnimation
!pip -q install chainerrl
#Main program
An example of training A3C against OpenAI Gym Envs.
This script is an example of training a A3C agent against OpenAI Gym envs.
Both discrete and continuous action spaces are supported.
modules import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA
import argparse
import os
import sys
import chainer
from chainer import functions as F
from chainer import links as L
import gym
import numpy as np
import chainerrl
from chainerrl.agents import a3c
from chainerrl import experiments
from chainerrl import links
from chainerrl import misc
from chainerrl.optimizers.nonbias_weight_decay import NonbiasWeightDecay
from chainerrl.optimizers import rmsprop_async
from chainerrl import policies
from chainerrl.recurrent import RecurrentChainMixin
from chainerrl import v_function
Class A3CFFSoftmax
An example of A3C feedforward softmax policy.
class A3CFFSoftmax(chainer.ChainList, a3c.A3CModel):
def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
self.pi = policies.SoftmaxPolicy(
model=links.MLP(ndim_obs, n_actions, hidden_sizes))
self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
super().__init__(self.pi, self.v)
def pi_and_v(self, state):
return self.pi(state), self.v(state)
##Class A3CFFMellowmax
An example of A3C feedforward mellowmax policy.
class A3CFFMellowmax(chainer.ChainList, a3c.A3CModel):
def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
self.pi = policies.MellowmaxPolicy(
model=links.MLP(ndim_obs, n_actions, hidden_sizes))
self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
super().__init__(self.pi, self.v)
def pi_and_v(self, state):
return self.pi(state), self.v(state)
##Class A3CLSTMGaussian
An example of A3C recurrent Gaussian policy.
class A3CLSTMGaussian(chainer.ChainList, a3c.A3CModel, RecurrentChainMixin):
def __init__(self, obs_size, action_size, hidden_size=200, lstm_size=128):
self.pi_head = L.Linear(obs_size, hidden_size)
self.v_head = L.Linear(obs_size, hidden_size)
self.pi_lstm = L.LSTM(hidden_size, lstm_size)
self.v_lstm = L.LSTM(hidden_size, lstm_size)
self.pi = policies.FCGaussianPolicy(lstm_size, action_size)
self.v = v_function.FCVFunction(lstm_size)
super().__init__(self.pi_head, self.v_head,
self.pi_lstm, self.v_lstm, self.pi, self.v)
def pi_and_v(self, state):
def forward(head, lstm, tail):
h = F.relu(head(state))
h = lstm(h)
return tail(h)
pout = forward(self.pi_head, self.pi_lstm, self.pi)
vout = forward(self.v_head, self.v_lstm, self.v)
return pout, vout
##Main
###args
import logging
parser = argparse.ArgumentParser()
parser.add_argument('--processes', type=int,default=8)
parser.add_argument('--env', type=str, default='CartPole-v0')
parser.add_argument('--arch', type=str, default='FFSoftmax',choices=('FFSoftmax', 'FFMellowmax', 'LSTMGaussian'))
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--outdir', type=str, default='mydrive/OpenAI/CartPole/result-a3c')
parser.add_argument('--t-max', type=int, default=5)
parser.add_argument('--beta', type=float, default=1e-2)
parser.add_argument('--profile', action='store_true')
parser.add_argument('--steps', type=int, default=8 * 10 ** 7)
parser.add_argument('--eval-interval', type=int, default=10 ** 5)
parser.add_argument('--eval-n-runs', type=int, default=10)
parser.add_argument('--reward-scale-factor', type=float, default=1e-2)
parser.add_argument('--rmsprop-epsilon', type=float, default=1e-1)
parser.add_argument('--render', action='store_true', default=False)
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--weight-decay', type=float, default=0.0)
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default='')
parser.add_argument('--logger-level', type=int, default=logging.INFO)
parser.add_argument('--monitor', action='store_true')
変更したいところは、
args =parser.parse_args([--env].[CartPole-v0'])
のようにする。
args = parser.parse_args(['--steps','300000','--eval-interval','10000'])
logging.basicConfig(level=args.logger_level, stream=sys.stdout, format='')
Set a random seed used in ChainerRL.
If you use more than one processes, the results will be no longer
deterministic even with the same random seed.
misc.set_random_seed(args.seed)
Set different random seeds for different subprocesses.
If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
process_seeds = np.arange(args.processes) + args.seed * args.processes
assert process_seeds.max() < 2 ** 32
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
###function
def make_env(process_idx, test):
env = gym.make(args.env)
# Use different random seeds for train and test envs
process_seed = int(process_seeds[process_idx])
env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
env.seed(env_seed)
# Cast observations to float32 because our model uses float32
env = chainerrl.wrappers.CastObservationToFloat32(env)
if args.monitor and process_idx == 0:
env = chainerrl.wrappers.Monitor(env, args.outdir)
if not test:
# Scale rewards (and thus returns) to a reasonable range so that
# training is easier
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
if args.render and process_idx == 0 and not test:
env = chainerrl.wrappers.Render(env)
return env
actionのタイプでモデルを選択します。
sample_env = gym.make(args.env)
timestep_limit = sample_env.spec.tags.get(
'wrapper_config.TimeLimit.max_episode_steps')
obs_space = sample_env.observation_space
action_space = sample_env.action_space
# Switch policy types accordingly to action space types
if args.arch == 'LSTMGaussian':
model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size)
elif args.arch == 'FFSoftmax':
model = A3CFFSoftmax(obs_space.low.size, action_space.n)
elif args.arch == 'FFMellowmax':
model = A3CFFMellowmax(obs_space.low.size, action_space.n)
###optimizer
opt = rmsprop_async.RMSpropAsync(
lr=args.lr, eps=args.rmsprop_epsilon, alpha=0.99)
opt.setup(model)
opt.add_hook(chainer.optimizer.GradientClipping(40))
if args.weight_decay > 0:
opt.add_hook(NonbiasWeightDecay(args.weight_decay))
###Agent
agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99,
beta=args.beta)
if args.load:
agent.load(args.load)
###train
experiments.train_agent_async(
agent=agent,
outdir=args.outdir,
processes=args.processes,
make_env=make_env,
profile=args.profile,
steps=args.steps,
eval_n_steps=None,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
max_episode_len=timestep_limit)
agent.save(args.outdir+'/agent')
import pandas as pd
import glob
import os
score_files = glob.glob(args.outdir+'/scores.txt')
score_files.sort(key=os.path.getmtime)
score_file = score_files[-1]
df = pd.read_csv(score_file, delimiter='\t' )
df
df.plot(x='steps',y='average_value')
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1024, 768))
display.start()
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
import matplotlib.pyplot as plt
%matplotlib inline
frames = []
env = gym.make(args.env)
process_seeds = np.arange(args.processes) + args.seed * args.processes
assert process_seeds.max() < 2 ** 32
env_seed = int(process_seeds[0])
env.seed(env_seed)
env = chainerrl.wrappers.CastObservationToFloat32(env)
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
envw = gym.wrappers.Monitor(env, args.outdir, force=True)
for i in range(3):
obs = envw.reset()
done = False
R = 0
t = 0
while not done and t < 200:
frames.append(envw.render(mode = 'rgb_array'))
action = agent.act(obs)
obs, r, done, _ = envw.step(action)
R += r
t += 1
print('test episode:', i, 'R:', R)
agent.stop_episode()
envw.close()
from IPython.display import HTML
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),dpi=72)
patch = plt.imshow(frames[0])
plt.axis('off')
def animate(i):
patch.set_data(frames[i])
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),interval=50)
anim.save(args.outdir+'/test.mp4')
HTML(anim.to_jshtml())