2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

KerasでCIFAR-10の画像分類をしてみた

Last updated at Posted at 2022-05-28

はじめに

深層学習の勉強として取り組んだMNISTの画像分類の次のステップとして、CIFAR-10の画像分類に取り組んでみた。以前MNISTに取り組んだ際に使用したライブラリ、Kerasを使って同じようにCNNを実装してみる。

CIFAR-10

 CIFAR-10データセット(Canadian Institute For Advanced Research)は10種類の画像からなるデータセットで、MNISTと同様に画像認識を目的としたディープラーニング/機械学習の研究や初心者向けチュートリアルで使われている。データセットは5万枚の訓練データと1万枚のテストデータで構成され、中身の画像は24bitのRGBフルカラー画像で、0~255のピクセル値で表される。サイズは幅32×高さ32

CIFAR-10には、以下の10種類が用意されている。

ラベル「0」: airplane(飛行機)
ラベル「1」: automobile(自動車)
ラベル「2」: bird(鳥)
ラベル「3」: cat(猫)
ラベル「4」: deer(鹿)
ラベル「5」: dog(犬)
ラベル「6」: frog(カエル)
ラベル「7」: horse(馬)
ラベル「8」: ship(船)
ラベル「9」: truck(トラック)
image.png

環境

google colaboratory
Python 3.7.13
Keras 2.8.0

実装

1.ライブラリをインポート

必要なライブラリを読み込む。

import
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

from PIL import Image

from keras.models import Sequential, load_model
from keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense
from keras.layers import Activation, BatchNormalization
from keras.utils import np_utils
from keras.callbacks import EarlyStopping
from keras.datasets import cifar10

2.画像データ読み込み

keras.datasetsからCIFAR-10の画像データを読み込む。

keras.datasets.cifar10
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

# 確認
X_train.shape, y_train.shape, X_test.shape, y_test.shape
# (50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)

読み込んだ画像データはcolabのディレクトリ直下に保存する。

あとで使うので、クラスを格納したdictやlistを作っておく。

class
# ラベル「0」: airplane(飛行機)
# ラベル「1」: automobile(自動車)
# ラベル「2」: bird(鳥)
# ラベル「3」: cat(猫)
# ラベル「4」: deer(鹿)
# ラベル「5」: dog(犬)
# ラベル「6」: frog(カエル)
# ラベル「7」: horse(馬)
# ラベル「8」: ship(船)
# ラベル「9」: truck(トラック)

label_dict = {0:"飛行機", 1:"自動車", 2:"", 3:"", 4:"鹿", 5:"", 6:"カエル", 7:"", 8:"", 9:"トラック"}

# 正解ラベルの中身の種類(0~9)をlistに格納
class_list = np.unique(y_train).tolist()
num_class = len(class_list)

# 読み込んだデータを保存する
xy = (X_train, y_train, X_test, y_test)
np.save("./cifar10.npy", xy)

class_list
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

3.データの読み込み

データの読み込みから学習、可視化までを行う関数をそれぞれ定義する。

load_dataでは、データを読み込んで整形する。

load_data
def load_data():
  X_train, y_train, X_test, y_test = np.load("./cifar10.npy", allow_pickle=True)
  # 正規化
  X_train = X_train.astype("float32")/255.0
  X_test = X_test.astype("float32")/255.0
  # yを2値配列に変換(one-hot)
  y_train = np_utils.to_categorical(y_train, num_class)
  y_test = np_utils.to_categorical(y_test, num_class)

  return X_train, y_train, X_test, y_test

4.学習

trainでは、CNNの中身を構築して学習させる。

今回は畳み込みとプーリング、ドロップアウトを行うブロックを3つ通した後、平滑化して全結合層に渡す。ユニット数やバッチサイズは最初テキトーに決めて、学習を回しながら値を決めていく。Pythonの風習として2の階乗の値にするらしいが、精度が出れば別になんでもいいらしい。

train
def train(X_train, y_train, X_test, y_test):
  model = Sequential()

  # ブロック1
  model.add(Conv2D(128, (3,3), padding="same", activation="relu"))
  model.add(Conv2D(128, (3,3), padding="same", activation="relu"))
  model.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
  model.add(Dropout(0.25))

  # ブロック2
  model.add(Conv2D(64, (3,3), padding="same", activation="relu"))
  model.add(Conv2D(64, (3,3), padding="same", activation="relu"))
  model.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
  model.add(Dropout(0.25))

  # ブロック3
  model.add(Conv2D(32, (3,3), padding="same", activation="relu"))
  model.add(Conv2D(32, (3,3), padding="same", activation="relu"))
  model.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
  model.add(Dropout(0.25))

  # 平滑化
  model.add(Flatten())
  
  # 全結合
  model.add(Dense(512, activation="relu"))
  model.add(Dropout(0.6))
  model.add(Dense(num_class, activation="softmax"))

  # 学習の設定
  model.compile(loss="categorical_crossentropy",
                optimizer="adam",
                metrics=["accuracy"])

  # 学習
  history = model.fit(X_train, y_train,
                      batch_size=1024,
                      epochs=50,
                      verbose=1,
                      validation_data=(X_test, y_test),
                      callbacks=[EarlyStopping(patience=10)]
                      )
  
  # モデルの構造と重みを保存
  model.save("./cnn1.h5")

  return model, history

5.テストデータで正解率を算出

test_accuracyでは、学習したモデルのtestデータに対するクラスごとの正解率を表示する。

「テストデータから引数とデータを一緒に取り出す→予測→正解なら1を、不正解なら0をlistのlabel番目に足す」をテストデータの数だけ行っている。

def test_accuracy(model, X_test, y_test):
  # 全正解数
  sum_correct = 0

  # クラスごとの正解率
  class_total = [0 for i in range(num_class)]
  class_correct = [0. for i in range(num_class)]

  for i, data in enumerate(X_test):
    pred = model.predict(np.array([data])) # np.array([np.array])で(32,32,3)→(1,32,32,3)に整形
    pred = pred.reshape(pred.shape[1])  # predを2次元配列で出てくるので、1次元配列に変換
    pred_index = np.argmax(pred)  # 一番確率が高い引数を取得
    label = np.argmax(y_test[i]) # yは二値配列にしているので、np.argmaxで中身を取り出す(0~9)
    sum_correct += (1 if pred_index==label else 0) # y_testと一致した個数を累積
    class_total[label] += 1 # label番目の個数を+1
    class_correct[label] += (1 if pred_index==label else 0) # 正解ならlabel番目の正解数を+1

  print("-"*100)
  print("正解数:", sum_correct)
  print("データ数:", len(X_test))
  print("正解率:", (sum_correct/len(X_test)*100))

  print("-"*100)
  for i in range(num_class):
    print("%5s クラスの正解率:%.1f %%" %(class_dict[i], class_correct[i]/class_total[i]*100))

6.学習過程を可視化

plot_figでは、学習の経過を可視化する。

historyには学習の経過が格納されているので、plt.plotに渡せば簡単に可視化できる。

def plot_fig(history):
  print("-"*100)
  print("BatchNormalizationなし")
  # 描画する領域を設定
  plt.figure(1, figsize=(13,4))
  plt.subplots_adjust(wspace=0.5)

  # 学習曲線
  plt.subplot(1, 2, 1)
  plt.plot(history.history["loss"], label="train")
  plt.plot(history.history["val_loss"], label="test")
  plt.title("train and valid loss")
  plt.xlabel("epoch")
  plt.ylabel("loss")
  plt.legend()
  plt.grid()

  # 精度表示
  plt.subplot(1, 2, 2)
  plt.plot(history.history["accuracy"], label="train")
  plt.plot(history.history["val_accuracy"], label="test")
  plt.title("train and valid accuracy")
  plt.xlabel("epoch")
  plt.ylabel("accuracy")
  plt.legend()
  plt.grid()

  plt.show()

7.実行

mainは、これまでに定義してきた関数に値を渡して実行するだけの関数である。

# メイン関数
def main():
  X_train, y_train, X_test, y_test = load_data()
  model, history = train(X_train, y_train, X_test, y_test)
  test_accuracy(model, X_test, y_test)
  plot_fig(history)

main()

main()を実行すると
image.png
このような出力結果が得られる。モデルに乱数が絡んでいる部分があるので、学習を回すたびに結果は若干変わる。正解率は80%ほどで、だいぶ過学習を抑えつつ学習を進められているが、動物に対する正解率が他に比べて低くなっている。

CNNは浅い層で汎化的な特徴を、深い層で学習データに合った特徴を抽出するため、ユニット数やドロップアウト、学習率をいじればもう少し上がるかもしれない。しかし色々試してみたものの、CIFAR-10自体が32×32の粗い画像なのでよくても正解率82~3%あたりが限界な気もする()

おまけ

CNNの精度を上げる手法の一つに、BatchNormalizationというものがあるらしい。重み更新のたびに各層からの出力が従う分布が変わってしまう影響を抑制する手法だが、今回のような画素数の低いデータでは特に良い結果は得られなかった。しかし、データやモデルの構造によっては飛躍的に精度が向上するらしいので、使い方だけでも載せておく。

  model = Sequential()

  model.add(Conv2D(128, (3,3), padding="same"))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(Conv2D(128, (3,3), padding="same"))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
  model.add(Dropout(0.25))

  model.add(Conv2D(64, (3,3), padding="same"))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(Conv2D(64, (3,3), padding="same"))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
  model.add(Dropout(0.25))

  model.add(Conv2D(32, (3,3), padding="same"))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(Conv2D(32, (3,3), padding="same"))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(MaxPooling2D(pool_size=(2,2), strides=(2, 2)))
  model.add(Dropout(0.25))

  model.add(Flatten())

  model.add(Dense(512))
  model.add(BatchNormalization()) # ←コレ
  model.add(Activation("relu"))
  model.add(Dropout(0.6))
  
  model.add(Dense(num_class, activation="softmax"))

Kerasでの使い方は、畳み込み層の後にBatchNormalization()を加えるだけだった。出力を活性化関数に流す前に加えるので、"relu"をActivation()で後ろに付け加える。今回の結果はこんな感じ。

image.png
学習が安定していない。今回は↑で作成したCNNに加えただけだが、BatchNormalizationがあればドロップアウトはいらないという論文もあったので、今後の実務の中でお世話になる日が来るかもしれない。

参考

CIFAR-10:物体カラー写真(乗り物や動物など)の画像データセット

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?