tf.data.Dataset APIはストリーミング方式の入力パイプラインを記述する便利なAPIですが、このAPIで使えるsample_from_datasetsというメソッドが強力だったので紹介したいと思います。(このメソッドに関しては既に公式のチュートリアル記事やクラスキャットさんの記事がありますが、今回は画像データの多クラス分類についてオーグメンテーションのやり方も含めて説明しようと思います)
tf.dataの基本的な使い方
まずは検証データセットを用意しましょう。(以下、Xは画像のファイルパス、yはラベルを表す0以上の整数値としています)
まず、from_tensor_slicesメソッドを使って画像のデータセットを作ります。
path_ds = tf.data.Dataset.from_tensor_slices(X_test)
image_ds = path_ds.map(load_and_preprocess_image)
ここで登場するload_and_preprocess_imageは画像の読み込みと整形を行う、以下のような関数です。
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [height, width])
return image
def load_and_preprocess_image(path):
image = tf.io.read_file(path)
return preprocess_image(image)
次にラベルについてもデータセットを作ります。
label_ds = tf.data.Dataset.from_tensor_slices(y_test)
二つのデータセットをzipすることでペアのデータセットを作ります。この時、画像とラベルが同じ順番で並んでいることに注意してください。
test_ds = tf.data.Dataset.zip((image_ds, label_ds))
test_ds = test_ds.batch(BATCH_SIZE)
test_ds = test_ds.prefetch(BUFFER_SIZE)
sample_from_datasetsの使い方
次に訓練データセットを用意します。今回はラベルごとの出現頻度を調整するためにsample_from_datasetsメソッドを使用します。
ラベル毎のパイプライン
まずはsample_from_datasetsメソッドに渡すためのデータセットの配列を作ります。ここで登場しているX_train_groupsとy_train_groupsは、それぞれ画像ファイルのパスとラベルを、ラベルごとにグループ化したものです。
train_ds = []
for paths, labels in zip(X_train_groups, y_train_groups):
path_ds = tf.data.Dataset.from_tensor_slices(paths)
image_ds = path_ds.map(load_and_preprocess_image)
image_ds = image_ds.map(augmentation_layer)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
image_and_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
image_and_label_ds = image_and_label_ds.repeat()
train_ds.append(image_and_label_ds)
個々のデータセットの作り方は検証データの時とほとんど同じですが、変更箇所が二箇所だけあります。一つ目はrepeat変換の存在です。sample_from_datasetsメソッドの挙動の関係で、この位置にrepeatが必要になります(詳しくは後述)。二つ目はaugmentation_layer関数をmapしている部分です。この関数について次項で説明します。
オーグメンテーション
tf.data.Datasetでは前処理レイヤーを適用することで、前処理されたデータのバッチを生成することができます。今回はそのうちデータオーグメンテーションに関するレイヤーを適用します。
def augmentation_layer(x):
x = tf.keras.layers.RandomRotation((-1/6, 1/6))(x)
x = tf.keras.layers.RandomTranslation((-0.1, 0.1), (-0.1, 0.1))(x)
x = tf.keras.layers.RandomZoom((-0.1, 0.1))(x)
x = tf.keras.layers.RandomFlip(mode="horizontal")(x)
return x
sample_from_datasets
sample_from_datasetsメソッドにデータセットと各データセットのウェイトを渡します。第一引数は先ほど作成したデータセットの配列で、引数weightsにはそれぞれのデータセットのピック確率のリストです。今回は等確率にしたいのでクラス数の逆数を並べたものを使います。
train_ds = tf.data.Dataset.sample_from_datasets(
train_ds, weights=[1/num_classes]*num_classes)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.prefetch(BUFFER_SIZE)
注意点
sample_from_datasetsはデータセットのリストから、各データセットのウェイトに従ってサンプリングしています。ところが、個々のデータセットは一周するとそれ以上データを取り出せなくなります。オーバーサンプリングをするためには1epoch中に同じデータを何度も取り出す必要があり、それを実現するのがrepeat変換です。
データセットは通常であればデータを全て取り出し終えた時にエポックの終わりを伝達しますが、repeat変換は指定された回数だけそれを無視して再びデータセットの頭からデータの取り出しを始めるというものになります。sample_from_detasetsに渡した各データセットは長さが異なりますが、それぞれのデータセットに対してrepeat変換を適用することでサンプル数の少ないデータセットから何周でもデータを取り出せるようになります。
ただし、前述の通りrepeatを使うとデータセットはエポックの終わりを伝達しなくなるので、学習時にはsteps_per_epochを必ず指定してください。