79
57

More than 5 years have passed since last update.

tf.dataを完全に理解してイケてるデータローダを作るつもりだった

Last updated at Posted at 2018-12-06

はじめに

ここ1年くらいTensorFlowはtf.dataを強く押し出していて,そろそろ実務で使えるようになりたい.「過去のコードの置き換えられる柔軟でイケてるデータローダを作ってやるぜ!」…と思ってましたが時間と能力が枯渇したのでドキュメントさらうだけになりそう.すいません.

TL;DR

  • tf.dataの解説
  • データローダの作成 ※画像です

前提知識

TensorFlowの実行スタイル

皆さんご存知かと思いますが,TensorFlowの実行スタイルは「Define and Run」と言われます.メモリ上に静的なデータフローグラフを定義した後,そのグラフに対してデータを流すというスタイルです.静的グラフを定義するため分散実行性能が高く,速度的にメリットがあります.

Feeding

グラフにデータを流し込む伝統的な方法にFeedingがあります.擬似コードは下記のようになります.Feedingでは,データフローグラフの入力部分にtf.placeholderノードを配し,そこに向かってtf.Session()オブジェクトを介して実データ(主にnp.ndarray)を投げつけていく方法です.方法そのものはかなり分かりやすく,「Define and Run」を説明するのに最も適してると思います.

feeding_example.py
#----- グラフ定義 -----#
x = tf.placeholder(tf.float32, shape=(None, None, None, 3))
y = smt(x) # 任意のグラフ定義

#----- データ入力 -----#
with tf.Srssion() as sess:
    x_batch = DataLoader().get_next() # np.ndarrayを返してくる
    print(sess.run(y, feed_dict={x: x_batch})) # xに向かってx_batchを流して,yを出力

Feedingの欠点

しかし,FeedingにはTensorFlowの長所を打ち消すような特徴があります.Dataloader.get_next()の部分,ここは基本的にNumPyで記述され,グラフの外で処理されます.せっかくの分散実行による高速化が,Python命令ベースのデータ読込がボトルネックになってしまいます.
そういうわけで,現在Feedingは中の人から非推奨にされています.とはいってもFeedingを使ったExampleは生きてますし,既存のモデルは相当数こっちを使っているので,そんなに敵視するほどではないです.

tf.data

tf.dataは,TensorFlowが提供する入力パイプライン設計用のモジュールです.FeedingやQueue Runner(TFRecordsのExampleでよく見る)を置き換えて,こちらを使うのが今の推奨です.tf.data.Datasetを用いて簡単にデータ入力パイプラインが作れるらしい.色々種類があるので,使えそうな範囲でガイドを噛み砕きます.

メインコンセプト

データローダごとグラフ上に定義して,ボトルネックを解決する.(※個人の意見です)

データセットからデータを取得する例を下記に示します.

main_concept.py
data = ... # 何かしらのdataリスト取得
dataset = tf.data.Dataset.from_tensor_slices(data) #データセット作成
dataset = dataset.batch(1) # ミニバッチ化
iterator = dataset.make_one_shot_iterator() # イテレータ作成
batch = iterator.get_next() # 次のバッチを取得

with tf.Srssion() as sess:
    x = sess.run(batch)
    print(type(x)) # => np.ndarray
    print(x.shape) # => (batch_size, height, width, channel)

datasetiteratorbatchはそれぞれノードとしてグラフ上に定義されていて,変数そのものに実データは格納されていません.実データは,tf.Session()を通してbatchを実行することで初めて取得できます.学習モデルを構築する際には,tf.placeholderbatchに置き換えるだけで概ね動きます.sess.run()ごとに勝手にデータが流れるので,推論時にfeed_dictの記述がいらず,コードもスッキリします.これが地味に嬉しい.

基本的な使い方

下記はファイル名から画像を読み込んで,ミニバッチを返す例です.

Dataset

Datasetの作成

filenames = glob.glob(DATA_DIR) # filenameのリスト取得 (CSVとかから読んでも良い)
dataset = tf.data.Dataset.from_tensor_slices(filenames) # filenamesからデータセット作成

filenamesはただのファイル名のリストで,globで取得するも良し,CSVから読み込むも良し,なんでもいいです.filenamestf.data.Dataset.from_tensor_slices()に突っ込むだけでDatasetオブジェクトを作ってくれます.他にも,対応したラベルとかがあるんだったら,次みたいにタプルにして渡してあげればよしなにやってくれます.SparseTensorもいけるらしい.

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

データの読み込み,前処理

def _parse_fn(filename):
    image = tf.image.decode_jpeg(tf.read_file(filename)) # ファイル名 => 画像
    image = tf.random_crop(image, (height, width, channel)) # Random Crop
    image = ... # 他にも色々
    return tf.cast(image, tf.float32)

dataset = dataset.map(_parse_fn)

先ほどのfilenamesを画像データとして読み込んだり,前処理をするにはDataset.map()を利用します.一般的な高階関数と使い方は同じで,_parse_fnの中身は最上位次元を削減した形で記述するので,一定の注意は必要です.ここで使うためにtf.imageの関数が3次元テンソルしか受け付けない様に設計されているんだと妙に腑に落ちました.

ミニバッチ化,リピート,シャッフル

dataset = dataset.batch(10)
dataset = dataset.repeat(1)
dataset = dataset.shuffle(1000)

一気に行きます.

  • ミニバッチ化
    • Dataset.batch()を使用
    • 指定したサイズのミニバッチに分割してくれる
    • 割り切れない場合,最後が(データ数) %(バッチサイズ)の長さになる
      • drop_remainder=Trueで切り捨てることもできる
  • リピート
    • Dataset.repeat()を使用
    • 指定した回数データセットをリピート(指定しないと無限リピート)
    • 使うメリットはあんまりわからない
      • Validationに確率的な前処理(Random Cropとか)が必要な場合に,複数回繰り返せるくらい(?)
  • シャッフル
    • Dataset.shuffle()を使用
    • 指定したバッファサイズによってランダムさが決まるらしい
      • 過学習したくなかったら大きめに設定しておくと吉

これでDatasetオブジェクトについては終わりです.

Iterator

tf.data.Iteratorオブジェクトは,データセットをget_next()を用いてミニバッチ単位でデータを取得してきてくれる,まぁ普通のイテレータです.末尾のミニバッチを読みだした後にミニバッチを取得しようとすると,tf.errors.OutOfRangeErrorを吐きます.

Iteratorの作成

tf.data.Iteratorオブジェクトにはいくつか種類があって,上以下で1つずつ見て行きます.

One Shot Iterator
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

1度きりしか回せないイテレータ.testデータみたいに1回しか回さないデータに対して使います.実行するにあたって特にやらなきゃいけないこともないので簡単なデバッグにも使えます.Dataset.repeat()を使用して擬似的に複数回ループ回せますが,オススメはしないですね.

Initializable Iterator
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
batch = iterator.get_next() # 次のバッチを取得

Initialize + able,つまり初期化できるイテレータです.初期化することで複数回ループを回せます.学習時はこちらが使えます.ループ回す前に必ずinit_opを実行して初期化が必要で,しないとtf.errors.FailedPreconditionErrorを吐きます.

ほんとかよって思ったんですが,末尾まで行ってOutOfRangeErrorをトリガにするらしくて,学習ループを回す際は下記のような実行方法になります.tf.train.MonitoredSession使う手もあるらしいけど,Reinitializableの方で使いづらそうなのでとりあえず保留.デベロッパやSessionが上から監視するんじゃなくてIterator自体にやって欲しいなっていう不満を垂れておきます.

train_op = tf.train.Optimizer().minimize(loss)

with tf.Srssion() as sess:
    for epoch in range(NUM_EPOCH):
        sess.run(init_op) # 毎エポックの頭に初期化
        while True: # tf.errors.OutOfRangeErrorを吐くまでループ
            try:
                sess.run(train_op)
            except tf.errors.OutOfRangeError: # 末尾まで行ったらループを抜ける
                break
Reinitializable Iterator

正直ネーミングがよくわからんですが,複数のinit_opを用いてデータセットを切り替えできるイテレータです.Switchable Iteratorとかじゃいけないのかなと,拙い英語しか喋れないながら思います.学習時にtrain/valを切り替えながら使う時ことを考えると,これを使うのが一番いいように感じます.Initializable Iteratorを用いる場合,weight-sharingして別のグラフを作らなきゃいけない?ともかく,上記2つの定義方法とは違う方法で作成します.

train_set = tf.data.Dataset.from_tensor_slices(...) # trainデータ
val_set = tf.data.Dataset.from_tensor_slices(...) # validationデータ

# データセットの構造からiterator作成
iterator = tf.data.Iterator.from_structure(train_set.output_types,
                                           train_set.output_shapes)
batch = iterator.net_next() # 共通のbatch

train_init_op = iterator.make_initializer(train_set) # train_setに切り替えるinit_op
val_init_op = iterator.make_initializer(val_set) # val_setに切り替えるinit_op

...

# 実行方法もまとめて示します. 例のごとく OutOfRangeError がトリガ.
with tf.Srssion() as sess:
    for epoch in range(NUM_EPOCH):
        sess.run(train_init_op) # train_setを使って初期化
        while True:
            try:
                sess.run(train_op)
            except tf.errors.OutOfRangeError:
                break

        sess.run(val_init_op) # val_setを使って初期化
        while True:
            try:
                sess.run(val_op)
            except tf.errors.OutOfRangeError:
                break
Feedable Iterator

handle経由で初期化せずに途中で切り替え出来るらしい.どこかの記事で使い道がわからんって言われてた.疲れたので割愛します.

データローダの作成

tf.dataについてだらだら解説してきましたが,ここではモデルの部分はいじらずそのままFeedingの実装にぶち込めるようなDatasetクラスを作ります.できるだけAugmentationも色々入れ込みたい.普段pix2pi的なモデルを触ってるので,それ用です.NYU Depthみたいに画像ペアが拡張子違いで同じファイル名にデータが保存されている想定です.train-test-splitは元データをディレクトリに分けてtf.data.Datasetオブジェクトを分けてくれれば簡単にできます.ペアワイズされてるデータセットは基本的にサイズが揃ってるので.

[root]
  |-[data]
  |   |-IMAGE.jpg
  |   |-IMAGE.png
  |-dataset.py
  |-main.py
dataset.py
import tensorflow as tf
from tensorflow.image import ResizeMethod
from tensorflow.contrib.image import rotate

from os import path
from glob import glob
from math import pi

class Dataset():
    def __init__(self, config):
        # store command-line argumments
        self.data_dir = config.data_dir
        self.batch_size = config.batch_size
        self.image_size = (config.image_size,) * 2
        self.input_depth = config.input_depth
        self.output_depth = config.output_depth
        self.total_depth = self.input_depth + self.output_depth

        self._build_pipline()

    def _build_pipline(self):
        src_paths = glob(path.join(self.data_dir, '*.jpg'))
        tgt_paths = glob(path.join(self.data_dir, '*.png'))

        dataset = tf.data.Dataset.from_tensor_slices((src_paths, tgt_paths))
        dataset = dataset.map(self._imread_fn)
        dataset = dataset.map(self._augment_fn)
        # dataset = dataset.repeat()
        dataset = dataset.shuffle(1000)
        dataset = dataset.batch(self.batch_size)

        iterator = dataset.make_initializable_iterator()
        self.initializer = iterator.initializer
        self.batch = iterator.get_next()

    def _imread_fn(self, src_path, tgt_path):
        src_img = tf.image.decode_jpeg(tf.read_file(src_path))
        tgt_img = tf.image.decode_png(tf.read_file(tgt_path))
        return tf.cast(src_img, tf.float32), tf.cast(tgt_img, tf.float32)

    def _augment_fn(self, src_img, tgt_img):
        # augmentation for source image
        src_img = tf.image.random_brightness(src_img, max_delta=0.5)
        src_img = tf.image.random_contrast(src_img, lower=0.2, upper=1.8)
        src_img = src_img + tf.random_normal(tf.shape(src_img), stddev=8)

        # concatenate images
        img = tf.concat([src_img, tgt_img], axis=-1)

        # resize and crop
        height, width = tf.unstack(tf.shape(src_img)[:-1])
        sf = tf.random_uniform((), minval=0.8, maxval=1.2)
        size = tf.cast(tf.shape(img)[:-1], tf.float32) * sf
        img = tf.image.resize_images(img, tf.cast(size, tf.int32), ResizeMethod.BICUBIC)
        img = tf.image.resize_image_with_crop_or_pad(img, height, width)
        img = tf.random_crop(img, (*self.image_size, self.total_depth))

        # random flip
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)

        # random rotate
        img = rotate(img, tf.random_uniform((), minval=-1 / 4, maxval=1 / 4) * pi)

        src_img = img[..., :self.input_depth] / 255
        tgt_img = img[..., -self.output_depth:] / 255

        return src_img, tgt_img

mainからこれを叩いてみる.

from dataset import Dataset

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./data/NYUD')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--image_size', type=int, default=240)
parser.add_argument('--input_depth', type=int, default=3)
parser.add_argument('--output_depth', type=int, default=1)
config = parser.parse_args()

if __name__ == '__main__':
    tf.reset_default_graph()
    dataset = Dataset(config)

    with tf.Session() as sess:
        for epoch in range(5):
            sess.run(dataset.initializer)
            while True:
                try:
                    x, y = sess.run(dataset.batch)
                    x[x > 1], x[x < 0] = 1., 0.
                    y[y > 1], y[y < 0] = 1., 0.
                    cv2.imwrite('IMAGE_{}_src.jpg'.format(epoch), np.squeeze(x * 255)[:, :, ::-1])
                    cv2.imwrite('IMAGE_{}_tgt.jpg'.format(epoch), np.squeeze(y * 255)[:, :])
                except tf.errors.OutOfRangeError:
                    print('finished!')
                    break

出力結果がこちら.いい感じ.
IMAGE_0_src.jpg IMAGE_0_tgt.jpg
IMAGE_1_src.jpg IMAGE_1_tgt.jpg
IMAGE_2_src.jpg IMAGE_2_tgt.jpg
これならdataset.batchを既存のモデルクラスに組み込むだけで済みそう.

まとめ

  • tf.dataをちょっぴり理解した
  • TensorFlowに慣れてるとかなり簡単に書けて,テンソルで記述するので柔軟性もあるように感じる.
  • _parse_fn頑張れ
79
57
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
79
57