TensorFlowによるデータの読み込み

  • 17
    いいね
  • 0
    コメント
この記事は最終更新日から1年以上が経過しています。

TensorFlowにおける、ファイルからのデータ読み取り機構を、一連の流れとして分解してみました。
CIFAR-10のバイナリ画像データをどのように取り込み、Queueを活用し、Sessionでテンソルとしてグラフに流しているのかの参考になるかと思います。

分かったこととしては
・公式のReading dataにある通り、ファイルからの読み取りは7つのステップに沿って行われる
・FilenameQueueをかませるのは、データのシャッフルや複数スレッドでの処理の実行のため
・次の構造が使われている

tf.Graph().as_default()

    sess=tf.Session()
    tf.train.start_queue_runners(sess=sess)

    for i in range():
        sess.run([ .. ])

などです。

また、TensorFlowのReaderクラスを使ってみるはjpeg画像の扱い方の最も重要な部分を解説してくれているので、合わせてご参照ください。

/tmp/cifar10_data/..としてcifar10のデータが保存されているとしています。次のコードを走らせると、画像データがテンソルとして出力されます。
このスクリプトはcifar10 tutorialの大量の関数から、データの読み込みと前処理の基本的な部分を抜き出してきたものです。より多くの処理を施す場合はcifar10_input.pyを参照してください。

tensorflow_experiment3.py

#coding:utf-8
# Cifar10の image file を読み込んでテンソルに変換するまで.
import tensorflow as tf


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('max_steps', 1,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_integer('batch_size', 128,
                            """Number of images to process in a batch.""")
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000



with tf.Graph().as_default(): 
    # 1. ファイル名のリスト
    filenames = ['/tmp/cifar10_data/cifar-10-batches-bin/data_batch_1.bin',
        '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_2.bin',
        '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_3.bin', 
        '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_4.bin', 
        '/tmp/cifar10_data/cifar-10-batches-bin/data_batch_5.bin']
    # 2. ファイル名のシャッフルはなし
    # 3. epoch limitの設定もなし


    # 4. 「ファイル名リスト」のqueueの作成
    filename_queue = tf.train.string_input_producer(filenames)


    # 5. データのフォーマットにあったreaderの作成
    class CIFAR10Record(object):
        pass
    result = CIFAR10Record()

    label_bytes = 1 
    result.height = 32
    result.width = 32
    result.depth = 3
    image_bytes = result.height * result.width * result.depth
    record_bytes = label_bytes + image_bytes

    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)

    ##readerにqueueを渡してファイルを開く
    result.key, value = reader.read(filename_queue)


    # 6. readの結果からデータをdecode
    record_bytes = tf.decode_raw(value, tf.uint8)


    # 7. データの整形
    # 7-1. 基本的な整形
    result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
    depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
                                [result.depth, result.height, result.width])
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])

    read_input = result
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    float_image = reshaped_image

    # 7-2. データのシャッフルの準備
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                            min_fraction_of_examples_in_queue)
    print ('Filling queue with %d CIFAR images before starting to train. '
            'This will take a few minutes.' % min_queue_examples)

    # 7-3. バッチの作成(シャッフル有り)
    batch_size = FLAGS.batch_size
    num_preprocess_threads = 16
    images, label_batch = tf.train.shuffle_batch(
    [float_image, read_input.label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples)

    images=images
    labels = tf.reshape(label_batch, [batch_size])


    # 8. 実行
    sess = tf.Session()
    tf.train.start_queue_runners(sess=sess)
    for step in xrange(FLAGS.max_steps):
        img_label = sess.run([images, labels])
        print(img_label)
    print("FIN.")