0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

CIFAR-10を使ってkerasの学習モデルを作成し適当な写真を分類させてみる

Posted at

やりたいこと

  • CIFAR-10を使ってkerasのmodelを学習させる。
  • その学習済モデルを保存しておく。
  • 保存したモデルをロードして、ネットから拾ってきた適当な「自動車の写真」を「自動車」と判定できるか確かめる。

参考情報

環境構築

CNN

CIFAR-10について

ラベル「0」: airplane(飛行機)
ラベル「1」: automobile(自動車)
ラベル「2」: bird(鳥)
ラベル「3」: cat(猫)
ラベル「4」: deer(鹿)
ラベル「5」: dog(犬)
ラベル「6」: frog(カエル)
ラベル「7」: horse(馬)
ラベル「8」: ship(船)
ラベル「9」: truck(トラック)

image.png

参考
https://www.atmarkit.co.jp/ait/articles/2006/10/news021.html
https://ymgsapo.com/2019/01/18/cifar10-image-recognition/
https://www.mizuho-ir.co.jp/publication/column/2020/infocomm0318.html

CIFAR-10の読み込みからモデルのコンパイルと学習まで

import keras
from keras.datasets import cifar10
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D

num_classes = 10
im_rows = 32
im_cols = 32
in_shape = (im_rows, im_cols, 3)

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

# Images
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

# Labels
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', input_shape=in_shape))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
hist = model.fit(X_train, y_train, batch_size=32, epochs=50, verbose=1, validation_data=(X_test, y_test))

save_model_name = 'yamato_cifar10_model'
model.save(save_model_name)

score = model.evaluate(X_test, y_test, verbose=1)
print('正解率=', score[1], 'loss=', score[0])

解説

学習データのロード


(X_train, y_train), (X_test, y_test) = cifar10.load_data()

で学習データをロードしている。

モデルのコンパイルと学習

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
hist = model.fit(X_train, y_train, batch_size=32, epochs=50, verbose=1, validation_data=(X_test, y_test))

学習済モデルの保存

save_model_name = 'yamato_cifar10_model'
model.save(save_model_name)

にて学習済モデルを保存している。
下記のようにディレクトリで保存される。

 ll yamato_cifar10_model/
total 244
drwxr-xr-x  4 root root     59 Nov 25 08:05 ./
drwxrwxrwx. 4 root root   4096 Nov 25 09:23 ../
drwxr-xr-x  2 root root      6 Nov 25 08:05 assets/
-rw-r--r--  1 root root 240530 Nov 25 08:05 saved_model.pb
drwxr-xr-x  2 root root     66 Nov 25 08:05 variables/

保存した学習済モデルをロードして写真を判定させる

from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Activation
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image

save_model_name = 'yamato_cifar10_model'
model = tf.keras.models.load_model(save_model_name)

LABELS = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

input_file_name = 'input.jpg'

im = np.array(Image.open(input_file_name))
im = im / 255.0
input_array = np.array([im])
pred = model.predict(input_array)
print("---------- pred  ----------")
print(pred)
print("---------- argmax ----------")
print(np.argmax(pred))

print("---------- result ----------")
for i, label in enumerate(LABELS):
    print('socre of {label:10}: {score:0.3f}' . format(label=label, score=pred[0,i]))

解説

モデルをロードする

save_model_name = 'yamato_cifar10_model'
model = tf.keras.models.load_model(save_model_name)

判定したい画像を指定する

input_file_name = 'input.jpg'
im = np.array(Image.open(input_file_name))

ここで、input.jpg は whdth:32px * height:32px の任意のRGB写真である。
たとえば、つぎのような4枚のjpgを用意する。

image.png

  • 消防車
  • チワワ
  • 戦闘機

それぞれを input.jpg として判定させた結果は下記の通り。
result の socre が最も高いカテゴリに着目する。

消防車

---------- pred  ----------
[[2.92324296e-21 1.14770096e-10 7.31506615e-24 1.61884898e-24
  8.36754993e-26 1.97956345e-23 3.54064243e-24 3.93338139e-22
  8.55078894e-21 1.00000000e+00]]
---------- argmax ----------
9
---------- result ----------
socre of airplane  : 0.000
socre of automobile: 0.000
socre of bird      : 0.000
socre of cat       : 0.000
socre of deer      : 0.000
socre of dog       : 0.000
socre of frog      : 0.000
socre of horse     : 0.000
socre of ship      : 0.000
socre of truck     : 1.000

「トラック」であると判定されている。

---------- pred  ----------
[[7.7227313e-08 7.8912338e-10 1.8037810e-16 4.0607299e-16 1.3727580e-15
  2.2511350e-17 6.2338941e-16 8.0815070e-19 9.9999988e-01 3.6370885e-11]]
---------- argmax ----------
8
---------- result ----------
socre of airplane  : 0.000
socre of automobile: 0.000
socre of bird      : 0.000
socre of cat       : 0.000
socre of deer      : 0.000
socre of dog       : 0.000
socre of frog      : 0.000
socre of horse     : 0.000
socre of ship      : 1.000
socre of truck     : 0.000

「船」であると判定されている。

チワワ

---------- pred  ----------
[[2.6867713e-11 9.4863310e-13 2.7659814e-06 1.2264352e-03 1.2048064e-04
  9.9857211e-01 4.3672014e-08 7.8147066e-05 5.7476631e-13 1.1202570e-10]]
---------- argmax ----------
5
---------- result ----------
socre of airplane  : 0.000
socre of automobile: 0.000
socre of bird      : 0.000
socre of cat       : 0.001
socre of deer      : 0.000
socre of dog       : 0.999
socre of frog      : 0.000
socre of horse     : 0.000
socre of ship      : 0.000
socre of truck     : 0.000

「犬」であると判定されている。

戦闘機

---------- pred  ----------
[[8.9003527e-01 7.4754309e-05 9.2483513e-02 2.0143399e-03 1.3583582e-03
  3.4489814e-04 9.9419327e-03 2.3275054e-04 3.4499075e-03 6.4223343e-05]]
---------- argmax ----------
0
---------- result ----------
socre of airplane  : 0.890
socre of automobile: 0.000
socre of bird      : 0.092
socre of cat       : 0.002
socre of deer      : 0.001
socre of dog       : 0.000
socre of frog      : 0.010
socre of horse     : 0.000
socre of ship      : 0.003
socre of truck     : 0.000

「飛行機」であると判定されている。

10種類に分類できない写真はどうなるか?

りんご

image.png

---------- pred  ----------
[[1.3199445e-03 1.9305049e-02 1.5677975e-01 2.6933338e-02 7.0682232e-05
  3.0774438e-01 4.6424314e-01 7.4237183e-04 1.7272641e-03 2.1134103e-02]]
---------- argmax ----------
6
---------- result ----------
socre of airplane  : 0.001
socre of automobile: 0.019
socre of bird      : 0.157
socre of cat       : 0.027
socre of deer      : 0.000
socre of dog       : 0.308
socre of frog      : 0.464
socre of horse     : 0.001
socre of ship      : 0.002
socre of truck     : 0.021

「frog」と判定された。ただしscoreが高くないので「分類できていない」と言える。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?