7
Help us understand the problem. What are the problem?

posted at

転移学習で「呪術廻戦」のキャラの分類やってみた!

今回は、今流行りの呪術廻戦で機械学習やってみました!

画像検索で拾った画像を約90%の精度で見分けることができました!

VGG16を転移学習させたモデルで予測した画像の一覧
jyujyutu1.JPG
jyujyutu2.JPG

ちなみにキャラクターは、

itadori.png

itadori(虎杖悠仁)
呪術廻戦の主人公

hushiguro.png

hushiguro(伏黒恵)
呪術高専の1年生、十種影法術を使う

nobara.png

nobara(釘崎野薔薇)
呪術高専の1年生、芻霊呪法を使う

gojo.png

gojo(五条悟)
呪術高専に所属する教師で、特級の名を冠する現代最強の呪術師

todo.png

todo(東堂葵)
呪術高専京都校に所属する3年生、手を叩いて術式範囲内にある物の位置を入れ替える

Kerasで実装(データセット準備)

今回は、Google Colaboratoryを使って誰でも簡単に実装できるようにしています。

スクリプトを実装する前にデータセットを用意します。

GitHubからデータセットはダウンロードできます。GitHub

今回は、jyujyutu_VGGというフォルダの中にdisplay , test , train , validationを作っています。
jyujyutu_file.JPG

画像は train = 250枚、validation = 50枚、test = 50枚
の内訳で入っています。

displayは冒頭のような画像を表示させるときに使う用で、
中身はtestと同じです!

displayは用意しなくても大丈夫です。

このデータセットは、Googleの画像検索で自作しました。

ライブラリのインポートとモデル構築

jyujyutu_VGG.zipをダウンロードしたら、
Google Driveにアップロードしておきましょう。

# Googleドライブをマウント
from google.colab import drive
drive.mount('/content/drive')

.zipのままGoogle Driveにアップロードしてください!

# Google ColaboratoryでZipファイルを解凍
from zipfile import ZipFile
file_name = '/content/drive/My Drive/jyujyutu_VGG.zip'

with ZipFile(file_name, 'r') as zip:
zip.extractall()

zipファイルのままアップロードして、ここで解凍した方が時間短縮になります。

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.callbacks import CSVLogger

必要なライブラリをimportします。

n_categories=5
batch_size=32
train_dir='/content/jyujyutu_VGG/train'
validation_dir='/content/jyujyutu_VGG/validation'
file_name='vgg16_jyujyutu_file'

base_model=VGG16(weights='imagenet',include_top=False,
                 input_tensor=Input(shape=(224,224,3)))

今回は、5人のキャラクターを分類するので、n_categories = 5とします。

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024,activation='relu')(x)
prediction=Dense(n_categories,activation='softmax')(x)
model=Model(inputs=base_model.input,outputs=prediction)
#fix weights before VGG16 14layers
for layer in base_model.layers[:15]:
    layer.trainable=False

model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

今回の実装では、15層以降のみを学習させています。

最終的な出力は1次元(COVID-19かどうか確率)なので、新たな全結合層を追加しています。

転移学習、VGG16について詳しくは
VGG16を転移学習させて「まどか☆マギカ」のキャラを見分ける
を見てください。

#save model
json_string=model.to_json()
open(file_name+'.json','w').write(json_string)

画像の前処理をして学習

train_datagen=ImageDataGenerator(
    rescale=1.0/255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

validation_datagen=ImageDataGenerator(rescale=1.0/255)

train_generator=train_datagen.flow_from_directory(
    train_dir,
    target_size=(224,224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

validation_generator=validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(224,224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

hist=model.fit_generator(train_generator,
                         epochs=100,
                         verbose=1,
                         validation_data=validation_generator,
                         callbacks=[CSVLogger(file_name+'.csv')])

#save weights
model.save(file_name+'.h5')

ImageDataGeneratorは画像を整形したり、水増しするのに便利です。
読み込むフォルダを与えれば、自動的にそのフォルダの名前をラベルにしてくれます。

ImageDataGeneratorは、デフォルトの機能だけでなく、
関数として定義してあげれば他のData Augmentationも出来るのでめっちゃ便利です!

結果

test画像でモデルを評価します

from keras.models import model_from_json
import matplotlib.pyplot as plt
import numpy as np
import os,random
from keras.preprocessing.image import img_to_array, load_img
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD

batch_size=32
file_name='vgg16_jyujyutu_file'
test_dir='jyujyutu_VGG/test'
display_dir='jyujyutu_VGG/display'
label=['gojo','hushiguro','itadori','nobara','todo']

#load model and weights
json_string=open(file_name+'.json').read()
model=model_from_json(json_string)
model.load_weights(file_name+'.h5')

model.compile(optimizer=SGD(lr=0.0001,momentum=0.9),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

#data generate
test_datagen=ImageDataGenerator(rescale=1.0/255)

test_generator=test_datagen.flow_from_directory(
    test_dir,
    target_size=(224,224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True
)

#evaluate model
score=model.evaluate_generator(test_generator)
print('\n test loss:',score[0])
print('\n test_acc:',score[1])

jyujytukekka.png
jyujyutukekka2.png

validation_accuracyは最終的に約90%になりました。
test accuracyが約96%になりました。

訓練画像250枚(各キャラ50枚ずつ)で、Google ColaboratoryのGPUを使えば
かなり短時間で学習できました。

今後の検討

今回は、VGG16を転移学習させて「まどか☆マギカ」のキャラを見分けるの記事を参考に、

簡単に好きなデータセットで実装できることを試し、さらに今の流行りにも乗ってみました。

今後、Mixup や Random Erasing , Cutmixなども試して、精度が向上するかどうか

やってみたいですね!

参考文献

1.VGG16を転移学習させて「まどか☆マギカ」のキャラを見分ける

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
7
Help us understand the problem. What are the problem?