LoginSignup
2
1

More than 1 year has passed since last update.

KerasのImageDataGeneratorにある画像変形をPreprocessing Layerとして実装する

Last updated at Posted at 2022-01-29

はじめに

keras.preprocessing.image.ImageDataGeneratorはいろんな前処理が実装されています。ただtensorflowでモダンな書き方をする場合、こいつを使うと遅くなります(参考)。
モダンな書き方というのは、(僕の中だけかもしれないですが)tf.data+前処理をlayerに入れてしまうという書き方のことです。詳しくは以下をご覧ください。

ImageDataGeneratorは便利だけど、Preprocessing Layerとして実装したい…ということで、実装してみました。

実装

ImageDataGenratorでできる画像の変形(transformation)とpreprocessingでの対応関係は次の通りです。
(各効果についてはこちらを参照)

Transformation 効果 Preprocessing Layer
Shift 画像をランダムに縦横にシフト RandomTranslation
Shear 指定した範囲内の角度で画像をランダムに引っ張る 無し
Random brightness 画像の明るさをランダムに変更 無し(tf.imageには有り)
Channel Shift 各チャネルに指定したパラメータで加算、加算前の最小値、最大値でクリッピング 無し
Zoom ランダムに画像を拡大 RandomZoom
Flip ランダムに画像を縦/横に反転 RandomFlip

Shear, Random brigntness, Channel Shiftについてはないので、自前で実装します。ShiftはRandomTranslationでもいいんですが、0を指定できない(shiftしない場合)ので、自由度をあげるために自分で実装します。

サンプル画像については、わためぇを使います。
result.png (301.8 kB)

前準備

2.6.0以降はexperimentalがなくても大丈夫です。
書き方は標準実装のpreprocessing layerの書き方に従います。

import math
from typing import Union

import tensorflow as tf
from tensorflow.keras import Sequential, backend, layers

if tf.__version__ >= '2.6.0':
    from keras.layers.preprocessing.image_preprocessing import transform
    from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom
else:
    # direct
    from tensorflow.keras.layers.experimental.preprocessing import (
        RandomFlip, RandomRotation, RandomZoom)
    from tensorflow.python.keras.layers.preprocessing.image_preprocessing import \
        transform

SEED = 0
rng = tf.random.Generator.from_seed(SEED)

def smart_cond(pred, true_fn=None, false_fn=None, name=None):
  if isinstance(pred, tf.Variable):
    return tf.cond(
        pred, true_fn=true_fn, false_fn=false_fn, name=name)
  return tf.__internal__.smart_cond.smart_cond(
      pred, true_fn=true_fn, false_fn=false_fn, name=name)

プロット用:

import cv2
import matplotlib.pyplot as plt

image = cv2.imread(...)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)[None,...]

def plot_layer(p_layer, nrows, ncols):
    plt.figure(figsize=(ncols * 3, nrows * 3))
    for i in range(nrows):
        for j in range(ncols):
            plt.subplot(nrows, ncols, i * ncols + j + 1)
            plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
            plt.imshow(p_layer((image / 255.), training=True)[0]) # training=Trueにしないと何もしません
    plt.show()
    plt.close()

Shift

class RandomShift(layers.Layer):...
class RandomShift(layers.Layer):
    def __init__(
        self,
        width_shift_range: float,
        height_shift_range: float,
        fill_mode: str = 'nearest',
        fill_value: Union[float, int] = 0.,
        seed: int = SEED,
        **kwargs):
        super(RandomShift, self).__init__(**kwargs)
        self.width_shift_range = width_shift_range
        self.height_shift_range = height_shift_range
        self.fill_mode = fill_mode
        self.fill_value = fill_value
        self.seed = seed
        if seed != SEED:
            self.rng = tf.random.Generator.from_seed(seed)
        else:
            self.rng = rng


    def call(self, inputs, training=None):
        if training is None:
            training = backend.learning_phase()
        original_shape = inputs.shape

        def random_shift():
            input_shape = tf.shape(inputs)
            h, w = input_shape[1], input_shape[2]
            h = tf.cast(h, tf.float32)
            w = tf.cast(w, tf.float32)
            tx = self.rng.uniform((), -self.height_shift_range, self.height_shift_range)*h
            ty = self.rng.uniform((), -self.width_shift_range, self.width_shift_range)*w

            shift_matrix = tf.convert_to_tensor(
                [[1, 0, tx],
                [0, 1, ty],
                [0, 0, 1]])
            transform_matrix = shift_matrix @ shift_matrix
            transform_matrix = tf.reshape(transform_matrix, [-1])[:8]
            transform_matrix = tf.tile(tf.expand_dims(transform_matrix, 0), [tf.shape(inputs)[0], 1])
            return transform(inputs, transform_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)

        outputs = smart_cond(
            training, true_fn=random_shift, false_fn=lambda: inputs)
        outputs.set_shape(original_shape)
        return outputs

    def get_config(self):
        config = {
            'width_shift_range': self.width_shift_range,
            'height_shift_range': self.height_shift_range,
            'fill_mode': self.fill_mode,
            'fill_value': self.fill_value,
            'seed': self.seed,
        }
        base_config = super(RandomShift, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

plot_layer(RandomShift(width_shift_range=0.2, height_shift_range=0.2), 1, 4)

output.png

Shear

class RandomShear(layers.Layer):...
class RandomShear(layers.Layer):
    def __init__(self,
                 width_shift_range: float,
                 height_shift_range: float,
                 fill_mode: str = 'nearest',
                 fill_value: Union[float, int] = 0.,
                 seed: int = SEED,
                 **kwargs):
        super(RandomShear, self).__init__(**kwargs)
        self.shear_range = shear_range
        self.fill_mode = fill_mode
        self.fill_value = fill_value
        self.seed = seed
        if seed != SEED:
            self.rng = tf.random.Generator.from_seed(seed)
        else:
            self.rng = rng

    def call(self, inputs, training=None):
        if training is None:
            training = backend.learning_phase()
        original_shape = inputs.shape

        def random_shear():
            shear = self.shear_range / 180 * math.pi
            shear = self.rng .uniform((), -self.shear_range, self.shear_range)
            shear_matrix = tf.convert_to_tensor(
                [[1, -tf.sin(shear), 0],
                [0, tf.cos(shear), 0],
                [0, 0, 1]])
            transform_matrix = shear_matrix @ shear_matrix
            transform_matrix = tf.reshape(transform_matrix, [-1])[:8]
            transform_matrix = tf.tile(tf.expand_dims(transform_matrix, 0), [tf.shape(inputs)[0], 1])
            return transform(inputs, transform_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)

        outputs = smart_cond(training, true_fn=random_shear, false_fn=lambda: inputs)
        outputs.set_shape(original_shape)
        return outputs

    def get_config(self):
        config = {
            'shear_range': self.shear_range,
            'fill_mode': self.fill_mode,
            'fill_value': self.fill_value,
            'seed': self.seed,
        }
        base_config = super(RandomShear, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

plot_layer(RandomShear(shear_range=0.3), 1, 4)

output.png

Random brightness

tf.image.adjust_brightnessを利用します。

class RandomBrightness(layers.Layer):...
class RandomBrightness(layers.Layer):
    def __init__(self,
                 brightness: Union[float, tuple] = 0.2,
                 seed: int = SEED,
                 **kwargs):
        super(RandomBrightness, self).__init__(**kwargs)
        if isinstance(brightness, (list, tuple)):
            assert brightness[0] < brightness[1] and brightness[1] >= 0.0, 'brightness must be non-negative'
            self.brightness = brightness
        else:
            assert brightness >= 0., 'brightness must be non-negative'
            self.brightness = (-brightness, brightness)
        if seed != SEED:
            self.rng = tf.random.Generator.from_seed(seed)
        else:
            self.rng = rng

    def call(self, inputs, training=None):
        if training is None:
            training = backend.learning_phase()
        original_shape = inputs.shape

        def random_channel_shift():
            delta = self.rng.uniform(
                (), self.brightness[0], self.brightness[1], dtype=inputs.dtype)
            return tf.image.adjust_brightness(inputs, delta)

        outputs = smart_cond(
            training, true_fn=random_channel_shift, false_fn=lambda: inputs)
        outputs.set_shape(original_shape)
        return outputs

    def get_config(self):
        config = {
            'brightness': self.brightness,
            'seed': self.rng.seed
        }
        base_config = super(RandomBrightness, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

plot_layer(RandomBrightness(brightness=0.2), 1, 4)

output.png

Channel Shift

class RandomChannelShift(layers.Layer):...
class RandomChannelShift(layers.Layer):
    def __init__(self,
                 shift_range: float,
                 seed: int = SEED,
                 **kwargs):
        super(RandomChannelShift, self).__init__(**kwargs)
        self.shift_range = shift_range
        self.seed = seed
        if seed != SEED:
            self.rng = tf.random.Generator.from_seed(seed)
        else:
            self.rng = rng

    def call(self, inputs, training=None):
        if training is None:
            training = backend.learning_phase()
        original_shape = inputs.shape

        def random_channel_shift():
            cmin = tf.reduce_min(
                inputs, axis=[1, 2], keepdims=True)
            cmax = tf.reduce_max(
                inputs, axis=[1, 2], keepdims=True)
            avalue = self.rng.uniform((), -self.shift_range, self.shift_range, dtype=inputs.dtype)
            clipped = tf.clip_by_value(inputs + avalue, cmin, cmax)
            return clipped

        outputs = smart_cond(
            training, true_fn=random_channel_shift, false_fn=lambda: inputs)
        outputs.set_shape(original_shape)
        return outputs

    def get_config(self):
        config = {
            'shift_range': self.shift_range,
            'seed': self.seed,
        }
        base_config = super(RandomChannelShift, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

plot_layer(RandomChannelShift(shift_range=0.5), 1, 4)

output.png

まとめる

引数の名前はImageDataGeneratorと同じにしています。

class ImageTransformation(layers.Layer):...
class ImageTransformation(layers.Layer):
    def __init__(self,
                 rotation_range: Union[float, tuple, list]=0,
                 width_shift_range: Union[float, tuple, list]=0.,
                 height_shift_range: Union[float, tuple, list]=0.,
                 brightness_range: Union[float, tuple]=None,
                 shear_range: float=0.,
                 zoom_range: Union[float, tuple, list]=0.,
                 channel_shift_range: Union[float, tuple, list]=0.,
                 fill_mode: str='nearest',
                 horizontal_flip: bool=False,
                 vertical_flip: bool=False,
                 fill_value: Union[float, int] = None,
                 seed: int = SEED,
                 **kwargs):
        self.rotation_range = rotation_range
        self.width_shift_range = width_shift_range
        self.height_shift_range = height_shift_range
        self.brightness_range = brightness_range
        self.shear_range: float = shear_range
        self.zoom_range = zoom_range
        self.channel_shift_range = channel_shift_range
        self.fill_mode = fill_mode or 'nearest'
        self.flip_mode = ''
        if horizontal_flip:
            self.flip_mode += 'horizontal'
        if vertical_flip:
            if self.flip_mode:
                self.flip_mode += '_and_'
            self.flip_mode += 'vertical'
        if fill_value is None and 'cval' in kwargs:
            self.fill_value = kwargs.pop('cval')
        else:
            self.fill_value = fill_value or 0.
        self.seed = seed
        super(ImageTransformation, self).__init__(**kwargs)


    def build(self, input_shape):
        layers = [
            RandomRotation(
                self.rotation_range, fill_mode=self.fill_mode, fill_value=self.fill_value, seed=self.seed),
            RandomShift(
                self.width_shift_range, self.height_shift_range, self.fill_mode, fill_value=self.fill_value, seed=self.seed),
            RandomShear(
                self.shear_range, self.fill_mode, fill_value=self.fill_value, seed=self.seed),
            RandomZoom(
                self.zoom_range, fill_mode=self.fill_mode, fill_value=self.fill_value, seed=self.seed),
            ]
        if self.flip_mode != '':
            layers.append(RandomFlip(self.flip_mode, seed=self.seed))
        layers.append(RandomChannelShift(self.channel_shift_range, seed=self.seed))
        if self.brightness_range is not None:
            layers.append(RandomBrightness(self.brightness_range, seed=self.seed))
        self.preprocessing = Sequential(layers=layers)

    def call(self, inputs, training=None):
        if training is None:
            training = backend.learning_phase()
        return self.preprocessing(inputs, training=training)

    def get_config(self):
        config = {
            'rotation_range': self.rotation_range,
            'width_shift_range': self.width_shift_range,
            'height_shift_range': self.height_shift_range,
            'brightness_range': self.brightness_range,
            'shear_range': self.shear_range,
            'zoom_range': self.zoom_range,
            'channel_shift_range': self.channel_shift_range,
            'fill_mode': self.fill_mode,
            'horizontal_flip': self.flip_mode.startswith('horizontal'),
            'vertical_flip': self.flip_mode.endswith('vertical'),
            'fill_value': self.fill_value,
            'seed': self.seed,
        }
        base_config = super(ImageTransformation, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

テスト

学習最適化のための損失関数とOptimizer & MRI画像を使った比較で使っていた前処理をlayerにしてしまいましょう。

train_generator_args = dict(rotation_range=0.2,
                            width_shift_range=0.05,
                            height_shift_range=0.05,
                            shear_range=0.05,
                            zoom_range=0.05,
                            horizontal_flip=True,
                            fill_mode='nearest')
image_transform = ImageTransformation(**train_generator_args)
plot_layer(image_transform, 3, 3)

result.png (301.8 kB)

うまく機能できてそうですね。
実際に使う場合は、Keras前処理レイヤーを使用するための2つのオプションの一部を置き換えて次のようにすればokです。

model = tf.keras.Sequential([
  # Add the preprocessing layers you created earlier.
  resize_and_rescale,
  image_transform, # <-ここを変える
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  # Rest of your model.
])

Apply the preprocessing layers to the datasets

  # Use data augmentation only on the training set.
  if augment:
    ds = ds.map(lambda x, y: (image_transform(x, training=True), y), 
                num_parallel_calls=AUTOTUNE)

全体のコード

↓に上げておきます。

おわりに

tensorflowは今までつかってた手法が急に時代遅れになることがよくあります。pytorchは互換性も含めてそういうこともないので結構便利だなって思ってます。
まぁ、書くのが楽しいのがtensorflowなんですけどね。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1