TensorFlowのtf.Variable
型は主にネットワークの重みやバイアスを保持しておくために利用されますが、これを使ってデータ構造のキュー(Queue, 待ち行列)を作ることもできます。このような使い方があることは、PixelCNNの生成過程を高速化したRamachandran氏のFast PixelCNN++のコードを読んでいて知りました。変数は初期化かOptimizerによる更新以外では値を変えるものではないと思っていた僕には目から鱗でした。
以下が、自分で試してみたコードになります。TensorFlowはバージョン1.5を使いました。delay
関数で長さ3のキューと、それの先頭から要素をポップするオペレータを定義しています。
import tensorflow as tf
import numpy as np
def delay(x, n=3):
queue = tf.Variable(initial_value=np.zeros(n, dtype=np.float32), name='queue')
head, tail = queue[0], queue[1:]
pushed = tf.concat([tail, x], 0)
assigned = queue.assign(pushed)
with tf.control_dependencies([assigned]):
return assigned, tf.identity(head)
x = tf.placeholder(name='input', shape=(1), dtype=tf.float32)
queue, pop = delay(x)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(queue, feed_dict={x: np.array([1.0], dtype=np.float32)})
array([ 0., 0., 1.], dtype=float32)
sess.run(queue, feed_dict={x: np.array([2.0], dtype=np.float32)})
array([ 0., 1., 2.], dtype=float32)
sess.run(queue, feed_dict={x: np.array([3.0], dtype=np.float32)})
array([ 1., 2., 3.], dtype=float32)
sess.run(pop, feed_dict={x:np.array([0.0], dtype=np.float32)})
1.0
sess.run(pop, feed_dict={x:np.array([0.0], dtype=np.float32)})
2.0
sess.run(pop, feed_dict={x:np.array([0.0], dtype=np.float32)})
3.0
tf.control_dependencies
を使って、ポップしたときにキューが更新されるようにしている点が少しトリッキーかもしれません。同じ入力に対して結果が異なることがあり扱いには注意が必要ですが、使い所によっては便利だと思います。tf.FIFOQueue
というのもあり、そちらは並列処理での用途が念頭に置かれているようです。