11
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

学習済みモデルの比較と前処理

Posted at

tensorflow.keras.applications

学習済みモデルの比較をします。
ImageNet で使用した前処理を適用します。

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras.backend as K
from tensorflow.keras import models, layers
from tensorflow.keras.applications import DenseNet201, EfficientNetB7, InceptionResNetV2, InceptionV3, MobileNet, MobileNetV2, NASNetMobile, ResNet101, ResNet50, ResNet50V2, VGG16, VGG19, Xception
```

## 学習済みモデル一覧
[Module:tf.keras.applications](https://www.tensorflow.org/api_docs/python/tf/keras/applications?hl=ja&authuser=0)
に公開されているもの

- **1. densenet** module: DenseNet models for Keras.
- **2. efficientnet** module: EfficientNet models for Keras.
- **3. inception_resnet_v2** module: Inception-ResNet V2 model for Keras.
- **4. inception_v3** module: Inception V3 model for Keras.
- **5. mobilenet** module: MobileNet v1 models for Keras.
- **6. mobilenet_v2** module: MobileNet v2 models for Keras.
- **7. nasnet** module: NASNet-A models for Keras.
- **8. resnet** module: ResNet models for Keras.
- **9. resnet50** module: Public API for tf.keras.applications.resnet50 namespace.
- **10. resnet_v2** module: ResNet v2 models for Keras.
- **11. vgg16** module: VGG16 model for Keras.
- **12. vgg19** module: VGG19 model for Keras.
- **13. xception** module: Xception V1 model for Keras.


## 条件
### データ
- train
    - cat : 150
    - dog : 150
- test
    - cat : 100
    - dog : 100

![スクリーンショット 2020-09-25 22.05.13.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/2ae2c070-c2cc-b117-ef99-19e1f9ec4644.png)




### 前処理

- `preprocessing_function`:`tf.keras.applications.[学習済みモデル].preprocess_input`
- `target_size=(224, 224)`固定

```python:ImageDataGeneratorインスタンス
# train, validation 用
datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.densenet.preprocess_input,
    validation_split=0.3)

# test 用
test_datagen = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.densenet.preprocess_input)
```


```python:generator生成
train_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    subset='training')

val_generator = datagen.flow_from_directory(
    '/content/dog_cat_data/train',
    target_size=(224, 224),
    class_mode='binary',
    subset='validation')

test_generator = test_datagen.flow_from_directory(
    '/content/dog_cat_data/test',
    target_size=(224, 224),
    class_mode='binary')
```

```python:前処理後の画像可視化
# train_generator 1 バッチ目の先頭 5 件のみ表示
train_images, train_labels = next(train_generator)

def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

print(train_labels[: 5])
plotImages(train_images[: 5])
```
### 学習
- `include_top=False` で全結合層は除く
- `weights='imagenet'`
- `input_shape=(224, 224, 3)`
- パラメータはすべて凍結
- `batch_size=32`generator のデフォルト
- `epochs=20`
- `pooling='avg'`GlobalAveragePooling2D を使用する

```python:学習済みモデル
base_model = DenseNet201(
    include_top=False,
    weights='imagenet',
    input_shape=(224, 224, 3),
    pooling='avg')
```
```python:学習
K.clear_session()
tf.random.set_seed(0)

# パラメータ凍結
for layer in base_model.layers:
    layer.trainable = False

# モデルの構築
x = base_model.output
x = layers.Dense(2, activation='softmax')(x)
model = models.Model(base_model.input, x)

# コンパイル
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=optimizer,
              metrics=['accuracy'])

# 学習
history = model.fit(
    train_generator,
    epochs=20,
    verbose=1,
    validation_data=val_generator)
```
```python:評価
model.evaluate(test_generator)

>>> [loss, accuracy]
```

あとは 

- ImageDataGenerator  `preprocess_input`
- `base_model` のインスタンス化のモデル

を変えて実装していきます

## 1. DenseNet201
前処理後の画像
![スクリーンショット 2020-09-25 20.34.07.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/c81b8da0-d769-2823-c3ae-01a3d678ef30.png)
結果

```python
[0.24632257223129272, 0.9449999928474426]
```

## 2. EfficientNetB7
前処理後の画像
![スクリーンショット 2020-09-25 20.40.30.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/4bd586e4-c825-08f5-a4ba-cbf5278a857f.png)

結果

```python
[0.2533932626247406, 0.9900000095367432]
```

## 3. InceptionResNetV2
前処理後の画像
![スクリーンショット 2020-09-25 20.55.27.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/81eb2dab-1e1b-b0c0-20cb-3a8dbcf4ae5a.png)

結果

```python
[0.11524547636508942, 0.9750000238418579]
```

## 4. InceptionV3
前処理後の画像
![スクリーンショット 2020-09-25 21.00.49.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/0d1bce19-64ba-9714-d631-3a5ba7e541fc.png)

結果

```python
[0.2021591067314148, 0.9549999833106995]
```

## 5. MobileNet
前処理後の画像
![スクリーンショット 2020-09-25 21.04.52.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/624044ba-ba03-1b8d-5121-ddbf3cdb85c9.png)
結果

```python
[0.25840920209884644, 0.9200000166893005]
```

## 6. MobileNetV2
前処理後の画像
![スクリーンショット 2020-09-25 21.07.50.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/09794fdf-cf3d-afc5-b802-9907cf028684.png)
結果

```python
[0.3347647786140442, 0.9049999713897705]
```

## 7. NASNetMobile
前処理後の画像
![スクリーンショット 2020-09-25 21.11.21.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/0a2e8a6b-cd70-56e1-0be7-e1ba67902b0e.png)
結果

```python
[0.20991472899913788, 0.9599999785423279]
```

## 8. ResNet101
前処理後の画像
![スクリーンショット 2020-09-25 21.16.43.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/a2525bfd-bace-4cd7-58db-696a508db4ea.png)
結果

```python
[0.164097860455513, 0.949999988079071]
```

## 9. ResNet50
前処理語の画像
![スクリーンショット 2020-09-25 21.21.32.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/74be15bc-d468-6b18-ca0d-87507063abf2.png)
結果

```python
[0.23483356833457947, 0.9150000214576721]
```

## 10. ResNet50V2
前処理後の画像
![スクリーンショット 2020-09-25 21.25.43.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/9b00a808-a441-5673-263c-1ae30ef6f1b2.png)
結果

```python
[0.2073519378900528, 0.9100000262260437]
```

## 11. VGG16
前処理後の画像
![スクリーンショット 2020-09-25 21.28.25.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/c12cd8aa-f956-c16c-c684-af4b202133ef.png)
結果

```python
[1.3361703157424927, 0.7450000047683716]
```

## 12. VGG19
前処理後の画像
![スクリーンショット 2020-09-25 21.30.34.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/8f204028-866a-23aa-3723-88e4fbc4afe4.png)
結果

```python
[1.7525851726531982, 0.6399999856948853]
```

## 13. Xception
前処理後の画像
![スクリーンショット 2020-09-25 21.32.48.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/605917/1ff50c30-ff92-7c2d-86eb-56bd8edd9725.png)
結果

```python
[0.17928537726402283, 0.9750000238418579]
```
## 最後に
EfficientNet 前処理がよくわからないけれど強いですね
速さはダントツで MobileNet
VGG 系はもっとエポック数が必要ですね

とりあえず学習済みモデルを使う際も**前処理は大事**だと感じました
11
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?