概要
chainaerの作法を調べてみた。
open-ai gymやってみた。
写真
サンプルコード
import collections
import copy
import random
import gym
import numpy as np
import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers, Chain, no_backprop_mode
import matplotlib.pyplot as plt
class QFunction(Chain):
def __init__(self, obs_size, n_actions, n_units = 100):
super(QFunction, self).__init__(l0 = L.Linear(obs_size, n_units), l1 = L.Linear(n_units, n_units), l2 = L.Linear(n_units, n_actions))
def __call__(self, x):
h = F.relu(self.l0(x))
h = F.relu(self.l1(h))
return self.l2(h)
def get_greedy_action(Q, obs):
obs = Q.xp.asarray(obs[None], dtype = np.float32)
with no_backprop_mode():
q = Q(obs).data[0]
return int(q.argmax())
def mean_clipped_loss(y, t):
return F.mean(F.huber_loss(y, t, delta = 1.0, reduce = 'no'))
def update(Q, target_Q, opt, samples, gamma = 0.99, target_type = 'double_dqn'):
xp = Q.xp
obs = xp.asarray([sample[0] for sample in samples], dtype = np.float32)
action = xp.asarray([sample[1] for sample in samples], dtype = np.int32)
reward = xp.asarray([sample[2] for sample in samples], dtype = np.float32)
done = xp.asarray([sample[3] for sample in samples], dtype = np.float32)
obs_next = xp.asarray([sample[4] for sample in samples], dtype = np.float32)
y = F.select_item(Q(obs), action)
with no_backprop_mode():
if target_type == 'dqn':
next_q = F.max(target_Q(obs_next), axis = 1)
elif target_type == 'double_dqn':
next_q = F.select_item(target_Q(obs_next), F.argmax(Q(obs_next), axis = 1))
else:
raise ValueError('Unsupported target_type: {}'.format(target_type))
target = reward + gamma * (1 - done) * next_q
loss = mean_clipped_loss(y, target)
Q.cleargrads()
loss.backward()
opt.update()
def main():
env = gym.make('CartPole-v0')
assert isinstance(env.observation_space, gym.spaces.Box)
assert isinstance(env.action_space, gym.spaces.Discrete)
obs_size = env.observation_space.low.size
n_actions = env.action_space.n
reward_threshold = env.spec.reward_threshold
print (obs_size, n_actions, reward_threshold)
if reward_threshold is not None:
print ('{} defines "solving" as getting average reward of {} over 100 ' 'consecutive trials.'.format('CartPole-v0', reward_threshold))
else:
print ('{} is an unsolved environment, which means it does not have a ' 'specified reward threshold at which it\'s considered ' 'solved.'.format('CartPole-v0'))
D = collections.deque(maxlen = 10 ** 6)
Rs = collections.deque(maxlen = 100)
iteration = 0
Q = QFunction(obs_size, n_actions, n_units = 100)
target_Q = copy.deepcopy(Q)
opt = optimizers.Adam(eps = 1e-2)
opt.setup(Q)
reward_trend = []
for episode in range(200):
obs = env.reset()
done = False
R = 0.0
timestep = 0
while not done and timestep < env.spec.timestep_limit:
env.render()
epsilon = 1.0 if len(D) < 500 else max(0.01, np.interp(iteration, [0, 5000], [1.0, 0.01]))
if np.random.rand() < epsilon:
action = env.action_space.sample()
else:
action = get_greedy_action(Q, obs)
new_obs, reward, done, _ = env.step(action)
R += reward
D.append((obs, action, reward * 1e-2, done, new_obs))
obs = new_obs
if len(D) >= 500:
sample_indices = random.sample(range(len(D)), 64)
samples = [D[i] for i in sample_indices]
update(Q, target_Q, opt, samples, target_type = 'dqn')
if iteration % 100 == 0:
target_Q = copy.deepcopy(Q)
iteration += 1
timestep += 1
Rs.append(R)
reward_trend.append(R)
average_R = np.mean(Rs)
print ('episode: {} iteration: {} reward: {} average_R: {}'.format(episode, iteration, R, average_R))
if reward_threshold is not None and average_R >= reward_threshold:
print ('Solved {} by getting average reward of ' '{} >= {} over 100 consecutive episodes.'.format('CartPole-v0', average_R, reward_threshold))
break
plt.plot(reward_trend)
plt.savefig("cart0.png")
plt.show()
if __name__ == '__main__':
main()
以上。