LoginSignup
19
5

More than 3 years have passed since last update.

TensorFlowのDataset.shuffleのbuffer_size

Posted at

この記事は

Tensorflowのtf.data.Dataset.shuffleの引数であるbuffer_sizeについて、公式ドキュメントを読んだだけではいまいち理解できなかったので、実際に動かして確認した結果のメモです。
※確認はver 2.0.0で行っています。

ドキュメントに書いてあること

buffer_sizeで指定された要素数のバッファを用意して、そこからランダムにサンプリングする。サンプリングされた値は次の要素で置き換えられる。
完全にシャッフルしたい場合はデータ数よりも大きい値を指定すべし。

確認用コード

0~999の数字を順に生成して、shuffleでどのように並べ替えられるか確かめる。

import tensorflow as tf
import matplotlib.pyplot as plt

fig, ax = plt.subplots(4,2, figsize=(8,12))

for i, size in enumerate([1, 100, 500, 10000]):
    seq = [v for v in range(1000)]
    val = [v.numpy() for v in tf.data.Dataset.range(1000).shuffle(size)]
    ax[i][0].scatter(seq, val, s=3, label="buffer_size: "+str(size))
    ax[i][0].plot(seq, color="black")
    ax[i][0].legend(loc="upper left")
    ax[i][1].hist(val[500:])
    ax[i][1].set_xlim(0, 1000)
plt.show()

結果

download.png

左側の図
- 横軸:shuffle後のインデックス
- 縦軸:shuffle前のインデックス

右側の図:shuffle後の後半データ(500~999番目)の元のインデックスのヒストグラム

buffer_size=500で説明すると、shuffle後に最初に値を取得すると、元のインデックスで0~499番目のデータがサンプルされる。次のデータは0~500番目から選ばれる。ただし、すでに選ばれたものは除外される。
これを繰り返すと、n番目のサンプルデータは元のインデックスで0~(500+n)番目のデータから選ばれる。ただし、前半はすでにサンプルされてしまっている可能性が高いため、前半インデックスによって偏りが生じる。

ゆえに、均等にシャッフルするためには、buffer_sizeをデータ数より大きくしなければならない。

結論

データが多すぎてメモリに乗らないなどの特別な理由がなければbuffer_sizeはデータ数と同じ値にしとけばよさそうです。

19
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
5