はじめに
Keras使用者はImageDataGenaratorを使ってデータ拡張を行うことが一般的なはずだが、TPU環境ではImageDataGeneratorが直接使えない。
そこで、tf.data.Datasetにmapとして適用する関数を各種用意した。ImageDataGenaratorのほぼすべての機能に対応しているが、data_formatは'channels_last'を前提としている。
環境は、TensorFlow(2.3.0)/tf.keras(2.4.0)。
基本機能のみ
とりあえずflipとshiftだけあれば良い場合に使う。
パラメータはImageDataGeneratorと同じなので説明は省略。
ただし、fill_modeは'constant'と'reflect'しか使えない。
import tensorflow as tf
@tf.function
def simple_augmentation(images,
rescale = 1.0,
width_shift_range=0.,
height_shift_range=0.,
horizontal_flip=False,
vertical_flip=False,
cval=0.0,
fill_mode='constant'
):
img_shape = images.shape[-3:]
if rescale != 1.0:
images *= rescale
if horizontal_flip:
images = tf.image.random_flip_left_right(images)
if vertical_flip:
images = tf.image.random_flip_up_down(images)
if width_shift_range != 0 or height_shift_range != 0:
width_shift = int(img_shape[1] * width_shift_range)
height_shift = int(img_shape[0] * height_shift_range)
pad = tf.constant([[0, 0], [height_shift, height_shift], [width_shift, width_shift], [0, 0]])
images = tf.pad(images, pad, mode=fill_mode, constant_values=cval)
images = tf.map_fn(lambda image: tf.image.random_crop(image, size=[img_shape[0], img_shape[1], img_shape[2]]), images)
return images
色変更
ImageDataGeneratorのbrightness_rangeやchannel_shift_rangeのように明るさや色彩を変更したい場合に使う。
ただし、channel_shift_rangeは謎仕様で使いづらいので、別の機能で置き換える。
コントラスト/彩度/色相も変化させられるので、結果的にImageDataGeneratorより高機能になっている。
引数の効果等はtf.imageのオンラインヘルプの該当する項目参照のこと。
tf.imageの内部処理的に色情報は0.0~1.0の範囲でないと機能しないようなので、0~255の画像の場合はrescaleを設定するか、事前に変換したほうが良い。
当然ながら、モノクロ画像ではhueやsaturationは変更できない。
@tf.function
def color_augmentation(images,
rescale = 1.0,
brightness_range=0.0,
hue_range=0.0,
contrast_range=[1.0,1.0],
saturation_range=[1.0,1.0],
clip_range=[0.,1.]
):
if rescale != 1.0:
images *= rescale
if (images.get_shape()[-1] == 1):
hue_range = 0.0
saturation_range = [1.0,1.0]
if hue_range>0.5:
hue_range=0.5
def color_aug(image):
if brightness_range != 0.0:
image = tf.image.random_brightness(image, brightness_range)
if hue_range != 0.0:
image = tf.image.random_hue(image, hue_range)
if saturation_range[0] != saturation_range[1]:
image = tf.image.random_saturation(image,saturation_range[0], saturation_range[1])
if contrast_range[0] != contrast_range[1]:
image = tf.image.random_contrast(image, contrast_range[0], contrast_range[1])
if clip_range[0] != clip_range[1]:
image = tf.clip_by_value(image, clip_range[0], clip_range[1])
return image
images = tf.map_fn(lambda image: color_aug(image), images)
return images
変形
回転や拡大縮小のように画像を変形させたい場合に使う。基本機能として実装した機能も含んでいる。
アフィン変換のためにtensorflow-addonsを使用しているが、記事作成時のGoogle Colab環境ではupgradeが必要になるので、事前に!pip install --upgrade tensorflow-addons
をセルで実行しておく。(upgrade後はtensorflow_addonsは0.11.2以上になるはず)
パラメータはImageDataGeneratorと同じなので説明は省略。
入力画像が正方形(widthとheightが同じ)であることが前提なので注意。
記事作成時でのtensorflow_addonsでtransformがfill_modeを実装していないためpaddingの処理が必要になっているが、将来的にはfill_modeが実装されるはずなので、その場合はもう少し簡単なコードにできる。
@tf.function
def transform_augmentation(images,
rescale=1.0,
zoom_range=[1.,1.],
rotation_range=0.,
shear_range=0.,
interpolation='BILINEAR',
width_shift_range=0., height_shift_range=0.,
horizontal_flip=False, vertical_flip=False,
fill_mode='constant',
cval=0.0
):
if rescale!=1.0:
images = images*rescale
if isinstance(zoom_range,float):
zoom_lower = 1.0-zoom_range
zoom_upper = 1.0+zoom_range
else:
zoom_lower, zoom_upper = zoom_range
if shear_range>45.0:
shear_range = 45
if width_shift_range>0.5:
width_shift_range=0.5
if height_shift_range>0.5:
height_shift_range=0.5
if zoom_upper>1.5:
zoom_upper = 1.5
img_width = images.get_shape()[1]
center = img_width/2
if fill_mode.lower()=='constant':
fill_mode_no=0
else:
fill_mode_no=1
margin = 0
if fill_mode_no == 1 or cval != 0.0:
expand = 1.0
if shear_range != 0.0:
expand += math.tan(shear_range*3.141519/180)
if rotation_range != 0.0:
expand *= math.sqrt(2)
expand *= zoom_upper
margin += int(center*expand - center)
shift_max = int(max((width_shift_range*img_width, height_shift_range*img_width)))
margin+=shift_max
if margin >= img_width:
margin=img_width-1
center = center+float(margin)
def transform(image):
angle = tf.random.uniform(shape=[], minval=-rotation_range, maxval=rotation_range)*3.141519/180
if horizontal_flip:
mirror_x = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.dtypes.int32)*2-1, tf.float32)
else:
mirror_x = 1.0
if vertical_flip:
mirror_y = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.dtypes.int32)*2-1, tf.float32)
else:
mirror_y = 1.0
zoom_x = tf.random.uniform(shape=[], minval=zoom_lower, maxval=zoom_upper)
zoom_y = tf.random.uniform(shape=[], minval=zoom_lower, maxval=zoom_upper)
width_shift = tf.random.uniform(shape=[], minval=-width_shift_range, maxval=width_shift_range)*img_width
height_shift = tf.random.uniform(shape=[], minval=-height_shift_range, maxval=height_shift_range)*img_width
shear_val = tf.random.uniform(shape=[], minval=-shear_range, maxval=shear_range)*3.141519/180
if fill_mode_no == 1:
image = tf.pad(image, tf.constant([[margin, margin], [margin, margin], [0, 0]]), mode="REFLECT")
elif cval != 0.0:
image = tf.pad(image, tf.constant([[margin, margin], [margin, margin], [0, 0]]), mode="CONSTANT", constant_values=cval)
sinval = tf.sin(angle)
cosval = tf.cos(angle)
center_mat = [1.0, 0.0, center, 0.0, 1.0, center, 0.0, 0.0]
shear_mat=[1.0, 0.0, 0.0, tf.tan(shear_val), 1.0, 0.0, 0.0, 0.0]
rotate_mat = [cosval, -sinval, 0.0, sinval, cosval, 0.0, 0.0, 0.0]
zoom_mat = [zoom_x*mirror_x, 0.0, 0.0, 0.0, zoom_y*mirror_y, 0.0, 0.0, 0.0]
center_mat_inv = [1.0, 0.0, width_shift-center, 0.0, 1.0, height_shift-center, 0.0, 0.0]
matrix = [center_mat, shear_mat, rotate_mat, zoom_mat, center_mat_inv]
composed_matrix = tfa.image.transform_ops.compose_transforms(matrix)
image = tfa.image.transform(image, composed_matrix, interpolation=interpolation)
if fill_mode_no == 1 or cval != 0.0:
image = tf.image.resize_with_crop_or_pad(image, img_width, img_width)
return image
images = tf.map_fn(lambda image: transform(image), images)
return images
標準化
NormalizationやWhiteningの機能が必要な場合に使う。
ImageDataGeneratorでは事前にfitが必要なので、こちらではクラス化してインスタンス作成時にfit処理を実行している。
こちらはTrainingデータだけではなくValidationデータにも適用しないといけないので注意。
import numpy as np
import tensorflow as tf
import scipy
from scipy import linalg
class ImageStandardization():
def __init__(self, x,
rescale = 1.0,
samplewise_center=False,
samplewise_std_normalization=False,
featurewise_center=False,
featurewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6
):
self.samplewise_center = samplewise_center
self.samplewise_std_normalization = samplewise_std_normalization
self.rescale = rescale
if zca_whitening:
featurewise_center = True
featurewise_std_normalization = False
if featurewise_std_normalization:
featurewise_center = True
if self.samplewise_std_normalization:
self.samplewise_center = True
x = np.array(x, dtype=np.float32)
if self.rescale != 1.0:
x *= self.rescale
broadcast_shape = [1, 1, x.shape[3]]
if featurewise_center:
self.mean = np.mean(x, axis=(0, 1, 2))
self.mean = np.reshape(self.mean, broadcast_shape)
x -= self.mean
else:
self.mean = None
if featurewise_std_normalization:
self.std = np.std(x, axis=(0, 1, 2))
self.std = np.reshape(self.std, broadcast_shape)
x /= (self.std + 1e-6)
else:
self.std = None
if zca_whitening:
flat_x = np.reshape(
x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
u, s, _ = linalg.svd(sigma)
s_inv = 1. / np.sqrt(s[np.newaxis] + zca_epsilon)
self.principal_components = (u * s_inv).dot(u.T)
self.flatshape = (-1, np.prod(x.shape[-3:]))
self.shape = (-1, x.shape[1], x.shape[2], x.shape[3])
else:
self.principal_components = None
@tf.function
def standardize(self, x):
if self.rescale != 1.0:
x *= self.rescale
if self.samplewise_center:
def center(image):
mean = tf.math.reduce_mean(image, axis=None, keepdims=True)
image -= mean
return image
x = tf.map_fn(lambda image: center(image), x)
if self.samplewise_std_normalization:
def normalize(image):
std = tf.math.reduce_std(image, axis=None, keepdims=True)
image *= 1./(std + 1e-6)
return image
x = tf.map_fn(lambda image: normalize(image), x)
if self.mean is not None:
x -= self.mean
if self.std is not None:
x /= (self.std + 1e-6)
if self.principal_components is not None:
flatx = tf.reshape(x, self.flatshape)
whitex = tf.tensordot(flatx, tf.convert_to_tensor(self.principal_components), axes=1)
x = tf.reshape(whitex, self.shape)
return x
使用例
batch_size設定後にmapで適用することを想定している。
下記のように連続して呼び出すこともできる。rescaleを使う場合は一度だけにすること。
uint8の画像が入力されることを想定。rescaleはtransform_augmentationだけで実施している。
def filter_img(images):
images = tf.cast(images, tf.float32)
images = transform_augmentation(images,
rescale=1.0/255.0,
rotation_range=30,
shear_range=30,
zoom_range=[0.8, 1.2],
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1,
fill_mode='REFLECT')
images = color_augmentation(images,
hue_range=0.1,
brightness_range=0.2,
saturation_range=[0.8,1.2],
contrast_range=[0.8,1.2])
return images
ds = tf.data.Dataset.from_tensor_slices((x,y)).repeat().batch(batch_size)
ds = ds.map(lambda image, label: (filter_img(image), tf.cast(label, tf.float32)))
以下は、標準化の使用例。
uint8の画像が入力されることを想定している。Standardizationを他のAugumentationと同時利用する場合はrescaleは使用するとちょっと面倒なので、生成時にデータを渡す際に255で割っておき、filter_img内でも同じ操作をしている。単体で使う場合はrescaleを使用しても問題ない。
std = ImageStandardization(x/255.0, zca_whitening=True)
def filter_img(images):
images = tf.cast(images, tf.float32)/255.0 # cast & rescale
images = simple_augmentation(images,
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1)
images = std.standardize(images)
return images
ds = tf.data.Dataset.from_tensor_slices((x,y)).repeat().batch(batch_size)
ds = ds.map(lambda image, label: (filter_img(image), tf.cast(label, tf.float32)))
実験用のColabノートブックはこちら
CIFAR100出力サンプル(元画像と拡張後の画像を上下に連結している)
参考
TensorFlow & TensorFlow Addons
https://www.tensorflow.org/api_docs/python/tf
https://www.tensorflow.org/addons/api_docs/python/tfa
ImageDataGenerator
https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/image/image_data_generator.py
アフィン変換関係
https://qiita.com/koshian2/items/c133e2e10c261b8646bf
https://imagingsolution.blog.fc2.com/blog-entry-284.html