LoginSignup
2
2

More than 5 years have passed since last update.

TensorFlowのStatefulな変数でキューを作る

Posted at

TensorFlowのtf.Variable型は主にネットワークの重みやバイアスを保持しておくために利用されますが、これを使ってデータ構造のキュー(Queue, 待ち行列)を作ることもできます。このような使い方があることは、PixelCNNの生成過程を高速化したRamachandran氏のFast PixelCNN++のコードを読んでいて知りました。変数は初期化かOptimizerによる更新以外では値を変えるものではないと思っていた僕には目から鱗でした。

以下が、自分で試してみたコードになります。TensorFlowはバージョン1.5を使いました。delay関数で長さ3のキューと、それの先頭から要素をポップするオペレータを定義しています。



import tensorflow as tf
import numpy as np

def delay(x, n=3):
    queue = tf.Variable(initial_value=np.zeros(n, dtype=np.float32), name='queue')
    head, tail = queue[0], queue[1:]
    pushed = tf.concat([tail, x], 0)
    assigned =  queue.assign(pushed)
    with tf.control_dependencies([assigned]):
        return assigned, tf.identity(head)

x = tf.placeholder(name='input', shape=(1), dtype=tf.float32)
queue, pop = delay(x)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(queue, feed_dict={x: np.array([1.0], dtype=np.float32)})

array([ 0., 0., 1.], dtype=float32)

sess.run(queue, feed_dict={x: np.array([2.0], dtype=np.float32)})

array([ 0., 1., 2.], dtype=float32)

sess.run(queue, feed_dict={x: np.array([3.0], dtype=np.float32)})

array([ 1., 2., 3.], dtype=float32)

sess.run(pop, feed_dict={x:np.array([0.0], dtype=np.float32)})

1.0

sess.run(pop, feed_dict={x:np.array([0.0], dtype=np.float32)})

2.0

sess.run(pop, feed_dict={x:np.array([0.0], dtype=np.float32)})

3.0

tf.control_dependenciesを使って、ポップしたときにキューが更新されるようにしている点が少しトリッキーかもしれません。同じ入力に対して結果が異なることがあり扱いには注意が必要ですが、使い所によっては便利だと思います。tf.FIFOQueueというのもあり、そちらは並列処理での用途が念頭に置かれているようです。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2