LoginSignup
1
1

More than 3 years have passed since last update.

TensorFlow+KerasでCutoutを実装/評価する

Last updated at Posted at 2020-11-29

はじめに

別記事で各種Data Augmentationを実装した際に、TensorFlow Addonsを使うと簡単にCutoutを実装できることに気づいた。
ここではDatasetに対してmap適用する実装とその評価を行う。

環境

  • TensorFlow(2.3.0)
  • tf.keras(2.4.0)
  • tensorflow_addons(0.11.2)

Cutoutとは

画像のデータ拡張手法として提案されたもので、ランダムな座標の矩形を単色で塗りつぶす。
似たような手法にRandom Erasingがあるが、Cutoutのほうがより単純。
TensorFlow Addonsにはtfa.image.cutoutが用意されている。

実装

下記関数をDatasetのmap関数で適用する。

import tensorflow_addons as tfa
@tf.function
def cutout_augmentation(images, cutout_size, cval=0):
    img_shape = images.shape[-3:]
    mask_size = (int(img_shape[0]*cutout_size[0])//2*2, int(img_shape[1]*cutout_size[1])//2*2)
    images = tfa.image.random_cutout(images, mask_size, constant_values=cval)
    return images

cutout_sizeは塗りつぶす矩形のサイズを指定する。元画像からの比率を[H,W]としてfloatで指定する。32x32画像の場合は[0.5,0.5]で16x16の範囲を塗りつぶす。
cvalは塗りつぶす色。0.0~1.0の画像の場合はcval=0.5で灰色となる。

tensorflow-addonsのバージョンが古いと使えないので、Google Colabでは!pip install --upgrade tensorflow-addonsを事前に実行する。

実行結果はこちら。
cutout.png

評価

矩形サイズの設定値の影響や、他のAugmentationとの比較を行う。

  • Google ColabのTPUでCIFAR10の学習を3回行い、ValidationのAccuracyの中間値で比較
  • Vertical Flipのみ使用したものをベースとして、そこに以下のAugmentation処理を追加して実施
    • Cutout(0.4~0.8)
    • Width/Height Shift (range=0.25,fill_mode='constant'および'reflect' )
    • Cutout + Width/Height Shift('reflect')
  • SGD(lr=0.01/momentum=0.9)で100epoch、100~150epochはlr=0.001に変更

結果はこちら。

拡張法 cutout fill_mode Accuracy
Vertical Flip - - 89.15
Cutout 0.4 - 91.80
Cutout 0.5 - 92.31
Cutout 0.6 - 92.50
Cutout 0.7 - 92.63
Cutout 0.8 - 92.53
Shift - constant 92.86
Shift - reflect 93.39
Cutout+Shift 0.5 reflect 94.43

Cutoutのサイズ比較では0.7が最も良好だった。0.7*0.7=0.49なので、ほぼ半分が消えることになる。
元論文ではCIFAR10では16x16つまり0.5が最も良いとされていたので結果にはずれがあった。モデルの違いが影響しているのか?。
同論文ではCIFAR100の場合はもっとサイズを下げたほうが好結果となっている。学習内容によってパラメータ調整する余地があるようだ。
Shiftとの比較になると、Shiftのほうが好結果だった。
全体としては、ShiftとCutoutを混ぜたものが最も好結果となった。Shift(reflect)からは1%の認識率向上があるが、このレベルでは結構大きな差なので同時に適用する価値は十分あると思われる。

実験用コードはこちら

参考

ImageDataGeneratorを拡張しcutoutを実装する
PyTorchでデータ水増し(Data Augmentation)する方法

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1