この記事は
Tensorflowのtf.data.Dataset.shuffleの引数であるbuffer_sizeについて、公式ドキュメントを読んだだけではいまいち理解できなかったので、実際に動かして確認した結果のメモです。
※確認はver 2.0.0で行っています。
ドキュメントに書いてあること
buffer_sizeで指定された要素数のバッファを用意して、そこからランダムにサンプリングする。サンプリングされた値は次の要素で置き換えられる。
完全にシャッフルしたい場合はデータ数よりも大きい値を指定すべし。
確認用コード
0~999の数字を順に生成して、shuffleでどのように並べ替えられるか確かめる。
import tensorflow as tf
import matplotlib.pyplot as plt
fig, ax = plt.subplots(4,2, figsize=(8,12))
for i, size in enumerate([1, 100, 500, 10000]):
seq = [v for v in range(1000)]
val = [v.numpy() for v in tf.data.Dataset.range(1000).shuffle(size)]
ax[i][0].scatter(seq, val, s=3, label="buffer_size: "+str(size))
ax[i][0].plot(seq, color="black")
ax[i][0].legend(loc="upper left")
ax[i][1].hist(val[500:])
ax[i][1].set_xlim(0, 1000)
plt.show()
結果
左側の図
- 横軸:shuffle後のインデックス
- 縦軸:shuffle前のインデックス
右側の図:shuffle後の後半データ(500~999番目)の元のインデックスのヒストグラム
buffer_size=500で説明すると、shuffle後に最初に値を取得すると、元のインデックスで0~499番目のデータがサンプルされる。次のデータは0~500番目から選ばれる。ただし、すでに選ばれたものは除外される。
これを繰り返すと、n番目のサンプルデータは元のインデックスで0~(500+n)番目のデータから選ばれる。ただし、前半はすでにサンプルされてしまっている可能性が高いため、前半インデックスによって偏りが生じる。
ゆえに、均等にシャッフルするためには、buffer_sizeをデータ数より大きくしなければならない。
結論
データが多すぎてメモリに乗らないなどの特別な理由がなければbuffer_sizeはデータ数と同じ値にしとけばよさそうです。