はじめに
keras.preprocessing.image.ImageDataGenerator
はいろんな前処理が実装されています。ただtensorflowでモダンな書き方をする場合、こいつを使うと遅くなります(参考)。
モダンな書き方というのは、(僕の中だけかもしれないですが)tf.data+前処理をlayerに入れてしまうという書き方のことです。詳しくは以下をご覧ください。
ImageDataGeneratorは便利だけど、Preprocessing Layerとして実装したい…ということで、実装してみました。
実装
ImageDataGenratorでできる画像の変形(transformation)とpreprocessingでの対応関係は次の通りです。
(各効果についてはこちらを参照)
Transformation | 効果 | Preprocessing Layer |
---|---|---|
Shift | 画像をランダムに縦横にシフト | RandomTranslation |
Shear | 指定した範囲内の角度で画像をランダムに引っ張る | 無し |
Random brightness | 画像の明るさをランダムに変更 | 無し(tf.image には有り) |
Channel Shift | 各チャネルに指定したパラメータで加算、加算前の最小値、最大値でクリッピング | 無し |
Zoom | ランダムに画像を拡大 | RandomZoom |
Flip | ランダムに画像を縦/横に反転 | RandomFlip |
Shear, Random brigntness, Channel Shiftについてはないので、自前で実装します。ShiftはRandomTranslationでもいいんですが、0を指定できない(shiftしない場合)ので、自由度をあげるために自分で実装します。
前準備
2.6.0以降はexperimentalがなくても大丈夫です。
書き方は標準実装のpreprocessing layerの書き方に従います。
import math
from typing import Union
import tensorflow as tf
from tensorflow.keras import Sequential, backend, layers
if tf.__version__ >= '2.6.0':
from keras.layers.preprocessing.image_preprocessing import transform
from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom
else:
# direct
from tensorflow.keras.layers.experimental.preprocessing import (
RandomFlip, RandomRotation, RandomZoom)
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import \
transform
SEED = 0
rng = tf.random.Generator.from_seed(SEED)
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
if isinstance(pred, tf.Variable):
return tf.cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
return tf.__internal__.smart_cond.smart_cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
プロット用:
import cv2
import matplotlib.pyplot as plt
image = cv2.imread(...)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)[None,...]
def plot_layer(p_layer, nrows, ncols):
plt.figure(figsize=(ncols * 3, nrows * 3))
for i in range(nrows):
for j in range(ncols):
plt.subplot(nrows, ncols, i * ncols + j + 1)
plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
plt.imshow(p_layer((image / 255.), training=True)[0]) # training=Trueにしないと何もしません
plt.show()
plt.close()
Shift
`class RandomShift(layers.Layer):...`
class RandomShift(layers.Layer):
def __init__(
self,
width_shift_range: float,
height_shift_range: float,
fill_mode: str = 'nearest',
fill_value: Union[float, int] = 0.,
seed: int = SEED,
**kwargs):
super(RandomShift, self).__init__(**kwargs)
self.width_shift_range = width_shift_range
self.height_shift_range = height_shift_range
self.fill_mode = fill_mode
self.fill_value = fill_value
self.seed = seed
if seed != SEED:
self.rng = tf.random.Generator.from_seed(seed)
else:
self.rng = rng
def call(self, inputs, training=None):
if training is None:
training = backend.learning_phase()
original_shape = inputs.shape
def random_shift():
input_shape = tf.shape(inputs)
h, w = input_shape[1], input_shape[2]
h = tf.cast(h, tf.float32)
w = tf.cast(w, tf.float32)
tx = self.rng.uniform((), -self.height_shift_range, self.height_shift_range)*h
ty = self.rng.uniform((), -self.width_shift_range, self.width_shift_range)*w
shift_matrix = tf.convert_to_tensor(
[[1, 0, tx],
[0, 1, ty],
[0, 0, 1]])
transform_matrix = shift_matrix @ shift_matrix
transform_matrix = tf.reshape(transform_matrix, [-1])[:8]
transform_matrix = tf.tile(tf.expand_dims(transform_matrix, 0), [tf.shape(inputs)[0], 1])
return transform(inputs, transform_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)
outputs = smart_cond(
training, true_fn=random_shift, false_fn=lambda: inputs)
outputs.set_shape(original_shape)
return outputs
def get_config(self):
config = {
'width_shift_range': self.width_shift_range,
'height_shift_range': self.height_shift_range,
'fill_mode': self.fill_mode,
'fill_value': self.fill_value,
'seed': self.seed,
}
base_config = super(RandomShift, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
plot_layer(RandomShift(width_shift_range=0.2, height_shift_range=0.2), 1, 4)
Shear
`class RandomShear(layers.Layer):...`
class RandomShear(layers.Layer):
def __init__(self,
width_shift_range: float,
height_shift_range: float,
fill_mode: str = 'nearest',
fill_value: Union[float, int] = 0.,
seed: int = SEED,
**kwargs):
super(RandomShear, self).__init__(**kwargs)
self.shear_range = shear_range
self.fill_mode = fill_mode
self.fill_value = fill_value
self.seed = seed
if seed != SEED:
self.rng = tf.random.Generator.from_seed(seed)
else:
self.rng = rng
def call(self, inputs, training=None):
if training is None:
training = backend.learning_phase()
original_shape = inputs.shape
def random_shear():
shear = self.shear_range / 180 * math.pi
shear = self.rng .uniform((), -self.shear_range, self.shear_range)
shear_matrix = tf.convert_to_tensor(
[[1, -tf.sin(shear), 0],
[0, tf.cos(shear), 0],
[0, 0, 1]])
transform_matrix = shear_matrix @ shear_matrix
transform_matrix = tf.reshape(transform_matrix, [-1])[:8]
transform_matrix = tf.tile(tf.expand_dims(transform_matrix, 0), [tf.shape(inputs)[0], 1])
return transform(inputs, transform_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)
outputs = smart_cond(training, true_fn=random_shear, false_fn=lambda: inputs)
outputs.set_shape(original_shape)
return outputs
def get_config(self):
config = {
'shear_range': self.shear_range,
'fill_mode': self.fill_mode,
'fill_value': self.fill_value,
'seed': self.seed,
}
base_config = super(RandomShear, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
plot_layer(RandomShear(shear_range=0.3), 1, 4)
Random brightness
tf.image.adjust_brightness
を利用します。
`class RandomBrightness(layers.Layer):...`
class RandomBrightness(layers.Layer):
def __init__(self,
brightness: Union[float, tuple] = 0.2,
seed: int = SEED,
**kwargs):
super(RandomBrightness, self).__init__(**kwargs)
if isinstance(brightness, (list, tuple)):
assert brightness[0] < brightness[1] and brightness[1] >= 0.0, 'brightness must be non-negative'
self.brightness = brightness
else:
assert brightness >= 0., 'brightness must be non-negative'
self.brightness = (-brightness, brightness)
if seed != SEED:
self.rng = tf.random.Generator.from_seed(seed)
else:
self.rng = rng
def call(self, inputs, training=None):
if training is None:
training = backend.learning_phase()
original_shape = inputs.shape
def random_channel_shift():
delta = self.rng.uniform(
(), self.brightness[0], self.brightness[1], dtype=inputs.dtype)
return tf.image.adjust_brightness(inputs, delta)
outputs = smart_cond(
training, true_fn=random_channel_shift, false_fn=lambda: inputs)
outputs.set_shape(original_shape)
return outputs
def get_config(self):
config = {
'brightness': self.brightness,
'seed': self.rng.seed
}
base_config = super(RandomBrightness, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
plot_layer(RandomBrightness(brightness=0.2), 1, 4)
Channel Shift
`class RandomChannelShift(layers.Layer):...`
class RandomChannelShift(layers.Layer):
def __init__(self,
shift_range: float,
seed: int = SEED,
**kwargs):
super(RandomChannelShift, self).__init__(**kwargs)
self.shift_range = shift_range
self.seed = seed
if seed != SEED:
self.rng = tf.random.Generator.from_seed(seed)
else:
self.rng = rng
def call(self, inputs, training=None):
if training is None:
training = backend.learning_phase()
original_shape = inputs.shape
def random_channel_shift():
cmin = tf.reduce_min(
inputs, axis=[1, 2], keepdims=True)
cmax = tf.reduce_max(
inputs, axis=[1, 2], keepdims=True)
avalue = self.rng.uniform((), -self.shift_range, self.shift_range, dtype=inputs.dtype)
clipped = tf.clip_by_value(inputs + avalue, cmin, cmax)
return clipped
outputs = smart_cond(
training, true_fn=random_channel_shift, false_fn=lambda: inputs)
outputs.set_shape(original_shape)
return outputs
def get_config(self):
config = {
'shift_range': self.shift_range,
'seed': self.seed,
}
base_config = super(RandomChannelShift, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
plot_layer(RandomChannelShift(shift_range=0.5), 1, 4)
まとめる
引数の名前はImageDataGenerator
と同じにしています。
`class ImageTransformation(layers.Layer):...`
class ImageTransformation(layers.Layer):
def __init__(self,
rotation_range: Union[float, tuple, list]=0,
width_shift_range: Union[float, tuple, list]=0.,
height_shift_range: Union[float, tuple, list]=0.,
brightness_range: Union[float, tuple]=None,
shear_range: float=0.,
zoom_range: Union[float, tuple, list]=0.,
channel_shift_range: Union[float, tuple, list]=0.,
fill_mode: str='nearest',
horizontal_flip: bool=False,
vertical_flip: bool=False,
fill_value: Union[float, int] = None,
seed: int = SEED,
**kwargs):
self.rotation_range = rotation_range
self.width_shift_range = width_shift_range
self.height_shift_range = height_shift_range
self.brightness_range = brightness_range
self.shear_range: float = shear_range
self.zoom_range = zoom_range
self.channel_shift_range = channel_shift_range
self.fill_mode = fill_mode or 'nearest'
self.flip_mode = ''
if horizontal_flip:
self.flip_mode += 'horizontal'
if vertical_flip:
if self.flip_mode:
self.flip_mode += '_and_'
self.flip_mode += 'vertical'
if fill_value is None and 'cval' in kwargs:
self.fill_value = kwargs.pop('cval')
else:
self.fill_value = fill_value or 0.
self.seed = seed
super(ImageTransformation, self).__init__(**kwargs)
def build(self, input_shape):
layers = [
RandomRotation(
self.rotation_range, fill_mode=self.fill_mode, fill_value=self.fill_value, seed=self.seed),
RandomShift(
self.width_shift_range, self.height_shift_range, self.fill_mode, fill_value=self.fill_value, seed=self.seed),
RandomShear(
self.shear_range, self.fill_mode, fill_value=self.fill_value, seed=self.seed),
RandomZoom(
self.zoom_range, fill_mode=self.fill_mode, fill_value=self.fill_value, seed=self.seed),
]
if self.flip_mode != '':
layers.append(RandomFlip(self.flip_mode, seed=self.seed))
layers.append(RandomChannelShift(self.channel_shift_range, seed=self.seed))
if self.brightness_range is not None:
layers.append(RandomBrightness(self.brightness_range, seed=self.seed))
self.preprocessing = Sequential(layers=layers)
def call(self, inputs, training=None):
if training is None:
training = backend.learning_phase()
return self.preprocessing(inputs, training=training)
def get_config(self):
config = {
'rotation_range': self.rotation_range,
'width_shift_range': self.width_shift_range,
'height_shift_range': self.height_shift_range,
'brightness_range': self.brightness_range,
'shear_range': self.shear_range,
'zoom_range': self.zoom_range,
'channel_shift_range': self.channel_shift_range,
'fill_mode': self.fill_mode,
'horizontal_flip': self.flip_mode.startswith('horizontal'),
'vertical_flip': self.flip_mode.endswith('vertical'),
'fill_value': self.fill_value,
'seed': self.seed,
}
base_config = super(ImageTransformation, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
テスト
学習最適化のための損失関数とOptimizer & MRI画像を使った比較で使っていた前処理をlayerにしてしまいましょう。
train_generator_args = dict(rotation_range=0.2,
width_shift_range=0.05,
height_shift_range=0.05,
shear_range=0.05,
zoom_range=0.05,
horizontal_flip=True,
fill_mode='nearest')
image_transform = ImageTransformation(**train_generator_args)
plot_layer(image_transform, 3, 3)

うまく機能できてそうですね。
実際に使う場合は、Keras前処理レイヤーを使用するための2つのオプションの一部を置き換えて次のようにすればokです。
model = tf.keras.Sequential([
# Add the preprocessing layers you created earlier.
resize_and_rescale,
image_transform, # <-ここを変える
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
# Rest of your model.
])
Apply the preprocessing layers to the datasets
# Use data augmentation only on the training set.
if augment:
ds = ds.map(lambda x, y: (image_transform(x, training=True), y),
num_parallel_calls=AUTOTUNE)
全体のコード
↓に上げておきます。
おわりに
tensorflowは今までつかってた手法が急に時代遅れになることがよくあります。pytorchは互換性も含めてそういうこともないので結構便利だなって思ってます。
まぁ、書くのが楽しいのがtensorflowなんですけどね。