3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【備忘録】KerasのImageData Generator

Last updated at Posted at 2022-05-04

はじめに

KerasのImageData Generatorの理解を深めたい。

Iteratorの動きに関しては、下記のポストをご参照ください。

やりたいこと

学習の時、model.fit()を使わず、for文でやりたい。そちらのほうがいろいろ処理の自由度が増すので。

Image Data Generatorの挙動の基礎

# 01.Buildin
import os, time, math, random, pickle

# 02.2nd source
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Base path
base_path = 'G:\\マイドライブ\\datasets\\mvtec_anomaly_detection\\bottle\\test'

# Generator
train_datagen = ImageDataGenerator(
    featurewise_center=False,  # データセット全体で,入力の平均を0にします
    samplewise_center=False,  # 各サンプルの平均を0にします
    featurewise_std_normalization=False,  # 入力をデータセットの標準偏差で正規化します
    samplewise_std_normalization=False,  # 各入力をその標準偏差で正規化します
    zca_whitening=False,  # ZCA白色化を適用します
    zca_epsilon=1e-06,  # ZCA白色化のイプシロン.デフォルトは1e-6
    rotation_range=0.0,  # 画像をランダムに回転する回転範囲
    width_shift_range=0.0,  # 浮動小数点数(横幅に対する割合).ランダムに水平シフトする範囲
    height_shift_range=0.0,  # 浮動小数点数(縦幅に対する割合).ランダムに垂直シフトする範囲.
    brightness_range=None,  #
    shear_range=0.0,  # 浮動小数点数.シアー強度(反時計回りのシアー角度)
    zoom_range=0.0,
    # 浮動小数点数または[lower,upper].ランダムにズームする範囲.浮動小数点数が与えられた場合,[lower, upper] = [1-zoom_range, 1+zoom_range]です.
    channel_shift_range=0.0,  # 浮動小数点数.ランダムにチャンネルをシフトする範囲.
    fill_mode='nearest',  #: {"constant", "nearest", "reflect", "wrap"}
    cval=0.0,  # 浮動小数点数または整数.fill_mode = "constant"のときに境界周辺で利用される値.
    horizontal_flip=False,  # 水平方向に入力をランダムに反転します.
    vertical_flip=False,  # 垂直方向に入力をランダムに反転します.
    rescale=1. / 255,  # 画素値のリスケーリング係数.デフォルトはNone.Noneか0ならば,適用しない.それ以外であれば,(他の変換を行う前に) 与えられた値をデータに積算する.
    preprocessing_function=None,
    # 各入力に適用される関数です.この関数は他の変更が行われる前に実行されます.この関数は3次元のNumpyテンソルを引数にとり,同じshapeのテンソルを出力するように定義する必要があります.
    data_format=None,  # {"channels_first", "channels_last"}のどちらか
    validation_split=0.0  # 浮動小数点数.検証のために予約しておく画像の割合
)

# Directoryか画像を読み込む

train_generator = train_datagen.flow_from_directory(
    directory=base_path,#ディレクトリへのパス.クラスごとに1つのサブディレクトリを含み,サブディレクトリはPNGかJPGかBMPかPPMかTIF形式の画像を含まなければいけません.
    target_size=(256, 256),#整数のタプル(height, width).
    color_mode='rgb',#"grayscale"か"rbg"の一方.
    classes=None,#クラスサブディレクトリのリスト.(例えば,['dogs', 'cats'])
    class_mode='categorical',# "categorical"か"binary"か"sparse"か"input"か"None""categorical"は2次元のone-hotにエンコード化されたラベル,"binary"は1次元の2値ラベル,"sparse"は1次元の整数ラベル,"input"は入力画像と同じ画像になります(主にオートエンコーダで用いられます).Noneであれば,ラベルを返しません(ジェネレーターは画像のバッチのみ生成するため,model.predict_generator()やmodel.evaluate_generator()などを使う際に有用).class_modeがNoneの場合,正常に動作させるためにはdirectoryのサブディレクトリにデータが存在する必要があることに注意してください.
    batch_size=32,#データのバッチのサイズ
    shuffle=True,#データをシャッフルするかどうか
    seed=None,#シャッフルや変換のためのオプションの乱数シード
    save_to_dir=None,# Noneまたは文字列(デフォルト: None).生成された拡張画像を保存するディレクトリを指定できます
    save_prefix='',#文字列.画像を保存する際にファイル名に付けるプリフィックス
    save_format='png',#"png"または"jpeg"
    follow_links=False,#
    subset=None,#
    interpolation='nearest'#

)


#Check
for data_batch, labels_batch in train_generator:

    print('Data Batch shape:', data_batch.shape)
    print('Labels Batch shape:', labels_batch.shape)

無限ループで、いつになってもデータを渡してくれる。

結果

Found 83 images belonging to 4 classes.
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
Data Batch shape: (19, 256, 256, 3)
Labels Batch shape: (19, 4)
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
Data Batch shape: (19, 256, 256, 3)
Labels Batch shape: (19, 4)
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
Data Batch shape: (19, 256, 256, 3)
Labels Batch shape: (19, 4)
Data Batch shape: (32, 256, 256, 3)
Labels Batch shape: (32, 4)
...

Image Data Generatorを利用し、データを渡すプログラム

動きはわかったので、10 epochsだけ回るプログラムを作成します。

# 01.Buildin
import os, time, math, random, pickle

# 02.2nd source
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Variables
batch_size = 16
epochs_threshold = 10

# Base path
base_path = 'G:\\マイドライブ\\datasets\\mvtec_anomaly_detection\\bottle\\test'



# Generator
train_datagen = ImageDataGenerator(

    rotation_range=1,  # 画像をランダムに回転する回転範囲
    width_shift_range=0.1,  # 浮動小数点数(横幅に対する割合).ランダムに水平シフトする範囲
    height_shift_range=0.1,  # 浮動小数点数(縦幅に対する割合).ランダムに垂直シフトする範囲.
    rescale=1. / 255,  # 画素値のリスケーリング係数.デフォルトはNone.Noneか0ならば,適用しない.それ以外であれば,(他の変換を行う前に) 与えられた値をデータに積算する.

)

# Directoryか画像を読み込む


train_generator = train_datagen.flow_from_directory(
    directory=base_path,#ディレクトリへのパス.クラスごとに1つのサブディレクトリを含み,サブディレクトリはPNGかJPGかBMPかPPMかTIF形式の画像を含まなければいけません.
    target_size=(256, 256),#整数のタプル(height, width).
    color_mode='rgb',#"grayscale"か"rbg"の一方.
    classes=None,#クラスサブディレクトリのリスト.(例えば,['dogs', 'cats'])
    class_mode='categorical',
    # "categorical"か"binary"か"sparse"か"input"か"None",
    # "categorical"は2次元のone-hotにエンコード化されたラベル,
    # "binary"は1次元の2値ラベル,
    # "sparse"は1次元の整数ラベル,
    # "input"は入力画像と同じ画像になります(主にオートエンコーダで用いられます).
    # Noneであれば,ラベルを返しません(ジェネレーターは画像のバッチのみ生成するため,model.predict_generator()やmodel.evaluate_generator()などを使う際に有用).class_modeがNoneの場合,正常に動作させるためにはdirectoryのサブディレクトリにデータが存在する必要があることに注意してください.
    batch_size=batch_size,#データのバッチのサイズ
    shuffle=False,#データをシャッフルするかどうか

)

#全体のデータ

count_files = 0

for current_dir, sub_dirs, files_list in os.walk(base_path):
    print(current_dir)
    # print(sub_dirs)
    print(files_list)

    count_files = count_files + len(files_list)

print(count_files)

#Calculate the batch number per each epoch

steps_per_epochs = math.ceil(count_files/batch_size)
print('steps_per_epochs:', steps_per_epochs)

now_epoch = 0

for i, (data_batch, labels_batch) in enumerate(train_generator):

    #epochsのカウント
    if i % steps_per_epochs == 0:
        now_epoch = now_epoch + 1
        print('epochs:',now_epoch)

    # 判断
    if now_epoch > epochs_threshold:
        break

    print('Data Batch shape:', data_batch.shape)
    print('Labels Batch shape:', labels_batch.shape)


結果

epochs: 10
Data Batch shape: (16, 256, 256, 3)
Labels Batch shape: (16, 4)
Data Batch shape: (16, 256, 256, 3)
Labels Batch shape: (16, 4)
Data Batch shape: (16, 256, 256, 3)
Labels Batch shape: (16, 4)
Data Batch shape: (16, 256, 256, 3)
Labels Batch shape: (16, 4)
Data Batch shape: (16, 256, 256, 3)
Labels Batch shape: (16, 4)
Data Batch shape: (3, 256, 256, 3)
Labels Batch shape: (3, 4)
epochs: 11

プロセスは終了コード 0 で終了しました

注意点

generatorのclass_modeの引数は、様々は使い方ができることを忘れないようにしましょう。

 class_mode='categorical',
    # "categorical"か"binary"か"sparse"か"input"か"None",
    # "categorical"は2次元のone-hotにエンコード化されたラベル,
    # "binary"は1次元の2値ラベル,
    # "sparse"は1次元の整数ラベル,
    # "input"は入力画像と同じ画像になります(主にオートエンコーダで用いられます).
    # Noneであれば,ラベルを返しません(ジェネレーターは画像のバッチのみ生成するため,model.predict_generator()やmodel.evaluate_generator()などを使う際に有用).class_modeがNoneの場合,正常に動作させるためにはdirectoryのサブディレクトリにデータが存在する必要があることに注意してください.
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?