2
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

第3回 橋本環奈分類機を作成してみた

Last updated at Posted at 2023-10-05

やったこと

画像分類の勉強のため、入力された画像が橋本環奈さんであるかを判定するAIを作ろうと考えた。
今回は、収集した橋本環奈さんの画像500枚と人気女優さんの画像500枚を使って、EfficientNet V2 B0で画像分類モデルを作成した。

結果

テストデータ100枚(クラスの内訳は1:1)での結果Accuracy: 0.92と予想よりいい精度が出た。

データフォルダの構成

dataset
├─train
│ ├─NotHashimoto
│ │ └─***.jpg
│ └─Hashimoto
├─test
│ ├─NotHashimoto
│ └─Hashimoto
└─val
  ├─NotHashimoto
  └─Hashimoto

ソースコード

今回は収集した画像を使ってEfficientNet V2をファインチューニングしてみた。また画像の水増しは行っていないが、将来的に水増しによる精度向上がどの程度あるかも確認したく、ImageDataGeneratorを使っている。

Classifier.py
import tensorflow as tf
from tensorflow.keras.applications.efficientnet_v2 import EfficientNetV2B0
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, BatchNormalization, Activation
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint

def build_model():
    """
    mobilenet v2を転移学習するためのモデルを作成する関数
    """
    inputs = Input(shape=(224, 224, 3))
    base_model = EfficientNetV2B0(weights='imagenet', include_top=False, input_tensor=inputs)
    x = base_model(inputs)
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    outputs = Dense(1, activation='sigmoid')(x)
    model = Model(inputs, outputs)
    return model

def setup_generator():
    """
    ImageDataGeneratorを使って、train_generator, val_generator, test_generatorを作成
    """
    # 画像があるフォルダ
    train_data_dir = './dataset/train'
    valid_data_dir = './dataset/val'
    test_data_dir = './dataset/test'

    # データ拡張の設定
    train_datagen = ImageDataGenerator(
        rescale=1.0
    )
    valid_datagen = ImageDataGenerator(rescale=1.0)
    test_datagen = ImageDataGenerator(rescale=1.0)

    # バッチサイズと画像サイズを指定します。
    batch_size = 32
    image_size = (224, 224)

    # トレーニングデータジェネレータとバリデーションデータジェネレータを作成します。
    train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode='binary'
    )

    valid_generator = valid_datagen.flow_from_directory(
        valid_data_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode='binary'
    )

    test_generator = test_datagen.flow_from_directory(
        test_data_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode='binary'
    )

    return train_generator, valid_generator, test_generator

# generatorを作成
train_generator, valid_generator, test_generator = setup_generator()
# モデル作成
model = build_model()
# モデルのコンパイル
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# callbackの設定
earlystop = EarlyStopping(patience=10)
savefolder = './BestModel'
checkpoint = ModelCheckpoint(savefolder, monitor='val_loss', verbose=1, save_best_only=True)

# モデルの学習
epochs = 100
history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=epochs,
    validation_data=valid_generator,
    validation_steps=len(valid_generator),
    callbacks=[earlystop, checkpoint]
)
# 評価
model.evaluate(test_generator)

まとめ

そこそこの精度が出ているので、一旦このモデルでWebアプリを作成してみる。

関連記事

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?