はじめに
前回、Fashion-MNISTをCNNで分類し、Test Accuracy: 91.3でした。そこで次はData Augmentationを行うことで、どれだけ精度が上がるのか確認してみようと思ったのですが、tf.data.DatasetをData Augmentationする方法がよくわからなかったので、まずはそれを調べてみることにしました。
環境
- Google Colaboratory
- TensorFlow 2.0 Alpha
コード
こちらです。
コード解説
データセット取得
import tensorflow_datasets as tfds
batch_size = 64
dataset = tfds.load('cifar10')
dataset = dataset['train']
dataset = dataset.batch(batch_size)
今回はCIFAR-10を題材にしてみました。本当は'cats_vs_dogs'
にしようとしたのですが、バッチ単位に縦横のサイズを合わせないと、Data Augmentationでエラーになってしまったので諦めました。
Data Augmentationの設定
世の中にはたくさんデータはあるもののラベル付きのデータというのはそう多くありません。新しくラベルを付けようとするとコストがかかるのですが、今ラベルの付いているデータを少し加工してやれば簡単にラベル付きのデータを増やすことができます。そうして増やしたより多くのデータを学習することにより汎化性能が高まり精度が上がることが期待できます。データの水増しなどとも言います。これをKerasのImageDataGeneratorを使って実現します。
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rotation_range = 30, horizontal_flip = True, zoom_range = 0.2)
今回は回転、左右反転、ズームを行うことにしました。本当はcutoutも実現したかったのですが、公式ドキュメントを読む限り実現できないような気がします。
Data Augmentation、結果表示
import math
import numpy as np
import matplotlib.pyplot as plt
column_size = math.floor(math.sqrt(batch_size))
row_size = math.ceil(batch_size / column_size)
fig, ax = plt.subplots(row_size, column_size, figsize = (row_size * 2, column_size * 2), subplot_kw = {'xticks': (), 'yticks': ()})
for axis in ax:
for a in axis:
a.set_axis_off()
for data_list in dataset:
image_list = data_list['image']
label_list = data_list['label']
row, column = 0, 0
for x, y in datagen.flow(image_list, label_list, batch_size = batch_size):
print(x.shape)
for _x in x:
_x = tf.cast(_x, tf.uint8)
ax[row, column].imshow(_x)
if column == column_size - 1:
column = 0
row += 1
else:
column += 1
break
break
data_list
はbatch_size
のイメージとラベルのセットを持っています。image_list
やlabel_list
はndarrayで、image_list
のshapeは(64, 32, 32, 3)です。これらをdatagen.flow
の引数に与えることで、入力のイメージをランダムに回転、左右反転、ズームした画像が生成されx
に代入されます。image_list
とx
は1対1の関係になっています。あとは結果を表示しているだけです。matplotlibで枠線や目盛りを消す方法を調べるのに地味に時間がかかりましたが・・・。datagen.flow
は無限ループしますので1つ目のbreak
は必須です。
出力結果
何となく回転やズームはわかるかと思います。左右反転はわからないですね・・・。
最後に
tf.data.Datasetを入力にImageDataGeneratorを使っているコードがあまり世の中にないような気がしたので作成してみました。どなたかのお役に立てると幸いです。次はこれを使いどれだけ精度が上がるか確認してみます。