TensorFlowにおける、ファイルからのデータ読み取り機構を、一連の流れとして分解してみました。
CIFAR-10のバイナリ画像データをどのように取り込み、Queueを活用し、Sessionでテンソルとしてグラフに流しているのかの参考になるかと思います。
分かったこととしては
・公式のReading dataにある通り、ファイルからの読み取りは7つのステップに沿って行われる
・FilenameQueueをかませるのは、データのシャッフルや複数スレッドでの処理の実行のため
・次の構造が使われている
tf.Graph().as_default()
sess=tf.Session()
tf.train.start_queue_runners(sess=sess)
for i in range():
sess.run([ .. ])
```
などです。
また、[TensorFlowのReaderクラスを使ってみる](http://qiita.com/knok/items/2dd15189cbca5f9890c5)はjpeg画像の扱い方の最も重要な部分を解説してくれているので、合わせてご参照ください。
/tmp/cifar10_data/..としてcifar10のデータが保存されているとしています。次のコードを走らせると、画像データがテンソルとして出力されます。
このスクリプトはcifar10 tutorialの大量の関数から、データの読み込みと前処理の基本的な部分を抜き出してきたものです。より多くの処理を施す場合はcifar10_input.pyを参照してください。
```tensorflow_experiment3.py
#coding:utf-8
# Cifar10の image file を読み込んでテンソルに変換するまで.
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('max_steps', 1,
"""Number of batches to run.""")
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
with tf.Graph().as_default():
# 1. ファイル名のリスト
filenames = ['/tmp/cifar10_data/cifar-10-batches-bin/data_batch_1.bin',
'/tmp/cifar10_data/cifar-10-batches-bin/data_batch_2.bin',
'/tmp/cifar10_data/cifar-10-batches-bin/data_batch_3.bin',
'/tmp/cifar10_data/cifar-10-batches-bin/data_batch_4.bin',
'/tmp/cifar10_data/cifar-10-batches-bin/data_batch_5.bin']
# 2. ファイル名のシャッフルはなし
# 3. epoch limitの設定もなし
# 4. 「ファイル名リスト」のqueueの作成
filename_queue = tf.train.string_input_producer(filenames)
# 5. データのフォーマットにあったreaderの作成
class CIFAR10Record(object):
pass
result = CIFAR10Record()
label_bytes = 1
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
##readerにqueueを渡してファイルを開く
result.key, value = reader.read(filename_queue)
# 6. readの結果からデータをdecode
record_bytes = tf.decode_raw(value, tf.uint8)
# 7. データの整形
# 7-1. 基本的な整形
result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
[result.depth, result.height, result.width])
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
read_input = result
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
float_image = reshaped_image
# 7-2. データのシャッフルの準備
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
# 7-3. バッチの作成(シャッフル有り)
batch_size = FLAGS.batch_size
num_preprocess_threads = 16
images, label_batch = tf.train.shuffle_batch(
[float_image, read_input.label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
images=images
labels = tf.reshape(label_batch, [batch_size])
# 8. 実行
sess = tf.Session()
tf.train.start_queue_runners(sess=sess)
for step in xrange(FLAGS.max_steps):
img_label = sess.run([images, labels])
print(img_label)
print("FIN.")
```