4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【備忘録】ImageData GeneratorにCustom関数を実装

Last updated at Posted at 2022-05-06

はじめに

KerasのImage Data GeneratorにCustom関数を実装する方法です。
前回の記事はこちらです。

やり方

Image Data Generatorのクラスを継承しカスタム関数を実装する方法は、下記の記事が有名です。

今回は、別の方法で実装します。今回のほうが短いコードで済みます。
preprocessing_functionを利用します。
実はImageDataGeneratorの引数として提供されている機能です。(あまり知られていない印象ですが)
下記のコードを参照してください。


import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Generator
train_datagen = ImageDataGenerator(
        preprocessing_function=None
    # 各入力に適用される関数です.この関数は他の変更が行われる前に実行されます.この関数は3次元のNumpyテンソルを引数にとり,同じshapeのテンソルを出力するように定義する必要があります.
   )

preprocessing_functionの書き方

例えば、Gaussian Blurを行うCustom Functionを定義します。

覚えることは二つです。

  1. 関数の引数のimgはnumpy arrayです。
  2. 関数の戻り値のdstもnumpy arrayです。

numpy arrayとしての画像を受け取り、必要な処理を行った後、numpy arrayとしての画像を返す。これだけです。


def our_preprocessing_function(img):
    dst = cv2.GaussianBlur(img, (11,11), 3)
    return dst

そしたら、上記のgeneratorには下記のように記述します。


import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Generator
train_datagen = ImageDataGenerator(
        preprocessing_function=our_preprocessing_function #ここの関数の名称を記入
   
   )

テストコード

MVTecADのデータセットを読み取り、Gaussian Blurを行うカスタム関数を付けたImage Data Generatorを作ります。


# 01.Buildin
import os, time, math, random, pickle

# 02.2nd source
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import cv2


# Base path
base_path = 'E:\\mvtec_anomaly_detection\\carpet\\test'



# Variables
batch_size = 16
epochs_threshold = 10

def our_preprocessing_fun(numpy_img):
    dst = cv2.GaussianBlur(numpy_img, (11,11), 3)
    return dst

# Generator
train_datagen = ImageDataGenerator(

    rotation_range=0,  # 画像をランダムに回転する回転範囲
    width_shift_range=0.0,  # 浮動小数点数(横幅に対する割合).ランダムに水平シフトする範囲
    height_shift_range=0.0,  # 浮動小数点数(縦幅に対する割合).ランダムに垂直シフトする範囲.
    rescale=1. / 255,  # 画素値のリスケーリング係数.デフォルトはNone.Noneか0ならば,適用しない.それ以外であれば,(他の変換を行う前に) 与えられた値をデータに積算する.
    preprocessing_function= our_preprocessing_function
)


train_generator = train_datagen.flow_from_directory(
    directory=base_path,#ディレクトリへのパス.クラスごとに1つのサブディレクトリを含み,サブディレクトリはPNGかJPGかBMPかPPMかTIF形式の画像を含まなければいけません.
    target_size=(256, 256),#整数のタプル(height, width).
    color_mode='rgb',#"grayscale"か"rbg"の一方.
    classes=None,#クラスサブディレクトリのリスト.(例えば,['dogs', 'cats'])
    class_mode='categorical',
    # "categorical"か"binary"か"sparse"か"input"か"None",
    # "categorical"は2次元のone-hotにエンコード化されたラベル,
    # "binary"は1次元の2値ラベル,
    # "sparse"は1次元の整数ラベル,
    # "input"は入力画像と同じ画像になります(主にオートエンコーダで用いられます).
    # Noneであれば,ラベルを返しません(ジェネレーターは画像のバッチのみ生成するため,model.predict_generator()やmodel.evaluate_generator()などを使う際に有用).class_modeがNoneの場合,正常に動作させるためにはdirectoryのサブディレクトリにデータが存在する必要があることに注意してください.
    batch_size=batch_size,#データのバッチのサイズ
    shuffle=False,#データをシャッフルするかどうか

)


#全体のデータ

count_files = 0

for current_dir, sub_dirs, files_list in os.walk(base_path):
    print(current_dir)
    # print(sub_dirs)
    print(files_list)

    count_files = count_files + len(files_list)

print(count_files)

#Calculate the batch number per each epoch

steps_per_epochs = math.ceil(count_files/batch_size)
print('steps_per_epochs:', steps_per_epochs)

now_epoch = 0
imgs = []

for i, (data_batch, labels_batch) in enumerate(train_generator):
    print('Data Batch shape:', data_batch.shape)

    #epochsのカウント
    if i % steps_per_epochs == 0:
        now_epoch = now_epoch + 1
        print('epochs:',now_epoch)

    # 判断
    if now_epoch > epochs_threshold:
        break

    #可視化
    cv2.imshow('test',data_batch[0] )
    cv2.waitKey(1)

    print('Data Batch shape:', data_batch.shape)
    print('Labels Batch shape:', labels_batch.shape)


画像処理が多段になった場合

その時は、下記のように定義すればOKです。


def transform1(img):
    #Applies a transformation such as horizontal flip and returns the image
    return cv2.flip(img, 1)
def transform2(img):
    #Applies a transformation such as vertical flip and returns the image
    return cv2.flip(img, 0)
def transform3(img):
    #Applies 180-degree rotation and returns the image
    return cv2.rotate(img, cv2.ROTATE_180)

def our_preprocessing_function(img):
    #Combines all the transformations
    # img = cv2.imread(filename)
    img1 = transform1(img)
    img2 = transform2(img1)
    final_img = transform3(img2)
    return final_img


train_datagen = ImageDataGenerator(
        preprocessing_function=our_preprocessing_function
   
   )

#参考資料:ここの説明は、カスタム関数の入出力がnumpy arrayではなく、file nameになっており、2022年のTensorflowでは動かないです。
#https://www.analyticsvidhya.com/blog/2020/11/extending-the-imagedatagenerator-keras-tensorflow/

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?