2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

概要

TensorFlowプラットフォームの高レベルAPIであるKerasを用いて、転移学習によるCIFAR-10の画像識別を行います。

  • 動作環境
    Google Colaboratory

  • 対象者
    畳み込みニューラルネットワーク(Convolutional Neural Network、略称:CNN)の知識があり、TensorFlowを使ったCNNの実装に興味のある機械学習初心者

用語説明

CIFAR-10

CIFAR-10データセット(Canadian Institute For Advanced Research)は、機械学習やコンピュータビジョンアルゴリズムの学習によく使われる画像のコレクションです。CIFAR-10データセットは、10種類のクラスに分かれた60,000枚の32x32カラー画像を含んでいます。10種類のクラスは、飛行機、車、鳥、猫、鹿、犬、カエル、馬、船、トラックを表しています。各クラスには6,000枚の画像があります。

VGG16

VGG16はVisual Geometry Group(VGG)によって開発されたディープラーニングモデルで、画像認識タスクにおける卓越した性能で知られています。13層の畳み込み層と3層の全結合層を含む16層で構成されています。

ライブラリのインポート

import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.utils import plot_model
from tensorflow.keras import optimizers
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import Dense, Dropout, Flatten, Input
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical

データの準備

CIFAR-10データセットをロードします。このデータセットには訓練データが50,000、検証データが10,000もあるので、今回の学習ではその1/10を使います。

ラベル(y_trainとy_test)はto_categoricalメソッドを用いて、One-Hotエンコーディングしています。

ラベルロードした直後のy_train[0]にはflogを示す[6]が格納されていますが、to_categoricalメソッドで[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]に変換します。こうすることで、出力層で正解ラベルとして使えるようになります。

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

X_train = X_train[:5000]
X_test = X_test[:1000]
y_train = to_categorical(y_train)[:5000]
y_test = to_categorical(y_test)[:1000]

VGG16モデルのセットアップ

VGG16モデルには、ImageNetデータセットで事前に訓練された重みがロードされます。
入力テンソルは、今回使用するCIFAR-10の画像の解像度32x32の入力形状に一致するように定義します。

# vgg16のインスタンスを生成する
input_tensor = Input(shape=(32, 32, 3))
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

逐次モデルの構築:

新しい逐次モデル(top_model)を作成する。
VGG16モデルの出力は、この新しいモデルに接続されます。

vgg16.output_shapeの開始位置が1なのは、VGG16の出力形状は(batch_size, height, width, channels)であるため、shape=vgg16.output_shape[1:]としてバッチサイズを除く(height, width, channels)の3次元を指定しているためです。

# VGGの特徴抽出部分の出力と結合する全結合層のモデルを構築する
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(10, activation='softmax'))

# 構築したモデルをVGG16の出力と連結する
model = Model(inputs=vgg16.input, outputs=top_model(vgg16.output))

重みの固定

結合モデルの最初の19層(平坦化層まで)は、事前に訓練された重みを保持します。

for layer in model.layers[:19]:
    layer.trainable = False

モデルのコンパイルと訓練

カテゴリカル交差エントロピー損失関数とSGDオプティマイザを使用してコンパイルします。
事前に訓練されたモデルの重みがロードされ、モデルは3エポック訓練します。

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
              metrics=['accuracy'])

history = model.fit(X_train, y_train, batch_size=128, epochs=3)

学習進捗の可視化

学習の進捗をmatplotlibを使用して可視化しています。トレーニングおよび検証データでの精度がプロットされます。

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

model_accuracy_32.png

評価と可視化

モデルはテストセットで評価され、精度と損失が表示されます。
テスト画像のサブセットが視覚化され、予測値が表示されます。

scores = model.evaluate(X_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])

for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(X_test[i])
    pred = np.argmax(model.predict(X_test[i].reshape(1,32,32,3)))
    labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    plt.title(labels[pred])
plt.suptitle("10 images of test data", fontsize=16)
plt.show()

test_images_32.png

終わりに

ハイパーパラメータであるバッチサイズを実験的に32、64、96、128、256で試した結果、32で訓練データに高めの精度を得られました。
一方で、検証データの精度がエポック数15で高止まりしたことから、汎化性能の低いモデルとなってしまいました。
学習率、バッチサイズ、エポック数といった、ハイパーパラメータを調整し、それでも改善しなければVGG16に連結するモデルの構成に工夫が必要であるように感じました。時間を見つけては、実験的に最適なパラメータ・モデル構成を見つけていきたいと思います。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?