LoginSignup
4
2

More than 5 years have passed since last update.

Tensorflow + Q学習 で秘書問題を解いた

Posted at

はじめに

少し前に JavaScript + Q学習 で秘書問題を解いた(Qiita記事へのリンク: 「機械学習で秘書問題を解いた」)。今後 JavaScript のままだと機械学習の実装がつらくなりそうなので、python + Tensorflow で同じことを実装したのが本稿。Tensorflow は Lowlevel API を使った。

秘書問題ならではの考慮点

前回の記事参照。その時点での相対順位を作る必要があり、toAccumRelativeRank()という関数で実装した。

TensorFlow ならではの考慮点

特にないが、初めて触ったので良い実装ではない気がする。ただ、placeholder を使った Tensor の fancy indexing がうまく動作しなかったので、回避したため冗長な書き方になっている。どうすればいいか知っている人がいたら教えてほしい。

このエラーを回避する方法がわからない.py
g = tf.Graph()
with g.as_default():
  idx = tf.placeholder(tf.int32)
  q = tf.Variable([[0,1,2],[3,4,5]])
  val = q[idx].assign(q[idx] + 1)  
with tf.Session(graph=g) as session:
  session.run(tf.global_variables_initializer())
  print(session.run(val, {idx:(0,1)}))
# InvalidArgumentError (see above for traceback): Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,2], [1,2], and [1] instead.

Q-value の考察

前回の記事と全く同じ考察ができたので、上手くいったんじゃないかな。

コード

  • 意思決定は ε-greedy
  • 報酬は成功時のみ 1.0 でそれ以外は 0.0
  • Q-table の shape は (N+1, N, 2). なぜかは前回の記事を参考
コード.py
import numpy as np
import tensorflow as tf
import random

def toAccumRelativeRank(arr):
  return [len(list(filter(lambda x: x < arr[i], arr[:i]))) for i, x in enumerate(arr)]

class Env(object):
  def __init__(self, n):
    self.n = n
    self.reset()
  def reset(self):
    self.curr = 0
    self.seq = list(range(self.n))
    random.shuffle(self.seq)
    self.arank = toAccumRelativeRank(self.seq)
    return { "nextAction": [0, 1], "state": (0, 0) }
  def step(self, action):
    self.curr += 1
    if (action == 1):
      return {
        "state": (len(self.seq), 0), # final state
        "finished": True,
        "reward": 1.0 if self.seq[self.curr - 1] == 0 else 0.0,
      }
    if (action == 0):
      return {
        "state": (self.curr, self.arank[self.curr]),
        "finished": False,
        "reward": 0.0,
        "nextAction": [1] if self.curr == self.n - 1 else [0, 1],
      }
    raise Error('cannot reach here')

class Q(object):
  def __init__(self, n):
    self.n = n
    self.qshape = (n+1, n, 2)
    self._init_tf()
  def _init_tf(self):  
    g = tf.Graph()
    with g.as_default():
      self.qtable = tf.get_variable("qtable", self.qshape, initializer=tf.constant_initializer(0.0))
      self.idx = tf.placeholder(tf.int32)
      self.val = tf.gather_nd(self.qtable, self.idx) # ops get values
      self.delta = tf.placeholder(tf.float32)
      self.qupdate_op = self.qtable.assign(self.qtable + self.delta)
      init = tf.global_variables_initializer() # op:init
    self.session = tf.Session(graph=g)
    with self.session.as_default():
      self.session.run(init)
  def _get_values(self, locs):
    return self.session.run(self.val, {self.idx: locs})
  def dump(self):
    return self.session.run(self.qtable)
  def update(self, before, action, after, reward):
    alpha = 0.9
    gamma = 0.99
    delta = np.zeros(self.qshape)
    delta[before][action] += - alpha * self._get_values(before + (action,))
    delta[before][action] += alpha * (reward + gamma * max(self._get_values(after)))
    self.session.run(self.qupdate_op, {self.delta: delta})
  def choose_action(self, state, actions):
    epsilon = 0.5 # eplison-greedy
    actions = info["nextAction"]
    if len(actions) == 1:
      return actions[0]
    if random.random() < epsilon:
      return random.choice(actions)
    values = self._get_values(state)
    return actions[np.argmax(values)]    
  def destroy(self):
    self.session.close()

# main
N = 10
nTrain = 10000;

env = Env(N)
q = Q(N)
for i in range(nTrain):
  info = env.reset()
  while True:
    state0 = info["state"]
    action = q.choose_action(state0, info["nextAction"])
    info = env.step(action)
    q.update(state0, action, info["state"], info["reward"])    
    if info['finished']: break
print(q.dump())
q.destroy()

次のステップ

DQN

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2