Python
TensorFlow

TensorFlowでNumpy配列をバッチサイズずつ取り出す

More than 1 year has passed since last update.

Numpy配列をバッチサイズずつ取り出す

機械学習を行うときなどにNumpy配列をバッチサイズずつ取り出したいシーンはよくあると思います。
Numpyだけで処理することも可能ですが、Data AugmentationにTensorFlowの機能を使うことを考えると、TensorFlowでバッチ化したほうが便利そうなので調べました。
以下のStackoverflowに載っていたのですが、もう少し簡単に書けないかと思って調べた結果を載せます。
https://stackoverflow.com/questions/39068259/passing-a-numpy-array-to-a-tensorflow-queue

Python3.5.2、TensorFlow 1.4で動作確認しました。

データの順序が固定のケース

tf.train.input_placeholderを使ってテンソルをqueueに出力した後、tf.train.batchでバッチ化します。
tf.local_variables_initializer()input_producerが使用するepochのカウンタの初期化に必要です。
またtf.train.start_queue_runners()を読んでqueueを動作させる必要があります。

import numpy as np
import tensorflow as tf

data = np.arange(20).reshape((10, 2))
batch_size = 3
step_size = int(np.ceil(float(len(data)) / batch_size))

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    queue = tf.train.input_producer(data, shuffle=False, num_epochs=1)
    dequeue = queue.dequeue()
    batch = tf.train.batch([dequeue], batch_size=batch_size, allow_smaller_final_batch=True)
    sess.run(tf.local_variables_initializer())
    threads = tf.train.start_queue_runners(sess, coord)
    i = 0
    while i < step_size and not coord.should_stop():
        y = sess.run(batch)
        print(y)
        i += 1
    coord.request_stop()
    coord.join(threads)

実行結果(TensorFlowのメッセージは省略):

[[0 1]
 [2 3]
 [4 5]]
[[ 6  7]
 [ 8  9]
 [10 11]]
[[12 13]
 [14 15]
 [16 17]]
[[18 19]]

データの順序をランダムにするケース

データの順序をランダムにしたい場合はtf.train.shuffle_batchを使います。

epoch_num = 2
step_size = int(np.ceil(float(len(data) * epoch_num) / batch_size))
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    queue = tf.train.input_producer(data, shuffle=False, num_epochs=epoch_num)
    dequeue = queue.dequeue()
    batch = tf.train.shuffle_batch([dequeue], capacity=20, min_after_dequeue=10,
                batch_size=batch_size, allow_smaller_final_batch=True)
    sess.run(tf.local_variables_initializer())
    threads = tf.train.start_queue_runners(sess, coord)
    i = 0
    while i < step_size and not coord.should_stop():
        y = sess.run(batch)
        print(y)
        i += 1
    coord.request_stop()
    coord.join(threads)

実行結果(TensorFlowのメッセージは省略):

[[0 1]
 [2 3]
 [0 1]]
[[18 19]
 [ 8  9]
 [16 17]]
[[16 17]
 [ 4  5]
 [ 2  3]]
[[14 15]
 [12 13]
 [ 4  5]]
[[ 6  7]
 [14 15]
 [ 6  7]]
[[12 13]
 [18 19]
 [10 11]]
[[10 11]
 [ 8  9]]

データの終端を知る

データを最後まで取り出したことを知るには、以下の2つの方法があるようです。
* データ終端に達してからバッチを取得しようとするとtf.errors.OutOfRangeErrorが発生するので、これをハンドリングする
* 自分でバッチ数を数える

tf.errors.OutOfRangeErrorをハンドリングする場合は以下のようになります。

with tf.device('/cpu:0'), tf.Session() as sess:
    coord = tf.train.Coordinator()
    queue = tf.train.input_producer(data, shuffle=False, num_epochs=1)
    dequeue = queue.dequeue()
    batch = tf.train.batch([dequeue], batch_size=batch_size, allow_smaller_final_batch=True)
    sess.run(tf.local_variables_initializer())
    threads = tf.train.start_queue_runners(sess, coord)
    try:
        while not coord.should_stop():
            y = sess.run(batch)
            print(y)
    except tf.errors.OutOfRangeError:
        pass
    finally:
        coord.request_stop()
        coord.join(threads)