2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Xceptionを転移学習させてセーラームーンのキャラを分類する

Posted at

はじめに

CNNを勉強して自分も何か画像認識させてみようと思い,セーラームーンのキャラの画像認識を行いました. 結果として精度が80%のモデルを作成することができました. 初めてQiitaに投稿しました.

実行環境

Windows10 Anaconda 環境
CPU AMD Ryzen 5 3600
メモリ 16GB
GPU NVIDIA GeForce RTX 2060
Python 3.8.5
Numpy 1.19.5
Keras 2.4.3
Tensorflow-gpu 2.4.1
OpenCV2 4.5.1

セーラームーンとは

美少女戦士セーラームーンは武内直子先生の漫画を原作とする作品です. 今回の分類ではセーラー戦士10人の分類を行います. キャラクターは次の通りです. なおタキシード仮面は分類対象に含みません!

セーラームーン/月野うさぎ (Sailormoon)

セーラーマーキュリー/水野亜美 (Sailormercury)

セーラーマーズ/火野レイ (Sailormars)

セーラージュピター/木野まこと (Sailorjupiter)
Sailorjupiter16.jpg

セーラーヴィーナス/愛野美奈子(Sailorvenus)

セーラーちびムーン/ちびうさ(Sailortibimoon)

セーラープルート/冥王せつな(Sailorpluto)

セーラーウラヌス/天王はるか(Sailoruranus)

セーラーネプチューン/海王みちる(Sailorneptune)

セーラーサターン/土萠ほたる(Sailorsaturn)

Xceptionとは?

XceptionはImage Net dataset(ILSVRCのデータセット)でエラー率5.5%のモデルです. 原論文はChollet FrancoisさんのXception: Deep Learning with Depthwise Separable Convolutionsです. このモデルはGoogLeNetと呼ばれるモデルの進化系で, GoogLeNetで採用されたInceptionという手法を改良したモデルです. GoogLeNet, Xceptionの説明はAIsia Solid先生の動画がわかりやすいです.

Xceptionのアーキテクチャは次のようになっています(論文から引用). 特徴はSeparable Convolution layerです. このlayerは空間方向の情報とチャネル方向の情報を完全に分離して畳み込みを行います. 今回は事前にImage Net datasetで学習したXceptionモデルを転移学習させてセーラー戦士の予測を行いたいと思います. 転移学習とはすでに,学習済みのモデルを使用して少ない画像で,短時間でモデルを構築することです.
Xception-Architecture.png

ソースコード

コーディングをして学習を行います. データはGoogle画像検索をスクレイピングをしてとってきました. 訓練用画像が800枚, テスト用画像が200枚あります. 訓練用, テスト用画像はセーラー戦士が同じ割合で入っています. ソースコードとデータはGitHubに置いておきました. GitHub

ライブラリ読み込み

まずは必要なライブラリを読み込みます. ランダムシードもrandom_state=623で固定にしておきます.

ライブラリ読み込み
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2
import os
import random
import seaborn as sns

random_state = 623

データ読み込み

データを読み込みます. データを読み込むためにload_file関数を定義します. load_file関数は引数として読み込むファイル一覧filesとセーラー戦士のリストSoldiersを受け取り, 読み込んだ画像のnumpy配列と正解ラベルを返します. 元々の画像サイズはバラバラですが, すべて128x128に統一しています.

load_file関数
def load_file(files,Soldiers):
    """ファイルを読み込んでnumpy.arrayに変換する関数
    
    Args:
    files : 読み込むファイルのリスト
    Soldiers : セーラー戦士のリスト
    
    Returns:
    np.array(file_list) : 読み込んだ画像のnumpy.array
    y : 正解ラベル
    """
    file_list = []
    y=[]
    for file in files:
        # 正解ラベルをリストに代入
        for i,soldier in enumerate(Soldiers):
            if soldier in file:
                y.append(i)
                
        # load img
        f = cv2.imread(file)
        f = cv2.resize(f, dsize=(128, 128))
        f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
        
        file_list.append(f)
    return np.array(file_list),y

作成したload_file関数を使用して画像を読み込みます. 読み込んだ画像は255で割ることで数値を0~1に収まるようにしています. また正解ラベルはOne-hotエンコーディングを行っています. そして訓練用データを訓練用(train)と検証用(valid)に分割します. train:valid=8:2になるように分割をしました.

画像読み込み
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

Soldiers = ["Sailorjupiter","Sailormars","Sailormercury","Sailormoon","Sailorneptune",
            "Sailorpluto","Sailorsaturn","Sailortibimoon","Sailoruranus","Sailorvenus"]

train_path = "./train/"
test_path = "./test/"
# load data
train,y_train = load_file(glob.glob(train_path+"/*"),Soldiers)
test,y_test = load_file(glob.glob(test_path+"/*"),Soldiers)

train = train.reshape((train.shape[0], 128, 128, 3)) / 255
x_test = test.reshape((test.shape[0], 128, 128, 3)) / 255
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

x_train, x_valid, y_train, y_valid = train_test_split(
    train, y_train,test_size=0.2,random_state=random_state)

train,test,validのshapeを確認してここまでの実行がうまくいっているか確認します. 正しく実行できていれば次のようになります.

データが読み込めたか確認
# train,valid,testのshapeを確認
print(x_train.shape)
print(y_train.shape)
print(x_valid.shape)
print(y_valid.shape)
print(x_test.shape)
print(y_test.shape)
実行結果
(640, 128, 128, 3)
(640, 10)
(160, 128, 128, 3)
(160, 10)
(200, 128, 128, 3)
(200, 10)

転移学習モデルの作成

Xceptionをベースモデルとして学習モデルを作成します. 今回は出力層に1024ユニットの全結合層を追加しました. 損失関数はカテゴリカルクロスエントロピー, 最適化アルゴリズムはAdamを採用しました.

転移学習モデルの作成
import keras
import tensorflow as tf
from IPython.display import SVG
from tensorflow.python.keras.utils.vis_utils import model_to_dot
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input

# base model(Xception)を読み込み
# include_top : 出力層側を含むかどうか
base_model =  tf.keras.applications.xception.Xception(weights='imagenet',include_top=False,input_tensor=Input(shape=(224,224,3)))

# base modelに出力層を追加
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024,activation="relu")(x)
prediction=Dense(10,activation='softmax')(x)
model=Model(inputs=base_model.input,outputs=prediction)
    
model.compile(
    loss=keras.losses.categorical_crossentropy,
    optimizer="adam",
    metrics=["accuracy"]
)

画像の前処理と学習

画像の前処理と学習を行います. 画像の前処理にはImageDataGeneratorを使用して画像の整形と水増しを行いました. ImageDataGeneratorは次を参考にしました. ImageDataGenerator
学習はEarlyStoppingを70epochとして300epoch実行します.

画像の前処理と学習
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping

train_datagen=ImageDataGenerator(
    width_shift_range=0.4, # 左にシフトする割合
    height_shift_range=0.4, # 垂直方向にシフトする割合
    rotation_range=30, # ランダムに30度の範囲で回転
    shear_range=0.2, # 反時計回りの回転
    zoom_range=0.2, # 拡大
    horizontal_flip=True) # 水平方向に反転

validation_datagen=ImageDataGenerator()

train_datagen.fit(x_train)
validation_datagen.fit(x_valid)

train_generator=train_datagen.flow(
    x_train,y_train,
    batch_size=70,
    shuffle=True
)

validation_generator=validation_datagen.flow(
    x_valid,y_valid,
    batch_size=70,
    shuffle=True
)

early_stopping = EarlyStopping(patience=70, verbose=1)

hist=model.fit_generator(train_generator,
                         epochs=300,
                         verbose=1,
                         validation_data=validation_generator,
                         callbacks=[early_stopping])

実行結果

損失関数と正解率のプロット

損失関数と正解率のプロットは次のようになりました. 正解率は最終的に70~80%で安定しました. 学習は222epochで早期打ち切りました.

損失関数
loss.png

正解率
acc.png

テストデータによる評価

テストデータによる評価を行います. 0~1の実数で各クラスに属する確率を予測するpredictionと, 予測したクラスの番号0~9を格納するprediction_labelを出力するようにしました. 実行結果は正解率が0.8, つまり80%になりました.

テストデータによる評価
# modelの評価
test_datagen=ImageDataGenerator()
test_datagen.fit(x_test)
test_generator=test_datagen.flow(
    x_test,y_test,
    batch_size=100,
    shuffle=False
)

scores = model.evaluate(test_generator)
prediction = model.predict(test_generator)
prediction_label = np.argmax(model.predict(test_generator),axis=1)
実行結果
2/2 [==============================] - 0s 96ms/step - loss: 0.8476 - accuracy: 0.8000

分類結果の例

5x5の分類結果の例を表示してみます. 実行結果からほとんどの予測は正しいことがわかります.

分類結果の例を表示
files=glob.glob(test_path+"/*")
file_list = []
y_label = []
for file in files:
    # 正解ラベルをリストに代入
    for i,soldier in enumerate(Soldiers):
        if soldier in file:
            y_label.append(i)
    # load img
    f = cv2.imread(file)
    f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
    file_list.append(f)
test = np.array(file_list)

r =np.array(random.sample(range(0,200),25)).reshape(5,5)

fig = plt.figure(figsize=(15, 15))
k=1
for i in range(5):
    for j in range(5):
        plt.subplot(5,5,k)
        plt.imshow(test[r[i][j]])
        k+=1
        plt.xticks([]),plt.yticks([])
        plt.title("正解ラベル "+Soldiers[y_label[r[i][j]]]+"\n予測ラベル "+Soldiers[prediction_label[r[i][j]]])
        
plt.subplots_adjust(left=0.2, right=0.95, bottom=0.1, top=0.95)
plt.show()

exsample.png

Confusion Matrix

Confusion Matrix(混同行列)を表示してみます. Confusion Matrixから分類の間違いの傾向を確認することができます. 視覚的に分かりやすいように, ヒートマップでConfusion Matrixを表示するようにしました.

Confusion_Matrix
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_label, prediction_label)
sns.heatmap(cm)
plt.title("confusion matrix")

実行結果です. 縦軸が正解ラベル, 横軸が予測ラベルを表しており, 対角線上の要素が明るい色になっているほど予測が正しいことを意味しています. 実行結果の番号はSoldiersリストのインデックスと対応しています. セーラー戦士と番号の対応の表を載せておきます. 実行結果を見るとほとんどの要素は対角線上のみ明るい色になっています. 対角線上以外で明るい色になっているところを見ると(セーラーマーズ,セーラーサターン),(セーラーマーズ,セーラーネプチューン), (セーラームーン,セーラーヴィーナス), (セーラージュピター,セーラーウラヌス)のペアであることがわかります. (セーラーマーズ,セーラーサターン),(セーラーマーズ,セーラーネプチューン), (セーラームーン,セーラーヴィーナス)のペアは髪の色が似ていることから分類が間違っていることが考えられます. (セーラージュピター,セーラーウラヌス)のペアを間違える理由はいまいちわかりません. しいて言えば男の子っぽいということでしょうか...?

cm.png

番号 セーラー戦士
0 セーラージュピター
1 セーラーマーズ
2 セーラーマーキュリー
3 セーラームーン
4 セーラーネプチューン
5 セーラープルート
6 セーラーサターン
7 セーラーちびムーン
8 セーラーウラヌス
9 セーラーヴィーナス

まとめ

転移学習モデルを利用して精度80%のモデルを構築することができました. 最初のイメージではここまでの精度がでるとは考えていませんでした. また分類結果の分析もわかりやすいもので勉強になりました.

参考

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?