8
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowでNumpy配列をバッチサイズずつ取り出す

Last updated at Posted at 2017-11-13

Numpy配列をバッチサイズずつ取り出す

機械学習を行うときなどにNumpy配列をバッチサイズずつ取り出したいシーンはよくあると思います。
Numpyだけで処理することも可能ですが、Data AugmentationにTensorFlowの機能を使うことを考えると、TensorFlowでバッチ化したほうが便利そうなので調べました。
以下のStackoverflowに載っていたのですが、もう少し簡単に書けないかと思って調べた結果を載せます。
https://stackoverflow.com/questions/39068259/passing-a-numpy-array-to-a-tensorflow-queue

Python3.5.2、TensorFlow 1.4で動作確認しました。

データの順序が固定のケース

tf.train.input_placeholderを使ってテンソルをqueueに出力した後、tf.train.batchでバッチ化します。
tf.local_variables_initializer()input_producerが使用するepochのカウンタの初期化に必要です。
またtf.train.start_queue_runners()を読んでqueueを動作させる必要があります。

import numpy as np
import tensorflow as tf

data = np.arange(20).reshape((10, 2))
batch_size = 3
step_size = int(np.ceil(float(len(data)) / batch_size))

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    queue = tf.train.input_producer(data, shuffle=False, num_epochs=1)
    dequeue = queue.dequeue()
    batch = tf.train.batch([dequeue], batch_size=batch_size, allow_smaller_final_batch=True)
    sess.run(tf.local_variables_initializer())
    threads = tf.train.start_queue_runners(sess, coord)
    i = 0
    while i < step_size and not coord.should_stop():
        y = sess.run(batch)
        print(y)
        i += 1
    coord.request_stop()
    coord.join(threads)

実行結果(TensorFlowのメッセージは省略):

[[0 1]
 [2 3]
 [4 5]]
[[ 6  7]
 [ 8  9]
 [10 11]]
[[12 13]
 [14 15]
 [16 17]]
[[18 19]]

データの順序をランダムにするケース

データの順序をランダムにしたい場合はtf.train.shuffle_batchを使います。

epoch_num = 2
step_size = int(np.ceil(float(len(data) * epoch_num) / batch_size))
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    queue = tf.train.input_producer(data, shuffle=False, num_epochs=epoch_num)
    dequeue = queue.dequeue()
    batch = tf.train.shuffle_batch([dequeue], capacity=20, min_after_dequeue=10,
                batch_size=batch_size, allow_smaller_final_batch=True)
    sess.run(tf.local_variables_initializer())
    threads = tf.train.start_queue_runners(sess, coord)
    i = 0
    while i < step_size and not coord.should_stop():
        y = sess.run(batch)
        print(y)
        i += 1
    coord.request_stop()
    coord.join(threads)

実行結果(TensorFlowのメッセージは省略):

[[0 1]
 [2 3]
 [0 1]]
[[18 19]
 [ 8  9]
 [16 17]]
[[16 17]
 [ 4  5]
 [ 2  3]]
[[14 15]
 [12 13]
 [ 4  5]]
[[ 6  7]
 [14 15]
 [ 6  7]]
[[12 13]
 [18 19]
 [10 11]]
[[10 11]
 [ 8  9]]

データの終端を知る

データを最後まで取り出したことを知るには、以下の2つの方法があるようです。

  • データ終端に達してからバッチを取得しようとするとtf.errors.OutOfRangeErrorが発生するので、これをハンドリングする
  • 自分でバッチ数を数える

tf.errors.OutOfRangeErrorをハンドリングする場合は以下のようになります。

with tf.device('/cpu:0'), tf.Session() as sess:
    coord = tf.train.Coordinator()
    queue = tf.train.input_producer(data, shuffle=False, num_epochs=1)
    dequeue = queue.dequeue()
    batch = tf.train.batch([dequeue], batch_size=batch_size, allow_smaller_final_batch=True)
    sess.run(tf.local_variables_initializer())
    threads = tf.train.start_queue_runners(sess, coord)
    try:
        while not coord.should_stop():
            y = sess.run(batch)
            print(y)
    except tf.errors.OutOfRangeError:
        pass
    finally:
        coord.request_stop()
        coord.join(threads)
8
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?