Help us understand the problem. What is going on with this article?

Tensorflow dataのチューニング実験

はじめに

Better performance with the tf.data API
上のページを参考にCNN用のデータローダの設計を行ない、Tensorflowにおけるデータローダの速度のチューニングを行いました。
結論として、いくつかの高速化テクニックを試したのですが、残念ながらベースラインとなる実装より早くすることはできませんでした。

tf.data

Tensorflowにはtf.dataと呼ばれるインプットパイプライン用のAPIが用意されています。
画像ファイルなどといった、RAMに乗り切らないデータをモデルに読み込ませる時、tf.dataを使用すると、内部でデータの前処理とNNの学習が並列で実行されるため高速な処理が実現できます。
大雑把な仕組みは以下の通りです。

pipeline_performance.png

pythonのgeneratorなどで実装すると、CPUやGPUが稼働している間、もう一方はアイドル状態になるため効率が悪いのですが、tf.dataで実装すれば、アイドル状態の時間を短縮できるということになります。

実装はこちらで解説されており、比較的少ない手間で実装できます。

問題設定

ImageNetの画像をjpegで保存した状態でmobilenetに読み込ませます。
実験環境はGoogle Colaboratoryです。

ベースライン実装

はじめにtf.dataの基本的な使い方。
保存した画像のパスを1つずつ読み取り、244×244の画像にランダムにクロップします。

train_img_paths = glob.glob(os.path.join(IMAGE_DIR, '*.jpg'))
train_img_paths.sort()

num_train_imgs = len(train_img_paths)
train_label = [1 for path in train_img_paths]

m = tf.keras.Sequential([
    hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4", output_shape=[1], trainable=True)
])
m.build([None, IMAGE_SIZE[0], IMAGE_SIZE[1], IMAGE_SIZE[2]])
m.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer='Adam')

def preprocessing(img_path, label):
  img = tf.image.decode_image(tf.io.read_file(img_path))
  img = tf.image.random_crop(img, size=IMAGE_SIZE)
  img = tf.cast(img, tf.float32)
  img = img / 255.0
  label = tf.cast(label, tf.float32)
  img.set_shape(IMAGE_SIZE)
  return img, label

train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).map(preprocessing).repeat().batch(batch_size).prefetch(buffer_size=AUTOTUNE)

time_start = time.time()
m.fit(train_data, epochs=epochs, steps_per_epoch=steps_per_epoch)

time_end = time.time()

print(f'Total time:{(time_end-time_start)/60.0:.3f}[min]')
print(f'Time per step:{(time_end-time_start)/steps_per_epoch*epochs:.3f} [sec]')

結果

Total time:0.446[min]
Time per step:0.803 [sec]

だいたい1ステップあたり0.8秒かかりました。
ここから工夫して学習の高速化を実現しようと思います。

並列マッピング

Datasetオブジェクトのmap関数を並列で動かします。
データの抽出部分をマルチプロセスで処理するので高速になるはず。

ソースコード

前節の

train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).map(preprocessing).repeat().batch(batch_size).prefetch(buffer_size=AUTOTUNE)

の部分を以下のように書き換えます。

train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).repeat().map(preprocessing, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

結果

Total time:3.726[min]
Time per step:6.707 [sec]

なぜか遅くなってしまいました。
Google Colaboratoryの仕様でしょうか?(要調査)

キャッシング

ソースコード

キャッシングとは読み込んだデータをRAMやストレージなどに一時的に保持しておく機能です。

train_data = tf.data.Dataset.from_tensor_slices((train_img_paths, train_label))
train_data = train_data.shuffle(num_train_imgs).repeat().map(preprocessing, num_parallel_calls=AUTOTUNE).batch(batch_size).cache()

結果

Total time:7.014[min]
Time per step:12.625 [sec]

またもや、高速化することはできませんでした。
原因としてはmap関数上では画像の読み取りと画像データの変換を一度に実行する仕様になっていることが悪いのかなと思います。
画像の読み取りと画像データの変換を分離してやる構造が必要ですね。(今後の課題)

ベクトル化マッピング

ユーザ定義のmap関数だと、処理の都合上オーバーヘッドが発生するそうです。
そこでユーザ定義のmap関数をベクトル化、すなわち入力を一度に処理するように変更するとより早くなるそうです。
具体的にはデータの変換→バッチ処理の順番でなく、バッチ処理→データの変換で実装することが推奨されています。

時間の都合上まだ実験できていませんが、記事の冒頭で示したURLでの実験では、最大で30倍の速さになっています。

まとめ

map関数の並列化、キャッシングを実験してみましたが、いずれも高速化には繋がりませんでした。

原因は複数存在すると思うので、今後調査が必要ですね。
アドバイスなどありましたら、教えていただけけると嬉しいです。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away