LoginSignup
2
4

More than 3 years have passed since last update.

Kerasで始める機械学習(MNIST編)

Posted at

はじめに

機械学習におけるHelloWorld的位置にあるMNIST(手書き文字)の認識をKerasを用いて行います。
想定読者は機械学習、DeepLearningについて概念を学んだのでKerasで簡単な実装を行ないたい方です。

実行環境

macOS Mojave
Python 3.6.8
Keras 2.2.4

MNISTデータセットのLOAD

keasと必要なライブラリをImportし、MNISTデータセットをロードします。
MNISTデータはkerasに実装されている関数load_dataを利用し取得できます。
load_dataの詳細については公式サイトを参照ください。

main.py
import keras
import numpy

from keras.datasets import mnist

# MNISTデータのロード
train, test = mnist.load_data()
X_train, y_train = train
X_test, y_test = test

# サイズを変更(画像数、解像度(縦、横)、白黒)
X_train = X_train.reshape(60000,28,28,1)
X_test = X_test.reshape(10000,28,28,1)

MNISTデータを確認する

ロードしたMNISTデータをmatplotlibを使って確認してみます。
参考:matplotlib.pyplot.imshow

show_img
import matplotlib.pyplot as plt

def show_img(img, figsize=(2,2)):
    fig = plt.figure(figsize=figsize,dpi=100)
    plt.imshow(img, cmap = 'gray', interpolation = 'bicubic')
    plt.xticks([]), plt.yticks([])  # to hide tick values on X and Y axis
    plt.show()

# X_trainの一番目の画像を表示
show_img(X_train[0])

出力結果

数字の”5”と思われる画像が確認できます。

image.png

One-Hot encoding

教師ラベルをOne-Hot Encodingします。
KerasではOne-hot encodingする関数としてto_categoricalが用意されています。

One-hot とは1つだけHigh(1)であり、他はLow(0)であるようなビット列のことです。(wikipedia)
MNISTは0~9の教師ラベルを持つ手書き画像なので例えば教師ラベルの[5]は[0,0,0,0,0,1,0,0,0,0]のように6番目のみ1が立つようにEncodeされます。

to_categorical
from keras.utils.np_utils import to_categorical

y_trn_cgl = to_categorical(y_train)
y_test_cgl = to_categorical(y_test)

print(f"train_labels.shape:{y_train.shape}, y_trn.shape:{y_test.shape}")
print(f"y_trn_cgl:{y_trn_cgl}") 
print(f"y_test_cgl:{y_test_cgl}")

CNNモデルの構築

モデルの構造は (2D Conv -> Relu -> MaxPooling)*2 -> 全結合層 -> Softmax
となるシンプルなニューラルネットを構築します。
各レイヤの仕様について公式サイトのリンクを貼っておきます。

Conv2D:2次元畳み込み層
Activation:活性化関数(今回はReluを使用)
MaxPooling2D:Maxプーリング層
Dense:全結合層

Model
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, Flatten

def create_cnn_model(input_size):
    # 入力データサイズを指定
    X_input = Input((input_size[0], input_size[1], 1))
    # 2次元畳み込み層。
    X = Conv2D(filters=64, kernel_size=(5,5), padding="valid")(X_input)
    # 活性化関数はReluを使用
    X = Activation('relu')(X)
    # MaxPooling層
    X = MaxPooling2D(pool_size = (2,2))(X)

    X = Conv2D(filters=128, kernel_size=(5,5), padding="valid")(X)
    X = Activation('relu')(X)
    X = MaxPooling2D(pool_size = (2,2))(X)

    X = Flatten()(X)
    # 全結合層。MNISTでは10クラスへ分類するため出力10で指定
    X = Dense(10, activation="softmax")(X)

    model = Model(inputs=X_input, outputs=X)

    return model

CNNモデルの学習

MNISTデータの準備とCNNモデルの準備ができましたので、
モデルの学習を行なっていきます。

from keras import optimizers
# セッションの初期化
K.clear_session() 

# 入力画像サイズを指定
input_size =  (28, 28) 
# モデルを作成
cnn_model = create_cnn_model(input_size) 

# 最適化アルゴリズムを指定(lr=学習率)
opt = optimizers.Adagrad(lr=1e-4)
# 多クラス分類用の損失関数の指定
loss = "categorical_crossentropy" 
metrics = ['accuracy']

# 設定を反映
cnn_model.compile(optimizer=opt, loss=loss, metrics=metrics) 

#学習を実行
# 学習経過をhistory_cnnに格納してます。
history_cnn = cnn_model.fit(X_train[0:6000], y_trn_cgl[0:6000], epochs=10) 

テストデータの予測

学習したモデルでテストデータを予測します。

predict = cnn_model.predict(X_test)

eval_loss, eval_acc = cnn_model.evaluate(X_test, y_test_cgl)
print("eval_loss:{0}\neval_acc{1}".format(eval_loss, eval_acc))

def show_history(history):
    fig, ax = plt.subplots(1, 2, figsize=(15,5))
    ax[0].set_title('loss')
    ax[0].plot(history.epoch, history.history["loss"], label="Train loss")
    #ax[0].plot(history.epoch, history.history["val_loss"], label="Validation loss")
    ax[1].set_title('categorical_accuracy')
    ax[1].plot(history.epoch, history.history["acc"], label="Train accuracy")
    #ax[1].plot(history.epoch, history.history["val_acc"], label="Validation accuracy")
    ax[0].legend()
    ax[1].legend()
    plt.show()

show_history(history_cnn)

実行結果

左の図がLossを、右の図が正答率(categorical accuracy)を示しています。
横軸はEpoch数、つまり学習回数です。

学習回数を重ねるにつれLossが下がり学習が進み、それに伴い正答率が向上していることがわかります。
MNISTは比較的簡単な画像分類のため10回の学習で85%近い正答率が出せることがわかります。
特に工夫はしていないCNNモデルでも学習回数を増やせば90%近い正答率は出せると思います。

image.png

参考リンク

Keras Documentaion

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4