LoginSignup
2
6

More than 3 years have passed since last update.

TensorFlowにおける画像の正規化

Last updated at Posted at 2021-01-17

この記事について

kerasのImageDataGeneratorで画素値をリスケーリングする際にrescaleを用いていたのですが、[-1 ~ 1]の範囲で正規化したいときに使用できなかったで調べたことをメモとして残しておきます。

引数 - rescale

一般的に

from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
datagen = ImageDataGenerator(rescale=1./255)

のように使うことが多いかと思います。元の画素値とrescaleに与えられた値の積をとります。[0 ~ 255]を[0 ~ 1]に正規化する際に便利です。ImageDataGeneratorでは、他の変換(augmentation)を行う前にこの関数が適用されます。

引数 - preprocessing_function

まずは、使い方から確認してみます。

from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    )

preprocess_fucntion : 各入力に適用される関数。この関数は他の変更が行われる前に実行される。この関数は3次元のNumpyテンソルを引数にとり、同じshapeのテンソルを出力するように定義する必要あり。

tf.keras.applicationsのモデルを使用する際は、モデルによって[0 ~ 1] or [-1 ~ 1]など正規化の値の範囲が決まっています。[-1 ~ 1]の範囲なのに[0 ~ 1]で学習させても学習は上手くいくことが多いです(あまりよろしくはないが、、、、)。

そこで、各モデルにpreprocess_inputという便利な関数があります。
以下は、mobilenetv2の例です。

from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    )

このように記述することで、モデルに適した値に正規化してくれます。
単体でも使用できます。変換前と変換後の値が以下のコードで確認可能です。

import numpy as np
import tensorflow as tf 
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras.preprocessing import image

model = tf.keras.applications.MobileNetV2()

img_path = 'XXXXX.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
print(x.max())
print(x.min())
x = preprocess_input(x)
print(x.max())
print(x.min())

tf.keras.layers.experimental.preprocessing.Rescaling()

別の正規化の方法として、tf.keras.layers.experimental.preprocessing.Rescaling()があります。
引数として、scaleoffsetがあります。scaleの値は、入力値にかける。offsetの値は、足すイメージです。

[0, 255]->[-1, 1]のときは、Rescaling(scale=1./127.5, offset=-1)でokayです。



base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                          include_top=False,
                                          weights='imagenet')

model = tf.keras.models.Sequential([
        tf.keras.layers.experimental.preprocessing.Rescaling(scale=1./127.5, offset=-1),
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(2, activation='softmax'),
        ])


終わりに

普段、TFを使用しないので知らないことが多いなと思いました。あとは、作成したモデルに対して画像を入力際はどの値の範囲に正規化すべきかシビアになりましょう。正規化の範囲が違う場合は、予測がかなり変わってしまいます。webのアプリケーションとかにAIモデルを組み込む際は特に注意が必要ですね。

参考文献

2
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
6