1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

超解像用の自作ImageDataGeneratorを作る

Posted at

TensorFlow(keras)のImageDataGeneratorは便利ですよね。

ただ超解像タスクや、セグメンテーションタスクのように入力と出力が共に画像で、同じ変形をしたい時などはそのままだと使えません。

なのでクラス継承を使って超解像用のImageDataGeneratorを作ります。

環境

Google Colabを使います。
最近Pro+に課金しました。

自作ジェネレータ

結論を言うとImageDataGeneratorとその中のflowやflow_from_directoryを継承すればokです。

今回はndarray形式のデータを水増しするということを考えてflowを継承します。
水増し手法は何もしない、左右反転、90,180,270度回転の組み合わせで合計8種類をランダムに行うものとします。

以下そのコードです。

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

class MyGenerator(ImageDataGenerator):
    def __init__(self, 
                 basic_aug=False, *args, **kwargs):
        super().__init__(*args, **kwargs) # 継承
        self.basic_aug = basic_aug
       
    def img_basic_aug(self, LR, HR):
        mode = np.random.randint(0,8) # 0~7までの整数をランダムに選択
        if mode == 0: # 左右反転
            low_flip_img = tf.image.flip_left_right(LR).numpy()
            high_flip_img = tf.image.flip_left_right(HR).numpy()
            return low_flip_img, high_flip_img
        elif mode == 1: # 180度回転
            low_rotate_180_img = tf.image.rot90(LR, k=2).numpy()
            high_rotate_180_img = tf.image.rot90(HR, k=2).numpy()
            return low_rotate_180_img, high_rotate_180_img
        elif mode == 2: 左右反転&180度回転
            low_rotate_180_flip_img = tf.image.rot90(tf.image.flip_left_right(LR), k=2).numpy()
            high_rotate_180_flip_img = tf.image.rot90(tf.image.flip_left_right(HR), k=2).numpy()
            return low_rotate_180_flip_img, high_rotate_180_flip_img
        elif mode == 3: # 90度回転
            low_rotate_90_img = tf.image.rot90(LR, k=1).numpy()
            high_rotate_90_img = tf.image.rot90(HR, k=1).numpy()
            return low_rotate_90_img, high_rotate_90_img
        elif mode == 4: # 270度回転
            low_rotate_270_img = tf.image.rot90(LR, k=3).numpy()
            high_rotate_270_img = tf.image.rot90(HR, k=3).numpy()
            return low_rotate_270_img, high_rotate_270_img
        elif mode == 5: # 左右反転&90度回転
            low_rotate_90_flip_img = tf.image.rot90(tf.image.flip_left_right(LR), k=1).numpy()
            high_rotate_90_flip_img = tf.image.rot90(tf.image.flip_left_right(HR), k=1).numpy()
            return low_rotate_90_flip_img, high_rotate_90_flip_img
        elif mode == 6: # 左右反転&270度回転
            low_rotate_270_flip_img = tf.image.rot90(tf.image.flip_left_right(LR), k=3).numpy()
            high_rotate_270_flip_img = tf.image.rot90(tf.image.flip_left_right(HR), k=3).numpy()
            return low_rotate_270_flip_img, high_rotate_270_flip_img
        elif mode == 7: # 何もしない
            return LR, HR

    
    def flow(self, *args, **kwargs):
        batches = super().flow(*args, **kwargs) # flow関数を継承

        while True: # 無限ループ
            batch_LR, batch_HR = next(batches) # ミニバッチ1つを取り出す
            if self.basic_aug == True:
                for i in range(batch_LR.shape[0]):
                    # ミニバッチの中の画像それぞれに対して水増しを適用
                    batch_LR[i], batch_HR[i] = self.img_basic_aug(batch_LR[i], batch_HR[i])
            yield(batch_LR, batch_HR)

これを基本として、cutmixなどの流行りの手法も関数化して__self__の中とflowの中に書けばオリジナルのGeneratorが作成できます。

まとめ

既存のフレームワークになくてもクラス継承を使うことでやりたいことが行える。
自分で実装してみることが大事だと思う。

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?