はじめに
KerasのImage Data GeneratorにCustom関数を実装する方法です。
前回の記事はこちらです。
やり方
Image Data Generatorのクラスを継承しカスタム関数を実装する方法は、下記の記事が有名です。
今回は、別の方法で実装します。今回のほうが短いコードで済みます。
preprocessing_functionを利用します。
実はImageDataGeneratorの引数として提供されている機能です。(あまり知られていない印象ですが)
下記のコードを参照してください。
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Generator
train_datagen = ImageDataGenerator(
preprocessing_function=None
# 各入力に適用される関数です.この関数は他の変更が行われる前に実行されます.この関数は3次元のNumpyテンソルを引数にとり,同じshapeのテンソルを出力するように定義する必要があります.
)
preprocessing_functionの書き方
例えば、Gaussian Blurを行うCustom Functionを定義します。
覚えることは二つです。
- 関数の引数のimgはnumpy arrayです。
- 関数の戻り値のdstもnumpy arrayです。
numpy arrayとしての画像を受け取り、必要な処理を行った後、numpy arrayとしての画像を返す。これだけです。
def our_preprocessing_function(img):
dst = cv2.GaussianBlur(img, (11,11), 3)
return dst
そしたら、上記のgeneratorには下記のように記述します。
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Generator
train_datagen = ImageDataGenerator(
preprocessing_function=our_preprocessing_function #ここの関数の名称を記入
)
テストコード
MVTecADのデータセットを読み取り、Gaussian Blurを行うカスタム関数を付けたImage Data Generatorを作ります。
# 01.Buildin
import os, time, math, random, pickle
# 02.2nd source
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import cv2
# Base path
base_path = 'E:\\mvtec_anomaly_detection\\carpet\\test'
# Variables
batch_size = 16
epochs_threshold = 10
def our_preprocessing_fun(numpy_img):
dst = cv2.GaussianBlur(numpy_img, (11,11), 3)
return dst
# Generator
train_datagen = ImageDataGenerator(
rotation_range=0, # 画像をランダムに回転する回転範囲
width_shift_range=0.0, # 浮動小数点数(横幅に対する割合).ランダムに水平シフトする範囲
height_shift_range=0.0, # 浮動小数点数(縦幅に対する割合).ランダムに垂直シフトする範囲.
rescale=1. / 255, # 画素値のリスケーリング係数.デフォルトはNone.Noneか0ならば,適用しない.それ以外であれば,(他の変換を行う前に) 与えられた値をデータに積算する.
preprocessing_function= our_preprocessing_function
)
train_generator = train_datagen.flow_from_directory(
directory=base_path,#ディレクトリへのパス.クラスごとに1つのサブディレクトリを含み,サブディレクトリはPNGかJPGかBMPかPPMかTIF形式の画像を含まなければいけません.
target_size=(256, 256),#整数のタプル(height, width).
color_mode='rgb',#"grayscale"か"rbg"の一方.
classes=None,#クラスサブディレクトリのリスト.(例えば,['dogs', 'cats'])
class_mode='categorical',
# "categorical"か"binary"か"sparse"か"input"か"None",
# "categorical"は2次元のone-hotにエンコード化されたラベル,
# "binary"は1次元の2値ラベル,
# "sparse"は1次元の整数ラベル,
# "input"は入力画像と同じ画像になります(主にオートエンコーダで用いられます).
# Noneであれば,ラベルを返しません(ジェネレーターは画像のバッチのみ生成するため,model.predict_generator()やmodel.evaluate_generator()などを使う際に有用).class_modeがNoneの場合,正常に動作させるためにはdirectoryのサブディレクトリにデータが存在する必要があることに注意してください.
batch_size=batch_size,#データのバッチのサイズ
shuffle=False,#データをシャッフルするかどうか
)
#全体のデータ
count_files = 0
for current_dir, sub_dirs, files_list in os.walk(base_path):
print(current_dir)
# print(sub_dirs)
print(files_list)
count_files = count_files + len(files_list)
print(count_files)
#Calculate the batch number per each epoch
steps_per_epochs = math.ceil(count_files/batch_size)
print('steps_per_epochs:', steps_per_epochs)
now_epoch = 0
imgs = []
for i, (data_batch, labels_batch) in enumerate(train_generator):
print('Data Batch shape:', data_batch.shape)
#epochsのカウント
if i % steps_per_epochs == 0:
now_epoch = now_epoch + 1
print('epochs:',now_epoch)
# 判断
if now_epoch > epochs_threshold:
break
#可視化
cv2.imshow('test',data_batch[0] )
cv2.waitKey(1)
print('Data Batch shape:', data_batch.shape)
print('Labels Batch shape:', labels_batch.shape)
画像処理が多段になった場合
その時は、下記のように定義すればOKです。
def transform1(img):
#Applies a transformation such as horizontal flip and returns the image
return cv2.flip(img, 1)
def transform2(img):
#Applies a transformation such as vertical flip and returns the image
return cv2.flip(img, 0)
def transform3(img):
#Applies 180-degree rotation and returns the image
return cv2.rotate(img, cv2.ROTATE_180)
def our_preprocessing_function(img):
#Combines all the transformations
# img = cv2.imread(filename)
img1 = transform1(img)
img2 = transform2(img1)
final_img = transform3(img2)
return final_img
train_datagen = ImageDataGenerator(
preprocessing_function=our_preprocessing_function
)
#参考資料:ここの説明は、カスタム関数の入出力がnumpy arrayではなく、file nameになっており、2022年のTensorflowでは動かないです。
#https://www.analyticsvidhya.com/blog/2020/11/extending-the-imagedatagenerator-keras-tensorflow/