Numpy配列をバッチサイズずつ取り出す
機械学習を行うときなどにNumpy配列をバッチサイズずつ取り出したいシーンはよくあると思います。
Numpyだけで処理することも可能ですが、Data AugmentationにTensorFlowの機能を使うことを考えると、TensorFlowでバッチ化したほうが便利そうなので調べました。
以下のStackoverflowに載っていたのですが、もう少し簡単に書けないかと思って調べた結果を載せます。
https://stackoverflow.com/questions/39068259/passing-a-numpy-array-to-a-tensorflow-queue
Python3.5.2、TensorFlow 1.4で動作確認しました。
データの順序が固定のケース
tf.train.input_placeholder
を使ってテンソルをqueueに出力した後、tf.train.batch
でバッチ化します。
tf.local_variables_initializer()
はinput_producer
が使用するepochのカウンタの初期化に必要です。
またtf.train.start_queue_runners()
を読んでqueueを動作させる必要があります。
import numpy as np
import tensorflow as tf
data = np.arange(20).reshape((10, 2))
batch_size = 3
step_size = int(np.ceil(float(len(data)) / batch_size))
with tf.Session() as sess:
coord = tf.train.Coordinator()
queue = tf.train.input_producer(data, shuffle=False, num_epochs=1)
dequeue = queue.dequeue()
batch = tf.train.batch([dequeue], batch_size=batch_size, allow_smaller_final_batch=True)
sess.run(tf.local_variables_initializer())
threads = tf.train.start_queue_runners(sess, coord)
i = 0
while i < step_size and not coord.should_stop():
y = sess.run(batch)
print(y)
i += 1
coord.request_stop()
coord.join(threads)
実行結果(TensorFlowのメッセージは省略):
[[0 1]
[2 3]
[4 5]]
[[ 6 7]
[ 8 9]
[10 11]]
[[12 13]
[14 15]
[16 17]]
[[18 19]]
データの順序をランダムにするケース
データの順序をランダムにしたい場合はtf.train.shuffle_batch
を使います。
epoch_num = 2
step_size = int(np.ceil(float(len(data) * epoch_num) / batch_size))
with tf.Session() as sess:
coord = tf.train.Coordinator()
queue = tf.train.input_producer(data, shuffle=False, num_epochs=epoch_num)
dequeue = queue.dequeue()
batch = tf.train.shuffle_batch([dequeue], capacity=20, min_after_dequeue=10,
batch_size=batch_size, allow_smaller_final_batch=True)
sess.run(tf.local_variables_initializer())
threads = tf.train.start_queue_runners(sess, coord)
i = 0
while i < step_size and not coord.should_stop():
y = sess.run(batch)
print(y)
i += 1
coord.request_stop()
coord.join(threads)
実行結果(TensorFlowのメッセージは省略):
[[0 1]
[2 3]
[0 1]]
[[18 19]
[ 8 9]
[16 17]]
[[16 17]
[ 4 5]
[ 2 3]]
[[14 15]
[12 13]
[ 4 5]]
[[ 6 7]
[14 15]
[ 6 7]]
[[12 13]
[18 19]
[10 11]]
[[10 11]
[ 8 9]]
データの終端を知る
データを最後まで取り出したことを知るには、以下の2つの方法があるようです。
- データ終端に達してからバッチを取得しようとすると
tf.errors.OutOfRangeError
が発生するので、これをハンドリングする - 自分でバッチ数を数える
tf.errors.OutOfRangeError
をハンドリングする場合は以下のようになります。
with tf.device('/cpu:0'), tf.Session() as sess:
coord = tf.train.Coordinator()
queue = tf.train.input_producer(data, shuffle=False, num_epochs=1)
dequeue = queue.dequeue()
batch = tf.train.batch([dequeue], batch_size=batch_size, allow_smaller_final_batch=True)
sess.run(tf.local_variables_initializer())
threads = tf.train.start_queue_runners(sess, coord)
try:
while not coord.should_stop():
y = sess.run(batch)
print(y)
except tf.errors.OutOfRangeError:
pass
finally:
coord.request_stop()
coord.join(threads)