Posted at

ImageDataGeneratorを拡張しcutoutを実装する


はじめに

以前にImageDataGeneratorを使ってData Augmentation(水増し)を行ったのですが、ImageDataGeneratorが持っていない水増し方法も使いたいと思っていました。今回それを実現してみました。


環境


  • Google Colaboratory

  • TensorFlow 2.0 Alpha


コード

こちらです。


コード解説

import numpy as np

from tensorflow.keras.preprocessing.image import ImageDataGenerator

class CustomImageDataGenerator(ImageDataGenerator):
def __init__(self, cutout_mask_size = 0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cutout_mask_size = cutout_mask_size

def cutout(self, x, y):
return np.array(list(map(self._cutout, x))), y

def _cutout(self, image_origin):
# 最後に使うfill()は元の画像を書き換えるので、コピーしておく
image = np.copy(image_origin)
mask_value = image.mean()

h, w, _ = image.shape
# マスクをかける場所のtop, leftをランダムに決める
# はみ出すことを許すので、0以上ではなく負の値もとる(最大mask_size // 2はみ出す)
top = np.random.randint(0 - self.cutout_mask_size // 2, h - self.cutout_mask_size)
left = np.random.randint(0 - self.cutout_mask_size // 2, w - self.cutout_mask_size)
bottom = top + self.cutout_mask_size
right = left + self.cutout_mask_size

# はみ出した場合の処理
if top < 0:
top = 0
if left < 0:
left = 0

# マスク部分の画素値を平均値で埋める
image[top:bottom, left:right, :].fill(mask_value)
return image

def flow(self, *args, **kwargs):
batches = super().flow(*args, **kwargs)

# 拡張処理
while True:
batch_x, batch_y = next(batches)

if self.cutout_mask_size > 0:
result = self.cutout(batch_x, batch_y)
batch_x, batch_y = result

yield (batch_x, batch_y)

datagen = CustomImageDataGenerator(rotation_range=10, horizontal_flip=True, zoom_range=0.1, cutout_mask_size=16)

ImageDataGeneratorを継承したClassを作成し、flowメソッドをオーバーライドし、そこからcutoutを呼び出しています。


出力結果



ちゃんとcutoutが入っているようです。


参考にしたページ