TensorFlow 2には、tf.data.Dataset
が用意されていて、データに関する操作や処理を簡単に行うことができます。
Datasetの使い方については公式チュートリアルを含め、多くの情報が存在しているので説明を割愛しますが、Datasetを使っていて「ハマるポイント」について紹介したいと思います。
TL;DR
- shuffleはデータを取り出す度に再シャッフルされる
- シャッフル順序を固定したければ
reshuffle_each_iteration=False
を指定 - shuffle後にcacheを使うとシャッフルが機能しなくなる(cache→shuffleの順番で使う)
- batch→shuffleだと、バッチ単位でシャッフル(バッチ内部の並び順はそのまま)
- shuffle→batchだと、バッチ内部のデータもシャッフルされる
- take/skipでtrain/test用のデータ分割を簡単に行える
- shuffle後にtake/skipを使うとエポック単位にデータが混ざるので注意
- shuffle後にtake/skipする場合は
reshuffle_each_iteration=False
を指定 - takeしたデータ内でシャッフルしたければ、再度shuffleすれば良い
Datasetに関する操作
今回は10個の数値(0~9)をDataset化して、shuffle
,cache
,take
,skip
,batch
関数を使ったときの挙動をまとめています。また、実際には、これらの関数を組み合わせて使う場合も多いため、組み合わせたときの挙動についても試しています。
操作はJupyter Labを使って試しているため、実際のソースと実行結果を、GitHubにNotebook形式でそのままアップしています。そのため、詳細については、以下のNotebookファイルをご参照ください。
https://github.com/KurozumiGH/tf2-note/blob/main/Notes/tf-data-Dataset.ipynb
よくある間違いパターンの抜粋(詳細は上記Notebook参照)
準備したDatasetをシャッフルして、それをtrain/test用に分けて使いたい場合、次のように書いてしまうとNGです。train/test用のデータは完全に分割して混ぜないのが鉄則ですが、次のような書き方では、train/testデータが混ざってしまいます。
# この書き方だと各エポック単位にtrain/test用のデータが混ざってしまう
ds = make_ds().shuffle(10)
ds_x = ds.take(7) # 最初の7つをtrain用
ds_y = ds.skip(7) # 残りの3つをtest用
for i in range(LOOPS):
print_ds(ds_x, ds_y)
上記コードの実行結果は以下のようになります。(1行が、1エポックで使われるデータセットに相当します)
見ての通り、train/test用のデータが完全に混ざってしまっています。
[0, 4, 6, 7, 3, 2, 8][4, 7, 9]
[1, 2, 8, 7, 3, 6, 5][8, 7, 9]
[0, 7, 2, 1, 4, 6, 9][8, 1, 7]
[5, 4, 9, 7, 2, 6, 0][9, 3, 7]
[5, 0, 3, 1, 8, 7, 9][0, 7, 4]
[6, 3, 0, 2, 1, 5, 4][4, 9, 5]
[6, 5, 7, 9, 2, 8, 0][9, 4, 7]
[0, 5, 8, 7, 1, 4, 9][1, 6, 0]
[9, 1, 4, 5, 3, 6, 7][1, 2, 5]
[4, 6, 0, 7, 1, 3, 8][4, 8, 3]
[4, 0, 8, 9, 5, 3, 1][4, 8, 9]
[9, 8, 4, 3, 5, 1, 6][2, 4, 5]