Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
3
Help us understand the problem. What is going on with this article?
@taichinakabeppu

学習済みモデルの比較と前処理

tensorflow.keras.applications

学習済みモデルの比較をします。
ImageNet で使用した前処理を適用します。

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.backend as K
from tensorflow.keras import models, layers
from tensorflow.keras.applications import DenseNet201, EfficientNetB7, InceptionResNetV2, InceptionV3, MobileNet, MobileNetV2, NASNetMobile, ResNet101, ResNet50, ResNet50V2, VGG16, VGG19, Xception

学習済みモデル一覧

Module:tf.keras.applications
に公開されているもの。

  • 1. densenet module: DenseNet models for Keras.
  • 2. efficientnet module: EfficientNet models for Keras.
  • 3. inception_resnet_v2 module: Inception-ResNet V2 model for Keras.
  • 4. inception_v3 module: Inception V3 model for Keras.
  • 5. mobilenet module: MobileNet v1 models for Keras.
  • 6. mobilenet_v2 module: MobileNet v2 models for Keras.
  • 7. nasnet module: NASNet-A models for Keras.
  • 8. resnet module: ResNet models for Keras.
  • 9. resnet50 module: Public API for tf.keras.applications.resnet50 namespace.
  • 10. resnet_v2 module: ResNet v2 models for Keras.
  • 11. vgg16 module: VGG16 model for Keras.
  • 12. vgg19 module: VGG19 model for Keras.
  • 13. xception module: Xception V1 model for Keras.

条件

データ

  • train
    • cat : 150
    • dog : 150
  • test
    • cat : 100
    • dog : 100

スクリーンショット 2020-09-25 22.05.13.png

前処理

  • preprocessing_function:tf.keras.applications.[学習済みモデル].preprocess_input
  • target_size=(224, 224):固定
ImageDataGeneratorインスタンス
# train, validation 用
datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.densenet.preprocess_input,
    validation_split=0.3)

# test 用
test_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.densenet.preprocess_input)
generator生成
train_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    subset='training')

val_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    subset='validation')

test_generator = test_datagen.flow_from_directory(
    '/content/dog_cat_data/test',
    target_size=(224, 224),
    class_mode='binary')
前処理後の画像可視化
# train_generator 1 バッチ目の先頭 5 件のみ表示
train_images, train_labels = next(train_generator)

def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

print(train_labels[: 5])
plotImages(train_images[: 5])

学習

  • include_top=False で全結合層は除く。
  • weights='imagenet'
  • input_shape=(224, 224, 3)
  • パラメータはすべて凍結
  • batch_size=32:generator のデフォルト。
  • epochs=20
  • pooling='avg':GlobalAveragePooling2D を使用する。
学習済みモデル
base_model = DenseNet201(
    include_top=False,
    weights='imagenet',
    input_shape=(224, 224, 3),
    pooling='avg')
学習
K.clear_session()
tf.random.set_seed(0)

# パラメータ凍結
for layer in base_model.layers:
    layer.trainable = False

# モデルの構築
x = base_model.output
x = layers.Dense(2, activation='softmax')(x)
model = models.Model(base_model.input, x)

# コンパイル
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=optimizer,
              metrics=['accuracy'])

# 学習
history = model.fit(
    train_generator,
    epochs=20,
    verbose=1,
    validation_data=val_generator)
評価
model.evaluate(test_generator)

>>> [loss, accuracy]

あとは

  • ImageDataGenerator の preprocess_input
  • base_model のインスタンス化のモデル

を変えて実装していきます。

1. DenseNet201

前処理後の画像
スクリーンショット 2020-09-25 20.34.07.png
結果

[0.24632257223129272, 0.9449999928474426]

2. EfficientNetB7

前処理後の画像
スクリーンショット 2020-09-25 20.40.30.png

結果

[0.2533932626247406, 0.9900000095367432]

3. InceptionResNetV2

前処理後の画像
スクリーンショット 2020-09-25 20.55.27.png

結果

[0.11524547636508942, 0.9750000238418579]

4. InceptionV3

前処理後の画像
スクリーンショット 2020-09-25 21.00.49.png

結果

[0.2021591067314148, 0.9549999833106995]

5. MobileNet

前処理後の画像
スクリーンショット 2020-09-25 21.04.52.png
結果

[0.25840920209884644, 0.9200000166893005]

6. MobileNetV2

前処理後の画像
スクリーンショット 2020-09-25 21.07.50.png
結果

[0.3347647786140442, 0.9049999713897705]

7. NASNetMobile

前処理後の画像
スクリーンショット 2020-09-25 21.11.21.png
結果

[0.20991472899913788, 0.9599999785423279]

8. ResNet101

前処理後の画像
スクリーンショット 2020-09-25 21.16.43.png
結果

[0.164097860455513, 0.949999988079071]

9. ResNet50

前処理語の画像
スクリーンショット 2020-09-25 21.21.32.png
結果

[0.23483356833457947, 0.9150000214576721]

10. ResNet50V2

前処理後の画像
スクリーンショット 2020-09-25 21.25.43.png
結果

[0.2073519378900528, 0.9100000262260437]

11. VGG16

前処理後の画像
スクリーンショット 2020-09-25 21.28.25.png
結果

[1.3361703157424927, 0.7450000047683716]

12. VGG19

前処理後の画像
スクリーンショット 2020-09-25 21.30.34.png
結果

[1.7525851726531982, 0.6399999856948853]

13. Xception

前処理後の画像
スクリーンショット 2020-09-25 21.32.48.png
結果

[0.17928537726402283, 0.9750000238418579]

最後に

EfficientNet 前処理がよくわからないけれど強いですね。
速さはダントツで MobileNet。
VGG 系はもっとエポック数が必要ですね。

とりあえず、学習済みモデルを使う際も前処理は大事だと感じました。

3
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
3
Help us understand the problem. What is going on with this article?