Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
2
Help us understand the problem. What is going on with this article?
@taichinakabeppu

【画像前処理】tf.keras.preprocessing.image.ImageDataGenerator

ImageDataGenerator とは

Generate batches of tensor image data with real-time data augmentation.
リアルタイムのデータ増強でテンソル画像データのバッチを生成。

このリアルタイムというのは学習実行時のことです。
学習実行時にファイル(データ)を逐次読み込んでくれます。

flow_from_directory(), flow_from_dataframe()を使用することで
学習時にメモリに乗り切らない大量の画像も学習可能になります。

メリット

  • OpenCV, Pillow 不要
  • 画像読み込み、ラベル付け、NumPy 変換、正規化、データ分割を一度にできる
  • 大量のデータを扱う場合の学習時メモリ不足解消
  • データ拡張(水増し)可能
  • 従来の手法と学習時間の差は無し

犬猫分類で実装

  • dog_cat_data
    • train
      • cat:150 枚
      • dog:150 枚
    • test
      • cat:100 枚
      • dog:100 枚
from tensorflow.keras.preprocessing.image import ImageDataGenerator
ImageDataGeneratorクラスのインスタンス
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.3
    )
  • rescale で正規化
  • validation_split で検証用データセット分割可能

他にも、データ拡張(水増し)もできます。
詳しく知りたい方はこちらKIKAGAKU を参考にしてみてください。

画像データの読み込み
batch_size=10

train_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    batch_size=batch_size,
    subset='training',
)

val_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    batch_size=batch_size,
    subset='validation'
)

>>> Found 210 images belonging to 2 classes.
>>> Found 90 images belonging to 2 classes.
  • target_size:デフォルト(256, 256)
  • class_modebinary で 2 クラス分類、categorical で多クラス。None でディレクトリから自動でラベル付け。
  • batch_size:デフォルト 32:学習時のバッチサイズにも影響する。
  • subsettraining で学習用、validation で検証用。ImageGenerator クラスのインスタンスで validation_split を指定した場合に可能。

今回は train に犬猫 150 枚ずつ、計 300 枚。
バッチサイズを 10 にしたためこうなる。

確認
# train 画像データ
train_generator.image_shape
>>> (224, 224, 3)

# train ラベル
train_generator.class_indices
>>> {'cat': 0, 'dog': 1}

# train 全てのラベル
train_generator.classes
>>> array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

1 つのバッチ分を可視化して確認してみます。

可視化
training_images, training_labels = next(train_generator)

def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

print(training_labels)
plotImages(training_images[:5])

>>> [1. 1. 0. 1. 0. 1. 0. 0. 0. 1.]

スクリーンショット 2020-09-14 1.35.25.png

しっかりラベル付け、シャッフルできています。
データ準備がこれで完了です!

あとはモデルを構築して、学習
ImageDataGenerator の挙動確認のため、精度は気にしない。

学習

fit_generator が tf 2.1 から .fit に統合されました。
そのためこれまで通り .fit で学習できます。

学習
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator
)
>>> Epoch 1/10
21/21 [==============================] - 1s 52ms/step - loss: 0.9283 - sparse_categorical_accuracy: 0.5333 - val_loss: 0.7018 - val_sparse_categorical_accuracy: 0.5000
.
.
.

1epoch で 21/21 なので ImageDataGenerator のバッチサイズが適用されていることがわかります。

テスト用データセットで検証

test には犬猫 200 枚が用意されています。

test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/test',
    target_size=(224, 224),
    class_mode='binary',
    batch_size=batch_size
)

model.evaluate(test_generator)

>>> 20/20 [==============================] - 0s 22ms/step - loss: 0.6969 - sparse_categorical_accuracy: 0.5000
[0.6969258189201355, 0.5]
2
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
2
Help us understand the problem. What is going on with this article?