TensorFlow(keras)のImageDataGeneratorは便利ですよね。
ただ超解像タスクや、セグメンテーションタスクのように入力と出力が共に画像で、同じ変形をしたい時などはそのままだと使えません。
なのでクラス継承を使って超解像用のImageDataGeneratorを作ります。
環境
Google Colabを使います。
最近Pro+に課金しました。
自作ジェネレータ
結論を言うとImageDataGeneratorとその中のflowやflow_from_directoryを継承すればokです。
今回はndarray形式のデータを水増しするということを考えてflowを継承します。
水増し手法は何もしない、左右反転、90,180,270度回転の組み合わせで合計8種類をランダムに行うものとします。
以下そのコードです。
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
class MyGenerator(ImageDataGenerator):
def __init__(self,
basic_aug=False, *args, **kwargs):
super().__init__(*args, **kwargs) # 継承
self.basic_aug = basic_aug
def img_basic_aug(self, LR, HR):
mode = np.random.randint(0,8) # 0~7までの整数をランダムに選択
if mode == 0: # 左右反転
low_flip_img = tf.image.flip_left_right(LR).numpy()
high_flip_img = tf.image.flip_left_right(HR).numpy()
return low_flip_img, high_flip_img
elif mode == 1: # 180度回転
low_rotate_180_img = tf.image.rot90(LR, k=2).numpy()
high_rotate_180_img = tf.image.rot90(HR, k=2).numpy()
return low_rotate_180_img, high_rotate_180_img
elif mode == 2: 左右反転&180度回転
low_rotate_180_flip_img = tf.image.rot90(tf.image.flip_left_right(LR), k=2).numpy()
high_rotate_180_flip_img = tf.image.rot90(tf.image.flip_left_right(HR), k=2).numpy()
return low_rotate_180_flip_img, high_rotate_180_flip_img
elif mode == 3: # 90度回転
low_rotate_90_img = tf.image.rot90(LR, k=1).numpy()
high_rotate_90_img = tf.image.rot90(HR, k=1).numpy()
return low_rotate_90_img, high_rotate_90_img
elif mode == 4: # 270度回転
low_rotate_270_img = tf.image.rot90(LR, k=3).numpy()
high_rotate_270_img = tf.image.rot90(HR, k=3).numpy()
return low_rotate_270_img, high_rotate_270_img
elif mode == 5: # 左右反転&90度回転
low_rotate_90_flip_img = tf.image.rot90(tf.image.flip_left_right(LR), k=1).numpy()
high_rotate_90_flip_img = tf.image.rot90(tf.image.flip_left_right(HR), k=1).numpy()
return low_rotate_90_flip_img, high_rotate_90_flip_img
elif mode == 6: # 左右反転&270度回転
low_rotate_270_flip_img = tf.image.rot90(tf.image.flip_left_right(LR), k=3).numpy()
high_rotate_270_flip_img = tf.image.rot90(tf.image.flip_left_right(HR), k=3).numpy()
return low_rotate_270_flip_img, high_rotate_270_flip_img
elif mode == 7: # 何もしない
return LR, HR
def flow(self, *args, **kwargs):
batches = super().flow(*args, **kwargs) # flow関数を継承
while True: # 無限ループ
batch_LR, batch_HR = next(batches) # ミニバッチ1つを取り出す
if self.basic_aug == True:
for i in range(batch_LR.shape[0]):
# ミニバッチの中の画像それぞれに対して水増しを適用
batch_LR[i], batch_HR[i] = self.img_basic_aug(batch_LR[i], batch_HR[i])
yield(batch_LR, batch_HR)
これを基本として、cutmixなどの流行りの手法も関数化して__self__
の中とflow
の中に書けばオリジナルのGeneratorが作成できます。
まとめ
既存のフレームワークになくてもクラス継承を使うことでやりたいことが行える。
自分で実装してみることが大事だと思う。