LoginSignup
2
5

ImageDataGeneratorを利用してCNNへデータを流し込む方法

Last updated at Posted at 2023-03-07

画像ファイルを直接学習データとしてモデルに流し込む際に、ImageDataGeneratorがよく利用されています。同一の画像データをアレンジして学習データとして利用するツールとしても有名です。データの流れや使い方をしっかり理解できていなかったので、調べました。もしよく分からないという方がいらっしゃいましたら、参考になると思います。

スクリーンショット 2023-03-07 195410.jpg

シナリオ

以下の10枚の同一のモヤっと君画像と、10枚の同一のおねーさん画像が手元にある。
これらのデータをImageDataGeneratorで、良い感じにアレンジして、CNNに直接流し込む。
このシナリオで説明します。

image.png

image.png

画像データをラベル別に保管

以下のディレクトリ構成で保管します。

image.png

ImageDataGeneratorで画像データの返還方法を定義

まずはImageDataGeneratorをインポートしてインスタンス化します。
この時、画像データをどのようにアレンジしたいかを定義します。

from keras.preprocessing.image import ImageDataGenerator

# 画像データの返還方法
datagen = ImageDataGenerator(
    rescale=1/255,        # 画像のピクセル値を(0-255)から(0-1)の範囲に正規化
    rotation_range=100,   # ±100°の範囲でランダムに回転
    shear_range=0.5,      # ±20°の範囲で斜めに引き延ばし
    zoom_range=0.5,       # 0.5-1倍の範囲で縮小(結果大きくなる)
    horizontal_flip=True, # ランダムで左右反転
    )

CNNへの画像データの流し込み方の定義

次に、CNNへの画像データの流し込み方を定義します。

train_generator = datagen.flow_from_directory(
    'data',                         # データ格納フォルダ
    target_size=(150, 150),         # 流し込むときのデータのピクセル
    batch_size=6,                   # 一度に流し込むデータ数
    classes=['moyatto', 'onesan'],  # 流し込むデータのフォルダ
    )

出力👇 データ格納フォルダで確認された、画像データ数とラベルの種類数が表示されます。
image.png

次に、どのように流し込まれるか確認しましょう。
nextメソッドで1バッチ分の流れるデータを確認できます。
つまり、以下のスクリプトを実行すると6データが流出します。

train_generator.next()

このデータに含まれる画像を確認します。

train_generator.next()[0].shape

出力👇 6つの画像データで、150x150ピクセル×3色であることが分かります。
image.png

ラベルを確認します。

train_generator.next()[1]

出力👇 One-hotベクトルで出力されます。
image.png

データを可視化してみます。

image, label = train_generator.next()

import matplotlib.pyplot as plt
fig = plt.figure(figsize=(20, 20))
for i in range(6):
    fig.add_subplot(1, 6, i+1).set_title(label[i])
    plt.imshow(image[i])

出力👇 アレンジ画像で、ラベルもモヤっと君は[1,0], おねーさんは[0,1]であることが確認できます。
image.png

モデルの定義

モデルの定義は、一般的な方法で問題ありません。

from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense

model = Sequential()
model.add(Convolution2D(32, (3, 3), input_shape=(150, 150, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(2))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

学習

学習時に引数としてtrain_generatorをセットします。

history = model.fit(
        train_generator,
        epochs=10,
        )

出力👇 バッチサイズが6、データ総数が20なので、ステップ数は4となっていることも確認できます。(4ステップ目はデータ数が2つとなります)
image.png

おまけ(私が制作したものの紹介)

これで、ImageDataGeneratorをつかってCNNしてるソースコードは、お菓子たべながらでも読めそうです。画像データを引っ張ったりクルクルするのも、なんか楽しめそう。
あと、AI、機械学習は面白いよって、エンジニアだけじゃなく、ビジネスサイドにも広めていこうね。
👇私が制作したものです。もしよかったら、ご覧ください。
https://www.udemy.com/course/aiforbiz/?referralCode=67BB575DF596D8903B08

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5