LoginSignup
10
8

More than 3 years have passed since last update.

時間のない方むけ

学習データセットの設定は下記が良い。

ds = ds.shuffle(count)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(buffer_size=AUTOTUNE)

テスト用はshuffleをぬいて、drop_remainderをFalseにする。
(drop_remainderはpytorchで言えばdrop_lastと同じものです)

やりたいこと

大抵のニューラルネットワークの学習時には1エポックで、下記のようにバッチを取得して学習したいのではないでしょうか。下記は全20データを3バッチずつとってくるイメージ。これをtensorflowのdsではどうやったらできるか?というのが今回の目的。2021年に書くことじゃないだろ、と思うんですが、最近知ったこともあるので備忘録として・・
image.png

この記事が機能の詳細は、かなり細かく載せていますので、参考に。
TensorFlowで使えるデータセット機能が強かった話

Tutorialに従ってみる

Flowerのデータセットを例にとったTutorialでのデータセットのコードを見てみましょう。
このshuffle_and_repeatというやつが、一体なにをしてくれるんでしょうか。

ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)

具体的に100データに対して、batch_size=16でこれを動かしてみます。
(無限ループになるので10ループでbreakしています)

import tensorflow as tf

AUTOTUNE = tf.data.experimental.AUTOTUNE

ds = tf.data.Dataset.from_tensor_slices(tf.range(100))

ds2 = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=100))
ds2 = ds2.batch(16)
ds2 = ds2.prefetch(buffer_size=AUTOTUNE)
for i, data in enumerate(ds2):
    print(i)
    print(data)

    if i > 10:
        break

あっという間に同じデータが繰り返されているのがわかります。
tensorflowの公式tutorialt通りにやると、挙動が欲しいものと違う気がします。なんでtensorflowはこの方法を一番シンプルな公式tutorialに載せてるんだろう。と少しだけ感じますが、考察は下の方に書いています。

0
tf.Tensor([49 46 72 99  1 17 78 71 89 97 48 12 55 94 35 61], shape=(16,), dtype=int32)
1
tf.Tensor([29 52 79 84  4  2 87 44 56 42 16 33 86  6 14 80], shape=(16,), dtype=int32)
2
tf.Tensor([38 70 11 25 20  5 68 81 58  9 53 24 34 63 91 90], shape=(16,), dtype=int32)
3
tf.Tensor([13 30 92 18 67  0  8 74 65 40 47 15 43 85 28 22], shape=(16,), dtype=int32)
4
tf.Tensor([88  7 96 83 31 36 41 23 37  3 51 64 19 54 62 73], shape=(16,), dtype=int32)
5
tf.Tensor([57 93 21 59 98 69 95 32 75 60 66 26 10 82 45 50], shape=(16,), dtype=int32)
6
tf.Tensor([76 27 39 77 36 45 73 99 12 11  0 57 42 51 23  9], shape=(16,), dtype=int32)
7
tf.Tensor([98 39  4 17 91 50 15 24 65 86 52 81 35 18  3 26], shape=(16,), dtype=int32)
8
tf.Tensor([33 70 89 88 67 31  6 53 55 14 97 64 84 43 54 37], shape=(16,), dtype=int32)
9
tf.Tensor([25 32  5 92 75 21 90 29 19 40 28 79 69 48 66 46], shape=(16,), dtype=int32)
10
tf.Tensor([58 74 93 82 20 96  8 95 47 30 76 85  7 94 22  2], shape=(16,), dtype=int32)
11
tf.Tensor([83 16 68 62  1 41 13 78 38 44 10 87 63 49 34 59], shape=(16,), dtype=int32)

解決策

すごくシンプルに下記を書けばOK。
これは冒頭にも書きました。シャッフルとバッチだけでいいんです。

ds = ds.shuffle(count)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(buffer_size=AUTOTUNE)

なぜrepeat書いてあるのか?問題

では公式tutorialではそもそもなぜrepeatをしているのでしょう。
これは推測ですが、1epochって別に厳密な定義がなくて、1epochで何ステップ回すか、というのは自由に考えれば良いという意味だと思います。例えばaugmentationをする場合、1epochで全画像を1周することにどれだけ意味があるか?という問題があります。画像がloadされるたびに、RandomにAugmentationされるとしたら、データセット1周ってなんでしょう。

そういう観点から言っても、1epochで何step学習するかまで好きに設計して学習をしたほうが良いという、tensorflowのtutorial開発者の少し深い考え方かなぁ、と。下記のkerasのfit関数も、steps_per_epochは設定しなければcount/batch_sizeで計算されますが、設定してあげれば自由に設定が可能です。

model.fit(ds, epochs=1, steps_per_epoch=100)

ちなみにrepeatを使わないdsを引き渡して、count/batch_size以上のsteps_per_epochを設定すると下記のようなエラーがでます。

29/36 [=======================>......] - ETA: 11s - loss: 4.4318 - accuracy: 0.2248WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 36 batches). You may need to use the repeat() function when building your dataset.
36/36 [==============================] - 53s 1s/step - loss: 4.0827 - accuracy: 0.2321
<tensorflow.python.keras.callbacks.History at 0x10c9b3750>

augmentationを使わないような例では、むしろわかりづらいので、今回書いたshuffleとbatchだけを使ったほうがシンプルで良いかなぁ、と考えます。他にもなにか情報や考え方をご存じの方は是非教えて下さい。datasetを使って良いニューラルネットワークライフを!

10
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
8