12
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【画像前処理】tf.keras.preprocessing.image.ImageDataGenerator

Last updated at Posted at 2020-09-13

ImageDataGenerator とは

Generate batches of tensor image data with real-time data augmentation.
リアルタイムのデータ増強でテンソル画像データのバッチを生成。

このリアルタイムというのは学習実行時のことです。
学習実行時にファイル(データ)を逐次読み込んでくれます。

flow_from_directory(), flow_from_dataframe()を使用することで
学習時にメモリに乗り切らない大量の画像も学習可能になります。

メリット

  • OpenCV, Pillow 不要
  • 画像読み込み、ラベル付け、NumPy 変換、正規化、データ分割を一度にできる
  • 大量のデータを扱う場合の学習時メモリ不足解消
  • データ拡張(水増し)可能
  • 従来の手法と学習時間の差は無し

犬猫分類で実装

  • dog_cat_data
    • train
      • cat:150 枚
      • dog:150 枚
    • test
      • cat:100 枚
      • dog:100 枚
from tensorflow.keras.preprocessing.image import ImageDataGenerator
ImageDataGeneratorクラスのインスタンス
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.3
    )
  • rescale で正規化
  • validation_split で検証用データセット分割可能

他にも、データ拡張(水増し)もできます。
詳しく知りたい方はこちらKIKAGAKU を参考にしてみてください。

画像データの読み込み
batch_size=10

train_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    batch_size=batch_size,
    subset='training',
)

val_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    batch_size=batch_size,
    subset='validation'
)

>>> Found 210 images belonging to 2 classes.
>>> Found 90 images belonging to 2 classes.
  • target_size:デフォルト(256, 256)
  • class_modebinary で 2 クラス分類、categorical で多クラス。None でディレクトリから自動でラベル付け。
  • batch_size:デフォルト 32:学習時のバッチサイズにも影響する。
  • subsettraining で学習用、validation で検証用。ImageGenerator クラスのインスタンスで validation_split を指定した場合に可能。

今回は train に犬猫 150 枚ずつ、計 300 枚。
バッチサイズを 10 にしたためこうなる。

確認
# train 画像データ
train_generator.image_shape
>>> (224, 224, 3)

# train ラベル
train_generator.class_indices
>>> {'cat': 0, 'dog': 1}

# train 全てのラベル
train_generator.classes
>>> array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

1 つのバッチ分を可視化して確認してみます。

可視化
training_images, training_labels = next(train_generator)

def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

print(training_labels)
plotImages(training_images[:5])

>>> [1. 1. 0. 1. 0. 1. 0. 0. 0. 1.]

スクリーンショット 2020-09-14 1.35.25.png

しっかりラベル付け、シャッフルできています。
データ準備がこれで完了です!

あとはモデルを構築して、学習
ImageDataGenerator の挙動確認のため、精度は気にしない。

学習

fit_generator が tf 2.1 から .fit に統合されました。
そのためこれまで通り .fit で学習できます。

学習
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator
)
>>> Epoch 1/10
21/21 [==============================] - 1s 52ms/step - loss: 0.9283 - sparse_categorical_accuracy: 0.5333 - val_loss: 0.7018 - val_sparse_categorical_accuracy: 0.5000
.
.
.

1epoch で 21/21 なので ImageDataGenerator のバッチサイズが適用されていることがわかります。

テスト用データセットで検証

test には犬猫 200 枚が用意されています。

test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/test',
    target_size=(224, 224),
    class_mode='binary',
    batch_size=batch_size
)

model.evaluate(test_generator)

>>> 20/20 [==============================] - 0s 22ms/step - loss: 0.6969 - sparse_categorical_accuracy: 0.5000
[0.6969258189201355, 0.5]
12
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?