TensorflowのTreadhingとQueueを使うと学習データの読み込み処理を並列処理させることができる.ランダムな2次元マトリックスを入力データとして,バッチとして返す処理を作ってみた.
queue.py
import tensorflow as tf
import numpy as np
import threading
# parameters
data_shape = [2,3]
batch_size = 10
num_threads = 1
epoch = 10
def get_np_data():
"""
Returns random matrix as input data
"""
return np.random.rand(*data_shape).astype(np.float32)
def load_and_enqueue(sess, enqueue_op, coord):
"""
Function to be called in thread
"""
while not coord.should_stop():
np_data = get_np_data()
sess.run(enqueue_op, feed_dict={plf: np_data})
# placeholder of input data
plf = tf.placeholder(tf.float32, shape=data_shape)
# initializes queue
queue = tf.FIFOQueue(batch_size, [tf.float32], shapes=[data_shape])
enqueue_op = queue.enqueue([plf])
batch = queue.dequeue_many(batch_size)
# defines calculation, input_data * 10
y = batch * 10
# initializes session
sess = tf.InteractiveSession()
# initializes coordinator of threads
coord = tf.train.Coordinator()
# starts threads to load input_data
for _ in range(num_threads):
thread = threading.Thread(target=load_and_enqueue, args=(sess, enqueue_op, coord))
thread.start()
coord.register_thread(thread)
# calculates y many times
try:
for e in xrange(epoch):
print sess.run(y).shape
finally:
# finalizes threads
coord.request_stop()
coord.join()
実行結果は
$ python queue.py
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
(10, 2, 3)
のようになる.