KerasのImageDataGeneratorはData Augmentation(水増し)にとても便利ですが、最近水増しの方法がいろいろ研究されてきて物足りないなと感じることがあります。KerasオリジナルのImageDataGeneratorを継承・拡張して、独自のImageDataGeneratorを作る方法を考えてみました。
クラス継承なのでオリジナルのImageDataGeneratorと全く同じ感覚で使えます。
参考
以下のサイトを参考にしました。
-
Kerasでデータ拡張(Data Augmentation)後の画像を表示する
flow_from_dirctoryで流した画像をpyplotで表示させる方法について解説されています。 -
新たなdata augmentation手法mixupを試してみた
Mix-upの原著論文やその実装について解説されています。 -
Extending Keras' ImageDataGenerator to Support Random Cropping
Random CroppingをImageDataGeneratorからつなげる方法が解説されています。これを改良しました。
Kaggleのcat-datasetをサンプル画像としました。サイズの同一の画像4枚を「image/cats」というフォルダに置いておきます。フォルダ名は何でもいいんですが、サブディレクトリに入れておかないとKerasがうまく認識してくれません。
コード
こんな感じでImageDataGeneratorを継承させます。引数がやたら多いですが、オーバーライドするときにVisualStudioの入力補完が働くので特に大変だったという印象はないです。
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
class MyImageDataGenerator(ImageDataGenerator):
def __init__(self, featurewise_center = False, samplewise_center = False,
featurewise_std_normalization = False, samplewise_std_normalization = False,
zca_whitening = False, zca_epsilon = 1e-06, rotation_range = 0.0, width_shift_range = 0.0,
height_shift_range = 0.0, brightness_range = None, shear_range = 0.0, zoom_range = 0.0,
channel_shift_range = 0.0, fill_mode = 'nearest', cval = 0.0, horizontal_flip = False,
vertical_flip = False, rescale = None, preprocessing_function = None, data_format = None, validation_split = 0.0,
random_crop = None, mix_up_alpha = 0.0):
# 親クラスのコンストラクタ
super().__init__(featurewise_center, samplewise_center, featurewise_std_normalization, samplewise_std_normalization, zca_whitening, zca_epsilon, rotation_range, width_shift_range, height_shift_range, brightness_range, shear_range, zoom_range, channel_shift_range, fill_mode, cval, horizontal_flip, vertical_flip, rescale, preprocessing_function, data_format, validation_split)
# 拡張処理のパラメーター
# Mix-up
assert mix_up_alpha >= 0.0
self.mix_up_alpha = mix_up_alpha
# Random Crop
assert random_crop == None or len(random_crop) == 2
self.random_crop_size = random_crop
# ランダムクロップ
# 参考 https://jkjung-avt.github.io/keras-image-cropping/
def random_crop(self, original_img):
# Note: image_data_format is 'channel_last'
assert original_img.shape[2] == 3
if original_img.shape[0] < self.random_crop_size[0] or original_img.shape[1] < self.random_crop_size[1]:
raise ValueError(f"Invalid random_crop_size : original = {original_img.shape}, crop_size = {self.random_crop_size}")
height, width = original_img.shape[0], original_img.shape[1]
dy, dx = self.random_crop_size
x = np.random.randint(0, width - dx + 1)
y = np.random.randint(0, height - dy + 1)
return original_img[y:(y+dy), x:(x+dx), :]
# Mix-up
# 参考 https://qiita.com/yu4u/items/70aa007346ec73b7ff05
def mix_up(self, X1, y1, X2, y2):
assert X1.shape[0] == y1.shape[0] == X2.shape[0] == y2.shape[0]
batch_size = X1.shape[0]
l = np.random.beta(self.mix_up_alpha, self.mix_up_alpha, batch_size)
X_l = l.reshape(batch_size, 1, 1, 1)
y_l = l.reshape(batch_size, 1)
X = X1 * X_l + X2 * (1-X_l)
y = y1 * y_l + y2 * (1-y_l)
return X, y
def flow_from_directory(self, directory, target_size = (256,256), color_mode = 'rgb',
classes = None, class_mode = 'categorical', batch_size = 32, shuffle = True,
seed = None, save_to_dir = None, save_prefix = '', save_format = 'png',
follow_links = False, subset = None, interpolation = 'nearest'):
# 親クラスのflow_from_directory
batches = super().flow_from_directory(directory, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation)
# 拡張処理
while True:
if self.mix_up_alpha > 0:
while True:
batch_x_2, batch_y_2 = next(batches)
m1, m2 = batch_x.shape[0], batch_x_2.shape[0]
if m1 < m2:
batch_x_2 = batch_x_2[:m1]
batch_y_2 = batch_y_2[:m1]
break
elif m1 == m2:
break
batch_x, batch_y = self.mix_up(batch_x, batch_y, batch_x_2, batch_y_2)
# Random crop
if self.random_crop_size != None:
x = np.zeros((batch_x.shape[0], self.random_crop_size[0], self.random_crop_size[1], 3))
for i in range(batch_x.shape[0]):
x[i] = self.random_crop(batch_x[i])
batch_x = x
# 返り値
yield (batch_x, batch_y)
※ミニバッチで端数出たときに止まるバグあったので直しました
Mix-upだろうがRandom croppingだろうが、親クラスのflow_from_directoryから取得したバッチをどんどん加工していけばいいだけです。Random erasingなんかも簡単に実装できます。Numpy配列をいじるだけなんで。
テスト
オリジナルのImageDataGeneratorを可視化
まずはKerasオリジナルのImageDataGeneratorをflow_from_directoryして可視化してみましょう。
import matplotlib.pyplot as plt
# 参考:https://qiita.com/takurooo/items/c06365dd43914c253240
def show_imgs(imgs, row, col):
if len(imgs) != (row * col):
raise ValueError("Invalid imgs len:{} col:{} row:{}".format(len(imgs), row, col))
fig = plt.figure(figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for i, img in enumerate(imgs):
plot_num = i+1
ax = fig.add_subplot(row, col, plot_num, xticks=[], yticks=[])
ax.imshow(img)
plt.show()
datagen = ImageDataGenerator(
rescale=1/255.0)
max_img_num = 12
imgs = []
for d in datagen.flow_from_directory("images", batch_size=1, target_size=(375, 500), classes=None):
# target_size = (height, width)なのに注意
imgs.append(np.squeeze(d[0], axis=0))
if (len(imgs) % max_img_num) == 0:
break
show_imgs(imgs, row=4, col=3)
スケールを255で割っただけの画像をそのまま表示するスクリプトです。
flow_from_directroyはデフォルトでshuffle=Trueされているので、4枚の画像を1ループの間にシャッフルされながら表示されているのが確認できます。
独自のImageDataGeneratorのテスト
上記のImageDataGeneratorを継承したものに置き換えるだけです。とても単純です。
datagen = MyImageDataGenerator(
rescale=1/255.0,
mix_up_alpha=2,
random_crop=(375, 375))
結果はこちら。Random croppingとMix-upが同時に機能しているのがわかります。
もちろん継承しているので、オリジナルのImageDataGeneratorの処理を使うこともできます。ここではrescaleはオリジナル側でやらせています。水平方向の回転(horizontal_flip)なども相変わらず使えます。
ここでは効果がわかりやすいようにMix-upのα=2としましたが、通常はα=0.2などもっと小さい数のほうがいいと思います。Mix-upの乱数はベータ分布$Be(\alpha, \alpha)$で生成されますので、ベータ分布の値を変えるとどんなグラフになるか気になる方はこちらで試してみてください。
まとめ
KerasのImageDataGeneratorを継承させると容易に拡張でき、最新のData Augmentationもラクラク試せることが確認できました。
MNISTやCIFARなど変数に格納されているデータを使いたい場合は、flow_from_directoryではなくflowのほうをオーバーライドすればOK(なはず)です。