LoginSignup
1
2

More than 3 years have passed since last update.

非情報系大学院生が一から機械学習を勉強してみた #3:MNIST手書き数字認識

Posted at

はじめに

非情報系大学院生が一から機械学習を勉強してみました。勉強したことを記録として残すために記事に書きます。
進め方はやりながら決めますがとりあえずは有名な「ゼロから作るDeep-Learning」をなぞりながら基礎から徐々にステップアップしていこうと思います。環境はGoogle Colabで動かしていきます。第3回はMNIST手書き数字認識をニューラルネットワークで行います。

目次

  1. MNIST手書き数字認識問題とは
  2. ニューラルネットワークの実装
  3. バッチ処理

1. MNIST手書き数字認識問題とは

MNIST1とは手書きの0~9の数字画像で構成されたデータセットで、機械学習のHello world!てきなやつでしょうか。MNISTの画像データは28×28のグレー画像で各ピクセルは0~255の256段階で表現されます。各画像データに対して対応する正解ラベルが与えられています。実際に見てみましょう。

データセットの取得
!git clone https://github.com/oreilly-japan/deep-learning-from-scratch

import sys, os
sys.path.append("/content/deep-learning-from-scratch")      #setting path
from dataset.mnist import load_mnist

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
#(訓練画像, 訓練ラベル),(テスト画像, テストラベル)

print(x_train.shape)    #(60000, 784)
print(t_train.shape)    #(60000,)
print(x_test.shape)    #(10000, 784)
print(t_test.shape)    #(60000,)

事前準備としてデータセットを取得します。今回は参考にしている「ゼロから作るDeep-Learning」のgithubから取得します。データの形状から訓練画像は60000枚、テスト画像は10000枚あることが分かります。また、サイズ784というのは28×28の値で、load_mnist()関数の引数でflattenがTrueになっているので1×28×28の画像が1次元配列として格納されていることを表します。では確認がてらMNIST画像を表示してみます。

画像データ表示
img = x_train[0].reshape(28,28)    #28×28に変換
plt.imshow(img)
plt.show()
print(t_train[0])    #5

MNIST.png
データセットの0番目画像を見てみると手書きの「5」っぽい数字が表示されました。本当に5でなのか確認するために訓練ラベルの0番目を参照するとやはり「5」と出力されるので合っていることが確認できました。

2. ニューラルネットワークの実装

データセットの正体が分かったので手書き数字認識を行うニューラルネットワークを実装します。

MNIST手書き数字認識を行うニューラルネットワーク
#MNISTデータセット取得
def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

#ニューラルネットワークの重み、バイアスを読み込み
def init_network():
    with open("/content/deep-learning-from-scratch/ch03/sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

#推論処理関数
def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y

get_data()関数は上で説明したものと同じ、predict()関数は前回紹介したニューラルネットワークの構造と同じです。init_network()でニューラルネットワーク各層の重み、バイアスを定義していますが、ここでは学習済みのパラメータを使用するので与えられた値を読み込んでいます。学習済みパラメータは入力層が28×28=784、出力層が0~9の10分類をしなければいけないので10、隠れ層が2つの3層ニューラルネットワークでそれぞれのニューロン数は50, 100です。

推論処理
#定義
x, t = get_data()
network = init_network()
accuracy_cnt = 0

# for文で画像1枚ずつ推論処理
for i in range(len(x)):
    y = predict(network, x[i])
    p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))    #Accuracy:0.9352

#行列サイズ確認
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(x.shape)       #(10000, 784)
print(x[0].shape)    #(784,)
print(W1.shape)      #(784, 50)
print(W2.shape)      #(50, 100)
print(W3.shape)      #(100, 10)

for文を回してテスト画像1枚1枚に推論処理を行っていきます。predict()関数の出力がSoftmaxを通った各数字に認識される確率なので、その値が一番大きいものを認識結果として取得します。その結果をテストラベルと比較して精度を計算します。与えられた学習済みパラメータに基づいて分類を行うと93.52%の認識精度が得られました。また、上で述べた通りのニューロン数の計算になっていることも確認できました。

3. バッチ処理

一般に計算機による数値計算ではfor文をぐるぐる回すより行列演算にするなどひとまとまりを大きくした方が計算が楽になります。先ほどのコードではfor文をぐるぐる回していましたが、ニューラルネットワークでも一定サイズのまとまりごとに推論を行うように変更します。これをバッチ処理と言います。

バッチ処理を用いた推論処理
batch_size = 100
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]     #i番目バッチデータを取得
    y_batch = predict(network, x_batch)
    p= np.argmax(y_batch, axis=1)       #index of max value
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))    #Accuracy:0.9352

上のプログラムでは推論処理のpredict()関数を一枚ずつ行うのではなくx_batch枚ずつまとめておこなうことでforループをlen(x)/batch_size回に削減しています。つまり1枚ずつ演算していたときは
(1, 784)→(784, 50)→(50, 100)→(100, 10)→(1, 10)
とサイズが変化していたのが
(100, 784)→(784, 50)→(50, 100)→(100, 10)→(100, 10)
というように100入力100出力になっています。

ニューラルネットワークの構造を確認するのに前回第2回で書いたようなネットワークを毎回書くのはしんどいので通常はこのように略記するようです。
structure.png

次回はニューラルネットワークで学習を行う準備として最急降下法を導入します。

参考文献

ゼロから作るDeep-Learning
ゼロから作るDeep-Learning GitHub
深層学習 (機械学習プロフェッショナルシリーズ)


  1. Mixed National Institute of Standards and Technology databaseの略らしい 

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2