0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

chainerの作法 その7

Last updated at Posted at 2018-05-03

概要

chainaerの作法を調べてみた。
open-ai gymやってみた。

写真

cart.jpg

cart0.png

サンプルコード

import collections
import copy
import random
import gym
import numpy as np
import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers, Chain, no_backprop_mode
import matplotlib.pyplot as plt


class QFunction(Chain):
	def __init__(self, obs_size, n_actions, n_units = 100):
		super(QFunction, self).__init__(l0 = L.Linear(obs_size, n_units), l1 = L.Linear(n_units, n_units), l2 = L.Linear(n_units, n_actions))
	def __call__(self, x):
		h = F.relu(self.l0(x))
		h = F.relu(self.l1(h))
		return self.l2(h)

def get_greedy_action(Q, obs):
	obs = Q.xp.asarray(obs[None], dtype = np.float32)
	with no_backprop_mode():
		q = Q(obs).data[0]
	return int(q.argmax())

def mean_clipped_loss(y, t):
	return F.mean(F.huber_loss(y, t, delta = 1.0, reduce = 'no'))

def update(Q, target_Q, opt, samples, gamma = 0.99, target_type = 'double_dqn'):
	xp = Q.xp
	obs = xp.asarray([sample[0] for sample in samples], dtype = np.float32)
	action = xp.asarray([sample[1] for sample in samples], dtype = np.int32)
	reward = xp.asarray([sample[2] for sample in samples], dtype = np.float32)
	done = xp.asarray([sample[3] for sample in samples], dtype = np.float32)
	obs_next = xp.asarray([sample[4] for sample in samples], dtype = np.float32)
	y = F.select_item(Q(obs), action)
	with no_backprop_mode():
		if target_type == 'dqn':
			next_q = F.max(target_Q(obs_next), axis = 1)
		elif target_type == 'double_dqn':
			next_q = F.select_item(target_Q(obs_next), F.argmax(Q(obs_next), axis = 1))
		else:
			raise ValueError('Unsupported target_type: {}'.format(target_type))
		target = reward + gamma * (1 - done) * next_q
	loss = mean_clipped_loss(y, target)
	Q.cleargrads()
	loss.backward()
	opt.update()

def main():
	env = gym.make('CartPole-v0')
	assert isinstance(env.observation_space, gym.spaces.Box)
	assert isinstance(env.action_space, gym.spaces.Discrete)
	obs_size = env.observation_space.low.size
	n_actions = env.action_space.n
	reward_threshold = env.spec.reward_threshold
	print (obs_size, n_actions, reward_threshold)
	if reward_threshold is not None:
		print ('{} defines "solving" as getting average reward of {} over 100 ' 'consecutive trials.'.format('CartPole-v0', reward_threshold))
	else:
		print ('{} is an unsolved environment, which means it does not have a ' 'specified reward threshold at which it\'s considered ' 'solved.'.format('CartPole-v0'))
	D = collections.deque(maxlen = 10 ** 6)
	Rs = collections.deque(maxlen = 100)
	iteration = 0
	Q = QFunction(obs_size, n_actions, n_units = 100)
	target_Q = copy.deepcopy(Q)
	opt = optimizers.Adam(eps = 1e-2)
	opt.setup(Q)
	reward_trend = []
	for episode in range(200):
		obs = env.reset()
		done = False
		R = 0.0
		timestep = 0
		while not done and timestep < env.spec.timestep_limit:
			env.render()
			epsilon = 1.0 if len(D) < 500 else max(0.01, np.interp(iteration, [0, 5000], [1.0, 0.01]))
			if np.random.rand() < epsilon:
				action = env.action_space.sample()
			else:
				action = get_greedy_action(Q, obs)
			new_obs, reward, done, _ = env.step(action)
			R += reward
			D.append((obs, action, reward * 1e-2, done, new_obs))
			obs = new_obs
			if len(D) >= 500:
				sample_indices = random.sample(range(len(D)), 64)
				samples = [D[i] for i in sample_indices]
				update(Q, target_Q, opt, samples, target_type = 'dqn')
			if iteration % 100 == 0:
				target_Q = copy.deepcopy(Q)
			iteration += 1
			timestep += 1
		Rs.append(R)
		reward_trend.append(R)
		average_R = np.mean(Rs)
		print ('episode: {} iteration: {} reward: {} average_R: {}'.format(episode, iteration, R, average_R))
		if reward_threshold is not None and average_R >= reward_threshold:
			print ('Solved {} by getting average reward of ' '{} >= {} over 100 consecutive episodes.'.format('CartPole-v0', average_R, reward_threshold))
			break
	plt.plot(reward_trend)
	plt.savefig("cart0.png")
	plt.show()

if __name__ == '__main__':
	main()



以上。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?