LoginSignup
0
1

More than 5 years have passed since last update.

windowsでTensorFlow その21

Last updated at Posted at 2017-05-21

概要

windowsでTensorFlowやってみた。
生tensorflowで強化学習のデモ作って見た。
環境は、フルーツバスケット。
確認用のサンプルコード、載せる。

写真

br.jpg

環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.0
pyqt5

確認用のサンプルコード

import sys
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
import tensorflow as tf
import numpy as np
import random
import os
from collections import deque

class CatchEnvironment():
    def __init__(self, gridSize):
        self.gridSize = gridSize
        self.nbStates = self.gridSize * self.gridSize
        self.state = np.empty(3, dtype = np.uint8)
    def getState(self):
        stateInfo = self.state
        fruit_row = stateInfo[0]
        fruit_col = stateInfo[1]
        basket = stateInfo[2]
        return fruit_row, fruit_col, basket
    def reset(self):
        initialFruitColumn = random.randrange(1, self.gridSize + 1)
        initialBucketPosition = random.randrange(2, self.gridSize + 1 - 1)
        self.state = np.array([1, initialFruitColumn, initialBucketPosition])
        return self.getState()
    def isGameOver(self):
        if (self.state[0] == self.gridSize - 1):
            return True
        else:
            return False
    def drawState(self):
        canvas = np.zeros((self.gridSize, self.gridSize))
        canvas[self.state[0] - 1, self.state[1] - 1] = 1
        canvas[self.gridSize - 1, self.state[2] - 1 - 1] = 1
        canvas[self.gridSize - 1, self.state[2] - 1] = 1
        canvas[self.gridSize - 1, self.state[2] - 1 + 1] = 1
        return canvas
    def getReward(self):
        fruitRow, fruitColumn, basket = self.getState()
        if (fruitRow == self.gridSize - 1):
            if (abs(fruitColumn - basket) <= 1):
                return 1
            else:
                return -1
        else:
            return 0
    def updateState(self, action):
        if (action == 1):
            acton = -1
        elif (action == 2):
            acton = 0
        else:
            acton = 1
        fruitRow, fruitColumn, basket = self.getState()
        newBasket = min(max(2, basket + acton), self.gridSize - 1)
        fruitRow = fruitRow + 1
        self.state = np.array([fruitRow, fruitColumn, newBasket])
    def observe(self):
        canvas = self.drawState()
        canvas = np.reshape(canvas, (-1, self.nbStates))
        return canvas
    def act(self, action):
        self.updateState(action)
        reward = self.getReward()
        gameOver = self.isGameOver()
        return self.observe(), reward, gameOver, self.getState()

class Brain:
    INITIAL_EPSILON = 1.0
    FINAL_EPSILON = 0.01
    EXPLORE = 1000.
    OBSERVE = 100.
    REPLAY_MEMORY = 50000
    BATCH_SIZE = 50
    GAMMA = 0.99
    def __init__(self, n_action, n_width, n_height, state):
        self.n_action = n_action
        self.n_width = n_width
        self.n_height = n_height
        self.time_step = 0
        self.epsilon = self.INITIAL_EPSILON
        self.state_t = np.stack((state, state, state, state), axis = 1)[0]
        self.memory = deque()
        self.input_state = tf.placeholder(tf.float32, [None, len(self.state_t), self.n_width * self.n_height])
        self.input_action = tf.placeholder(tf.float32, [None, self.n_action])
        self.input_Y = tf.placeholder(tf.float32, [None])
        self.rewards = tf.placeholder(tf.float32, [None])
        self.global_step = tf.Variable(0, trainable = False)
        self.Q_value, self.train_op = self.build_model()
        self.saver, self.session = self.init_session()
    def init_session(self):
        saver = tf.train.Saver()
        session = tf.InteractiveSession()
        saver.restore(session, os.getcwd() + "/br1.ckpt")
        return saver, session
    def build_model(self):
        n_input = len(self.state_t) * self.n_width * self.n_height
        state = tf.reshape(self.input_state, [-1, n_input])
        w1 = tf.Variable(tf.truncated_normal([n_input, 128], stddev = 0.01))
        b1 = tf.Variable(tf.constant(0.01, shape = [128]))
        w2 = tf.Variable(tf.truncated_normal([128, 256], stddev = 0.01))
        b2 = tf.Variable(tf.constant(0.01, shape = [256]))
        w3 = tf.Variable(tf.truncated_normal([256, self.n_action], stddev = 0.01))
        b3 = tf.Variable(tf.constant(0.01, shape = [self.n_action]))
        l1 = tf.nn.relu(tf.matmul(state, w1) + b1)
        l2 = tf.nn.relu(tf.matmul(l1, w2) + b2)
        Q_value = tf.matmul(l2, w3) + b3
        Q_action = tf.reduce_sum(tf.multiply(Q_value, self.input_action), axis = 1)
        cost = tf.reduce_mean(tf.square(self.input_Y - Q_action))
        train_op = tf.train.AdamOptimizer(1e-6).minimize(cost, global_step = self.global_step)
        return Q_value, train_op
    def train(self):
        minibatch = random.sample(self.memory, self.BATCH_SIZE)
        state = [data[0] for data in minibatch]
        action = [data[1] for data in minibatch]
        reward = [data[2] for data in minibatch]
        next_state = [data[3] for data in minibatch]
        Y = []
        Q_value = self.Q_value.eval(feed_dict = {
            self.input_state: next_state
        })
        for i in range(0, self.BATCH_SIZE):
            if minibatch[i][4]:
                Y.append(reward[i])
            else:
                Y.append(reward[i] + self.GAMMA * np.max(Q_value[i]))
        self.train_op.run(feed_dict = {
            self.input_Y: Y,
            self.input_action: action,
            self.input_state: state
        })
    def step(self, state, action, reward, terminal):
        next_state = np.append(self.state_t[1:, :], state, axis = 0)
        self.memory.append((self.state_t, action, reward, next_state, terminal))
        if len(self.memory) > self.REPLAY_MEMORY:
            self.memory.popleft()
        if self.time_step > self.OBSERVE:
            self.train()
        self.state_t = next_state
        self.time_step += 1
    def get_action(self, train = False):
        action = np.zeros(self.n_action)
        if train and random.random() <= self.epsilon:
            index = random.randrange(self.n_action)
            #print ("rnd", index)
        else:
            Q_value = self.Q_value.eval(feed_dict = {
                self.input_state: [self.state_t]
            })[0]
            index = np.argmax(Q_value)
            #print ("brain", index)
        action[index] = 1
        if self.epsilon > self.FINAL_EPSILON and self.time_step > self.OBSERVE:
            self.epsilon -= (self.INITIAL_EPSILON - self.FINAL_EPSILON) / self.EXPLORE
        return action, index


class Test(QWidget):
    def __init__(self):
        app = QApplication(sys.argv)
        super().__init__()
        self.init_ui()
        self.show()
        self.timer = QTimer(self)
        self.timer.timeout.connect(self.update)
        self.timer.start(200)
        self.winCount = 0
        self.loseCount = 0
        state = env.observe()
        self.brain = Brain(3, 10, 10, state)
        app.exec_()
    def init_ui(self):
        self.setWindowTitle("PyQt5")
        self.resize(400, 400)
        self.angle = 0
    def paintEvent(self, QPaintEvent):
        action, index = self.brain.get_action(False)
        state, reward, gameOver, stateInfo = env.act(index)
        self.brain.step(state, action, reward, gameOver)
        fruitRow = stateInfo[0]
        fruitColumn = stateInfo[1]
        basket = stateInfo[2]
        if (reward == 1):
            self.winCount = self.winCount + 1
        elif (reward == -1):
            self.loseCount = self.loseCount + 1
        painter = QPainter(self)
        painter.setPen(Qt.black)
        painter.drawLine(QPoint(20, 20), QPoint(20, 220))
        painter.drawLine(QPoint(20, 20), QPoint(220, 20))
        painter.drawLine(QPoint(220, 20), QPoint(220, 220))
        painter.drawLine(QPoint(20, 220), QPoint(220, 220))
        painter.setFont(QFont('Consolas', 20))
        painter.drawText(QPoint(250, 50), "win: " + str(self.winCount));
        painter.drawText(QPoint(250, 80), "miss: " + str(self.loseCount));
        painter.setBrush(Qt.yellow)
        painter.drawRect(fruitColumn * 20, fruitRow * 20, 20, 20)
        painter.setBrush(Qt.green)
        painter.drawRect(basket * 20 - 20, 10 * 20, 60, 20)
        if (gameOver):
            fruitRow, fruitColumn, basket = env.reset()

env = CatchEnvironment(10)
fruitRow, fruitColumn, basket = env.reset()

if __name__ == '__main__':
    Test()


0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1