Posted at

tf.data.Datasetを入力にImageDataGeneratorを使ってData Augmentation(水増し)を行う


はじめに

前回、Fashion-MNISTをCNNで分類し、Test Accuracy: 91.3でした。そこで次はData Augmentationを行うことで、どれだけ精度が上がるのか確認してみようと思ったのですが、tf.data.DatasetをData Augmentationする方法がよくわからなかったので、まずはそれを調べてみることにしました。


環境


  • Google Colaboratory

  • TensorFlow 2.0 Alpha


コード

こちらです。


コード解説


データセット取得

import tensorflow_datasets as tfds

batch_size = 64

dataset = tfds.load('cifar10')
dataset = dataset['train']
dataset = dataset.batch(batch_size)

今回はCIFAR-10を題材にしてみました。本当は'cats_vs_dogs'にしようとしたのですが、バッチ単位に縦横のサイズを合わせないと、Data Augmentationでエラーになってしまったので諦めました。


Data Augmentationの設定

世の中にはたくさんデータはあるもののラベル付きのデータというのはそう多くありません。新しくラベルを付けようとするとコストがかかるのですが、今ラベルの付いているデータを少し加工してやれば簡単にラベル付きのデータを増やすことができます。そうして増やしたより多くのデータを学習することにより汎化性能が高まり精度が上がることが期待できます。データの水増しなどとも言います。これをKerasのImageDataGeneratorを使って実現します。

import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(rotation_range = 30, horizontal_flip = True, zoom_range = 0.2)

今回は回転、左右反転、ズームを行うことにしました。本当はcutoutも実現したかったのですが、公式ドキュメントを読む限り実現できないような気がします。


Data Augmentation、結果表示

import math

import numpy as np
import matplotlib.pyplot as plt

column_size = math.floor(math.sqrt(batch_size))
row_size = math.ceil(batch_size / column_size)
fig, ax = plt.subplots(row_size, column_size, figsize = (row_size * 2, column_size * 2), subplot_kw = {'xticks': (), 'yticks': ()})

for axis in ax:
for a in axis:
a.set_axis_off()

for data_list in dataset:
image_list = data_list['image']
label_list = data_list['label']

row, column = 0, 0
for x, y in datagen.flow(image_list, label_list, batch_size = batch_size):
print(x.shape)
for _x in x:
_x = tf.cast(_x, tf.uint8)
ax[row, column].imshow(_x)
if column == column_size - 1:
column = 0
row += 1
else:
column += 1
break
break

data_listbatch_sizeのイメージとラベルのセットを持っています。image_listlabel_listはndarrayで、image_listのshapeは(64, 32, 32, 3)です。これらをdatagen.flowの引数に与えることで、入力のイメージをランダムに回転、左右反転、ズームした画像が生成されxに代入されます。image_listxは1対1の関係になっています。あとは結果を表示しているだけです。matplotlibで枠線や目盛りを消す方法を調べるのに地味に時間がかかりましたが・・・。datagen.flowは無限ループしますので1つ目のbreakは必須です。


出力結果



何となく回転やズームはわかるかと思います。左右反転はわからないですね・・・。


最後に

tf.data.Datasetを入力にImageDataGeneratorを使っているコードがあまり世の中にないような気がしたので作成してみました。どなたかのお役に立てると幸いです。次はこれを使いどれだけ精度が上がるか確認してみます。