LoginSignup
4
2

More than 3 years have passed since last update.

Google ColabのTPU環境でImageDataGeneratorライクなデータ拡張を実現する

Last updated at Posted at 2020-11-26

はじめに

Keras使用者はImageDataGenaratorを使ってデータ拡張を行うことが一般的なはずだが、TPU環境ではImageDataGeneratorが直接使えない。
そこで、tf.data.Datasetにmapとして適用する関数を各種用意した。ImageDataGenaratorのほぼすべての機能に対応しているが、data_formatは'channels_last'を前提としている。
環境は、TensorFlow(2.3.0)/tf.keras(2.4.0)。

基本機能のみ

とりあえずflipとshiftだけあれば良い場合に使う。
パラメータはImageDataGeneratorと同じなので説明は省略。
ただし、fill_modeは'constant'と'reflect'しか使えない。

import tensorflow as tf
@tf.function
def simple_augmentation(images, 
                    rescale = 1.0,
                    width_shift_range=0.,
                    height_shift_range=0.,
                    horizontal_flip=False, 
                    vertical_flip=False,
                    cval=0.0,
                    fill_mode='constant'
    ):

    img_shape = images.shape[-3:]

    if rescale != 1.0:
        images *= rescale
    if horizontal_flip:
        images = tf.image.random_flip_left_right(images)
    if vertical_flip:
        images = tf.image.random_flip_up_down(images)

    if width_shift_range != 0 or height_shift_range != 0:
        width_shift = int(img_shape[1] * width_shift_range)
        height_shift = int(img_shape[0] * height_shift_range)
        pad = tf.constant([[0, 0], [height_shift, height_shift], [width_shift, width_shift], [0, 0]])
        images = tf.pad(images, pad, mode=fill_mode, constant_values=cval)
        images = tf.map_fn(lambda image: tf.image.random_crop(image, size=[img_shape[0], img_shape[1], img_shape[2]]), images)

    return images

色変更

ImageDataGeneratorのbrightness_rangeやchannel_shift_rangeのように明るさや色彩を変更したい場合に使う。
ただし、channel_shift_rangeは謎仕様で使いづらいので、別の機能で置き換える。
コントラスト/彩度/色相も変化させられるので、結果的にImageDataGeneratorより高機能になっている。
引数の効果等はtf.imageのオンラインヘルプの該当する項目参照のこと。
tf.imageの内部処理的に色情報は0.0~1.0の範囲でないと機能しないようなので、0~255の画像の場合はrescaleを設定するか、事前に変換したほうが良い。
当然ながら、モノクロ画像ではhueやsaturationは変更できない。

@tf.function
def color_augmentation(images,
        rescale = 1.0,
        brightness_range=0.0,
        hue_range=0.0,
        contrast_range=[1.0,1.0],
        saturation_range=[1.0,1.0],
        clip_range=[0.,1.]
    ):
    if rescale != 1.0:
        images *= rescale

    if (images.get_shape()[-1] == 1):
        hue_range = 0.0
        saturation_range = [1.0,1.0]

    if hue_range>0.5:
        hue_range=0.5

    def color_aug(image):
        if brightness_range != 0.0:
            image = tf.image.random_brightness(image, brightness_range)
        if hue_range != 0.0:
            image = tf.image.random_hue(image, hue_range)
        if saturation_range[0] != saturation_range[1]:
            image = tf.image.random_saturation(image,saturation_range[0], saturation_range[1])
        if contrast_range[0] != contrast_range[1]:
            image = tf.image.random_contrast(image, contrast_range[0], contrast_range[1])
        if clip_range[0] != clip_range[1]:
            image = tf.clip_by_value(image, clip_range[0], clip_range[1])
        return image

    images = tf.map_fn(lambda image: color_aug(image), images)
    return images

変形

回転や拡大縮小のように画像を変形させたい場合に使う。基本機能として実装した機能も含んでいる。
アフィン変換のためにtensorflow-addonsを使用しているが、記事作成時のGoogle Colab環境ではupgradeが必要になるので、事前に!pip install --upgrade tensorflow-addonsをセルで実行しておく。(upgrade後はtensorflow_addonsは0.11.2以上になるはず)

パラメータはImageDataGeneratorと同じなので説明は省略。
入力画像が正方形(widthとheightが同じ)であることが前提なので注意。
記事作成時でのtensorflow_addonsでtransformがfill_modeを実装していないためpaddingの処理が必要になっているが、将来的にはfill_modeが実装されるはずなので、その場合はもう少し簡単なコードにできる。

@tf.function
def transform_augmentation(images, 
        rescale=1.0, 
        zoom_range=[1.,1.],
        rotation_range=0.,
        shear_range=0.,
        interpolation='BILINEAR',
        width_shift_range=0., height_shift_range=0.,
        horizontal_flip=False, vertical_flip=False,
        fill_mode='constant',
        cval=0.0
        ):

    if rescale!=1.0:
        images = images*rescale

    if isinstance(zoom_range,float):
        zoom_lower = 1.0-zoom_range
        zoom_upper = 1.0+zoom_range
    else:
        zoom_lower, zoom_upper = zoom_range

    if shear_range>45.0:
        shear_range = 45
    if width_shift_range>0.5:
        width_shift_range=0.5
    if height_shift_range>0.5:
        height_shift_range=0.5
    if zoom_upper>1.5:
        zoom_upper = 1.5

    img_width = images.get_shape()[1]
    center = img_width/2

    if fill_mode.lower()=='constant':
        fill_mode_no=0
    else:
        fill_mode_no=1

    margin = 0
    if fill_mode_no == 1 or cval != 0.0:
        expand = 1.0
        if shear_range != 0.0:
            expand += math.tan(shear_range*3.141519/180)
        if rotation_range != 0.0:
            expand *= math.sqrt(2)
        expand *= zoom_upper
        margin += int(center*expand - center)
        shift_max = int(max((width_shift_range*img_width, height_shift_range*img_width)))
        margin+=shift_max
        if margin >= img_width:
            margin=img_width-1

        center = center+float(margin)
    def transform(image):
        angle = tf.random.uniform(shape=[], minval=-rotation_range, maxval=rotation_range)*3.141519/180
        if horizontal_flip:
            mirror_x = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.dtypes.int32)*2-1, tf.float32)
        else:
            mirror_x = 1.0
        if vertical_flip:
            mirror_y = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.dtypes.int32)*2-1, tf.float32)
        else:
            mirror_y = 1.0
        zoom_x = tf.random.uniform(shape=[], minval=zoom_lower, maxval=zoom_upper)
        zoom_y = tf.random.uniform(shape=[], minval=zoom_lower, maxval=zoom_upper)
        width_shift = tf.random.uniform(shape=[], minval=-width_shift_range, maxval=width_shift_range)*img_width
        height_shift = tf.random.uniform(shape=[], minval=-height_shift_range, maxval=height_shift_range)*img_width
        shear_val = tf.random.uniform(shape=[], minval=-shear_range, maxval=shear_range)*3.141519/180

        if fill_mode_no == 1:
            image = tf.pad(image, tf.constant([[margin, margin], [margin, margin], [0, 0]]), mode="REFLECT")
        elif  cval != 0.0:
            image = tf.pad(image, tf.constant([[margin, margin], [margin, margin], [0, 0]]), mode="CONSTANT", constant_values=cval)

        sinval = tf.sin(angle)
        cosval = tf.cos(angle)
        center_mat = [1.0, 0.0, center, 0.0, 1.0, center, 0.0, 0.0]
        shear_mat=[1.0, 0.0, 0.0, tf.tan(shear_val), 1.0, 0.0, 0.0, 0.0]
        rotate_mat = [cosval, -sinval, 0.0, sinval, cosval, 0.0, 0.0, 0.0]
        zoom_mat = [zoom_x*mirror_x, 0.0, 0.0, 0.0, zoom_y*mirror_y, 0.0, 0.0, 0.0]
        center_mat_inv = [1.0, 0.0, width_shift-center, 0.0, 1.0, height_shift-center, 0.0, 0.0]

        matrix = [center_mat, shear_mat, rotate_mat, zoom_mat, center_mat_inv]
        composed_matrix = tfa.image.transform_ops.compose_transforms(matrix)
        image = tfa.image.transform(image, composed_matrix, interpolation=interpolation)

        if fill_mode_no == 1 or cval != 0.0:
            image = tf.image.resize_with_crop_or_pad(image, img_width, img_width)

        return image

    images = tf.map_fn(lambda image: transform(image), images)

    return images

標準化

NormalizationやWhiteningの機能が必要な場合に使う。
ImageDataGeneratorでは事前にfitが必要なので、こちらではクラス化してインスタンス作成時にfit処理を実行している。
こちらはTrainingデータだけではなくValidationデータにも適用しないといけないので注意。

import numpy as np
import tensorflow as tf
import scipy
from scipy import linalg
class ImageStandardization():
    def __init__(self, x, 
            rescale = 1.0,
            samplewise_center=False, 
            samplewise_std_normalization=False, 
            featurewise_center=False, 
            featurewise_std_normalization=False, 
            zca_whitening=False,
            zca_epsilon=1e-6
        ):
        self.samplewise_center = samplewise_center
        self.samplewise_std_normalization = samplewise_std_normalization
        self.rescale = rescale

        if zca_whitening:
            featurewise_center = True
            featurewise_std_normalization = False
        if featurewise_std_normalization:
            featurewise_center = True
        if self.samplewise_std_normalization:
            self.samplewise_center = True

        x = np.array(x, dtype=np.float32)
        if self.rescale != 1.0:
            x *= self.rescale

        broadcast_shape = [1, 1, x.shape[3]]
        if featurewise_center:
            self.mean = np.mean(x, axis=(0, 1, 2))
            self.mean = np.reshape(self.mean, broadcast_shape)
            x -= self.mean
        else:
            self.mean = None

        if featurewise_std_normalization:
            self.std = np.std(x, axis=(0, 1, 2))
            self.std = np.reshape(self.std, broadcast_shape)
            x /= (self.std + 1e-6)
        else:
            self.std = None

        if zca_whitening:
            flat_x = np.reshape(
                x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
            sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
            u, s, _ = linalg.svd(sigma)
            s_inv = 1. / np.sqrt(s[np.newaxis] + zca_epsilon)
            self.principal_components = (u * s_inv).dot(u.T)
            self.flatshape = (-1, np.prod(x.shape[-3:]))
            self.shape = (-1, x.shape[1], x.shape[2], x.shape[3])
        else:
            self.principal_components = None

    @tf.function
    def standardize(self, x):
        if self.rescale != 1.0:
            x *= self.rescale
        if self.samplewise_center:
            def center(image):
                mean = tf.math.reduce_mean(image, axis=None, keepdims=True)
                image -= mean
                return image
            x = tf.map_fn(lambda image: center(image), x)
        if self.samplewise_std_normalization:
            def normalize(image):
                std = tf.math.reduce_std(image, axis=None, keepdims=True)
                image *= 1./(std + 1e-6)
                return image
            x = tf.map_fn(lambda image: normalize(image), x)

        if self.mean is not None:
            x -= self.mean

        if self.std is not None:
            x /= (self.std + 1e-6)

        if self.principal_components is not None:
            flatx = tf.reshape(x, self.flatshape)
            whitex = tf.tensordot(flatx, tf.convert_to_tensor(self.principal_components), axes=1)
            x = tf.reshape(whitex, self.shape)
        return x

使用例

batch_size設定後にmapで適用することを想定している。
下記のように連続して呼び出すこともできる。rescaleを使う場合は一度だけにすること。
uint8の画像が入力されることを想定。rescaleはtransform_augmentationだけで実施している。

def filter_img(images):
    images = tf.cast(images, tf.float32)
    images = transform_augmentation(images,
                        rescale=1.0/255.0,
                        rotation_range=30,
                        shear_range=30,
                        zoom_range=[0.8, 1.2],
                        horizontal_flip=True, 
                        width_shift_range=0.1,
                        height_shift_range=0.1,
                        fill_mode='REFLECT')
    images = color_augmentation(images,
                     hue_range=0.1,
                     brightness_range=0.2,
                     saturation_range=[0.8,1.2],
                     contrast_range=[0.8,1.2])
    return images
ds = tf.data.Dataset.from_tensor_slices((x,y)).repeat().batch(batch_size)
ds = ds.map(lambda image, label: (filter_img(image), tf.cast(label, tf.float32)))

以下は、標準化の使用例。
uint8の画像が入力されることを想定している。Standardizationを他のAugumentationと同時利用する場合はrescaleは使用するとちょっと面倒なので、生成時にデータを渡す際に255で割っておき、filter_img内でも同じ操作をしている。単体で使う場合はrescaleを使用しても問題ない。

std = ImageStandardization(x/255.0, zca_whitening=True) 
def filter_img(images):
    images = tf.cast(images, tf.float32)/255.0  # cast & rescale
    images = simple_augmentation(images,
                        horizontal_flip=True, 
                        width_shift_range=0.1,
                        height_shift_range=0.1)
    images = std.standardize(images)
    return images

ds = tf.data.Dataset.from_tensor_slices((x,y)).repeat().batch(batch_size)
ds = ds.map(lambda image, label: (filter_img(image), tf.cast(label, tf.float32)))

実験用のColabノートブックはこちら

CIFAR100出力サンプル(元画像と拡張後の画像を上下に連結している)
augment.png

参考

TensorFlow & TensorFlow Addons
https://www.tensorflow.org/api_docs/python/tf
https://www.tensorflow.org/addons/api_docs/python/tfa

ImageDataGenerator
https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/image/image_data_generator.py

アフィン変換関係
https://qiita.com/koshian2/items/c133e2e10c261b8646bf
https://imagingsolution.blog.fc2.com/blog-entry-284.html

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2