0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

windowsでTensorFlow その20

Last updated at Posted at 2017-05-21

概要

windowsでTensorFlowやってみた。
生tensorflowで強化学習のデモ作って見た。
環境は、フルーツバスケット。
学習用のサンプルコード、載せる。

写真

br.jpg

環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.0
pyqt5

学習用のサンプルコード

import sys
import tensorflow as tf
import numpy as np
import random
import os
from collections import deque

class CatchEnvironment():
	def __init__(self, gridSize):
		self.gridSize = gridSize
		self.nbStates = self.gridSize * self.gridSize
		self.state = np.empty(3, dtype = np.uint8)
	def getState(self):
		stateInfo = self.state
		fruit_row = stateInfo[0]
		fruit_col = stateInfo[1]
		basket = stateInfo[2]
		return fruit_row, fruit_col, basket
	def reset(self):
		initialFruitColumn = random.randrange(1, self.gridSize + 1)
		initialBucketPosition = random.randrange(2, self.gridSize + 1 - 1)
		self.state = np.array([1, initialFruitColumn, initialBucketPosition])
		return self.getState()
	def isGameOver(self):
		if (self.state[0] == self.gridSize - 1):
			return True
		else:
			return False
	def drawState(self):
		canvas = np.zeros((self.gridSize, self.gridSize))
		canvas[self.state[0] - 1, self.state[1] - 1] = 1
		canvas[self.gridSize - 1, self.state[2] - 1 - 1] = 1
		canvas[self.gridSize - 1, self.state[2] - 1] = 1
		canvas[self.gridSize - 1, self.state[2] - 1 + 1] = 1
		return canvas
	def getReward(self):
		fruitRow, fruitColumn, basket = self.getState()
		if (fruitRow == self.gridSize - 1):
			if (abs(fruitColumn - basket) <= 1):
				return 1
			else:
				return -1
		else:
			return 0
	def updateState(self, action):
		if (action == 1):
			acton = -1
		elif (action == 2):
			acton = 0
		else:
			acton = 1
		fruitRow, fruitColumn, basket = self.getState()
		newBasket = min(max(2, basket + acton), self.gridSize - 1)
		fruitRow = fruitRow + 1
		self.state = np.array([fruitRow, fruitColumn, newBasket])
	def observe(self):
		canvas = self.drawState()
		canvas = np.reshape(canvas, (-1, self.nbStates))
		return canvas
	def act(self, action):
		self.updateState(action)
		reward = self.getReward()
		gameOver = self.isGameOver()
		return self.observe(), reward, gameOver, self.getState()

class Brain:
	INITIAL_EPSILON = 1.0
	FINAL_EPSILON = 0.01
	EXPLORE = 1000.
	OBSERVE = 300
	REPLAY_MEMORY = 50000
	BATCH_SIZE = 50
	GAMMA = 0.99
	def __init__(self, n_action, n_width, n_height, state):
		self.n_action = n_action
		self.n_width = n_width
		self.n_height = n_height
		self.time_step = 0
		self.epsilon = self.INITIAL_EPSILON
		self.state_t = np.stack((state, state, state, state), axis = 1)[0]
		self.memory = deque()
		self.input_state = tf.placeholder(tf.float32, [None, len(self.state_t), self.n_width * self.n_height])
		self.input_action = tf.placeholder(tf.float32, [None, self.n_action])
		self.input_Y = tf.placeholder(tf.float32, [None])
		self.rewards = tf.placeholder(tf.float32, [None])
		self.global_step = tf.Variable(0, trainable = False)
		self.Q_value, self.train_op = self.build_model()
		self.saver, self.session = self.init_session()
	def init_session(self):
		saver = tf.train.Saver()
		session = tf.InteractiveSession()
		session.run(tf.global_variables_initializer())
		return saver, session
	def build_model(self):
		n_input = len(self.state_t) * self.n_width * self.n_height
		state = tf.reshape(self.input_state, [-1, n_input])
		w1 = tf.Variable(tf.truncated_normal([n_input, 128], stddev = 0.01))
		b1 = tf.Variable(tf.constant(0.01, shape = [128]))
		w2 = tf.Variable(tf.truncated_normal([128, 256], stddev = 0.01))
		b2 = tf.Variable(tf.constant(0.01, shape = [256]))
		w3 = tf.Variable(tf.truncated_normal([256, self.n_action], stddev = 0.01))
		b3 = tf.Variable(tf.constant(0.01, shape = [self.n_action]))
		l1 = tf.nn.relu(tf.matmul(state, w1) + b1)
		l2 = tf.nn.relu(tf.matmul(l1, w2) + b2)
		Q_value = tf.matmul(l2, w3) + b3
		Q_action = tf.reduce_sum(tf.multiply(Q_value, self.input_action), axis = 1)
		cost = tf.reduce_mean(tf.square(self.input_Y - Q_action))
		train_op = tf.train.AdamOptimizer(1e-6).minimize(cost, global_step = self.global_step)
		return Q_value, train_op
	def train(self):
		minibatch = random.sample(self.memory, self.BATCH_SIZE)
		state = [data[0] for data in minibatch]
		action = [data[1] for data in minibatch]
		reward = [data[2] for data in minibatch]
		next_state = [data[3] for data in minibatch]
		Y = []
		Q_value0 = self.Q_value.eval(feed_dict = {
			self.input_state: next_state
		})
		for i in range(0, self.BATCH_SIZE):
			if minibatch[i][4]:
				Y.append(reward[i])
			else:
				Y.append(reward[i] + self.GAMMA * np.max(Q_value0[i]))
		self.train_op.run(feed_dict = {
			self.input_Y: Y,
			self.input_action: action,
			self.input_state: state
		})
	def step(self, state, action, reward, terminal):
		next_state = np.append(self.state_t[1:, :], state, axis = 0)
		self.memory.append((self.state_t, action, reward, next_state, terminal))
		if len(self.memory) > self.REPLAY_MEMORY:
			self.memory.popleft()
		if self.time_step > self.OBSERVE:
			self.train()
		self.state_t = next_state
		self.time_step += 1
	def get_action(self, train = False):
		action = np.zeros(self.n_action)
		if train and random.random() <= self.epsilon:
			index = random.randrange(self.n_action)
			#print ("rnd", index)
		else:
			Q_value0 = self.Q_value.eval(feed_dict = {
				self.input_state: [self.state_t]
			})[0]
			index = np.argmax(Q_value0)
			#print ("brain", index)
		action[index] = 1
		if self.epsilon > self.FINAL_EPSILON and self.time_step > self.OBSERVE:
			self.epsilon -= (self.INITIAL_EPSILON - self.FINAL_EPSILON) / self.EXPLORE
		return action, index
	def save(self):
		save_path = self.saver.save(self.session, os.getcwd() + "/br1.ckpt")
		print ("Model saved in file: %s" % save_path)

def main(_):
	epoch = 10001
	print ("Training new model")
	env = CatchEnvironment(10)
	fruitRow, fruitColumn, basket = env.reset()
	state = env.observe()
	brain = Brain(3, 10, 10, state)
	winCount = 0
	loseCount = 0
	for i in range(epoch):
		isGameOver = False
		currentState = env.observe()
		while (isGameOver != True):
			action, index = brain.get_action(True)
			state, reward, gameOver, stateInfo = env.act(index)
			brain.step(state, action, reward, gameOver)
			fruitRow = stateInfo[0]
			fruitColumn = stateInfo[1]
			basket = stateInfo[2]
			if (reward == 1):
				winCount = winCount + 1
			elif (reward == -1):
				loseCount = loseCount + 1
			if (gameOver):
				fruitRow, fruitColumn, basket = env.reset()
				isGameOver = True
		print (i, " win: ", winCount, " loss: ", loseCount)
	brain.save()

if __name__ == '__main__':
	tf.app.run()

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?