0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

kerasで強化学習

Posted at

概要

kerasで強化学習してみた。
openai-gymのcartpoleしてみた。
keras-rlは、使わない。

写真

gym31.png

サンプルコード

import gym
import numpy as np
import time
from collections import deque
from gym import wrappers
from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.layers import Dense
from tensorflow.contrib.keras.python.keras.optimizers import Adam, SGD
from tensorflow.contrib.keras.python.keras import backend as K
import tensorflow as tf
import matplotlib.pyplot as plt

class Network:
	def __init__(self, learning_rate = 0.01, state_size = 4, action_size = 2, hidden_size = 10):
		self.model = Sequential()
		self.model.add(Dense(hidden_size, activation = 'tanh', input_dim = state_size))
		self.model.add(Dense(hidden_size, activation = 'tanh'))
		self.model.add(Dense(action_size, activation = 'linear'))
		self.optimizer = Adam(lr = learning_rate)
		self.model.compile(loss = self.loss, optimizer = self.optimizer)
	def replay(self, memory, batch_size, gamma):
		inputs = np.zeros((batch_size, 4))
		targets = np.zeros((batch_size, 2))
		mini_batch = memory.sample(batch_size)
		for i, (state_b, action_b, reward_b, next_state_b) in enumerate(mini_batch):
			inputs[i : i + 1] = state_b
			target = reward_b
			if not (next_state_b == np.zeros(state_b.shape)).all(axis = 1):
				retmainQs = self.model.predict(next_state_b)[0]
				next_action = np.argmax(retmainQs)
				target = reward_b + gamma * self.model.predict(next_state_b)[0][next_action]
			targets[i] = self.model.predict(state_b)
			targets[i][action_b] = target
			self.model.fit(inputs, targets, epochs = 1, verbose = 0)
	def loss(self, y_true, y_pred):
		err = y_true - y_pred
		cond = K.abs(err) < 1.0
		L2 = 0.5 * K.square(err)
		L1 = (K.abs(err) - 0.5)
		loss = tf.where(cond, L2, L1)
		return K.mean(loss)

class Memory:
	def __init__(self, max_size = 1000):
		self.buffer = deque(maxlen = max_size)
	def add(self, experience):
		self.buffer.append(experience)
	def sample(self, batch_size):
		idx = np.random.choice(np.arange(len(self.buffer)), size = batch_size, replace = False)
		return [self.buffer[ii] for ii in idx]
	def len(self):
		return len(self.buffer)

env = gym.make('CartPole-v0')
gamma = 0.99
memory_size = 5000
mainN = Network(hidden_size = 16, learning_rate = 0.00001)
memory = Memory(max_size = memory_size)
reward_trend = []
for episode in range(299):
	env.reset()
	state, reward, done, _ = env.step(env.action_space.sample())
	state = np.reshape(state, [1, 4])
	for t in range(200):
		#env.render()
		action = 0
		epsilon = 0.001 + 0.9 / (1.0 + episode)
		if epsilon <= np.random.uniform(0, 1):
			retTargetQs = mainN.model.predict(state)[0]
			action = np.argmax(retTargetQs)
		else:
			action = np.random.choice([0, 1])
		next_state, reward, done, info = env.step(action)
		next_state = np.reshape(next_state, [1, 4])
		if done:
			next_state = np.zeros(state.shape)
			if t < 195:
				reward = -1
			else:
				reward = 1
		else:
			reward = 0
		memory.add((state, action, reward, next_state))
		state = next_state
		if (memory.len() > 32):
			mainN.replay(memory, 32, gamma)
		if done:
			reward_trend.append(t + 1)
			print ('%d Episode  %d memory %d' % (episode, t + 1, memory.len()))
			break
plt.plot(reward_trend)
plt.savefig("gym31.png")
plt.show()

以上。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?