2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

深層学習/ゼロから作るDeep Learning 第3章メモ

Last updated at Posted at 2020-04-28

1.はじめに

 名著、**「ゼロから作るDeep Learning」**を読んでいます。今回は3章のメモ。
コードの実行はGithubからコード全体をダウンロードし、ch03の中で jupyter notebook にて行っています。

2.3層ニューラルネットワーク

 3層のニューラルネットワークを考えます。ニューロンの数は、入力層2, 第1層3, 第2層2, 出力層2とすると、順伝播は下記の様な行列演算で行えます。
スクリーンショット 2020-04-28 10.52.24.png
スクリーンショット 2020-04-28 10.52.38.png

import numpy as np

def sigmoid(x):
    return 1/(1+np.exp(-x))

def identity_function(x):
    return x

def init_network():
    network={}
    network['W1']=np.array([[0.1, 0.3, 0.5],[0.2, 0.4, 0.6]])
    network['b1']=np.array([0.1, 0.2, 0.3])
    network['W2']=np.array([[0.1, 0.4],[0.2, 0.5],[0.3, 0.6]])
    network['b2']=np.array([0.1, 0.2])
    network['W3']=np.array([[0.1, 0.3],[0.2, 0.4]])
    network['b3']=np.array([0.1, 0.2])    
    return network

def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1)+b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2)+b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3)+b3
    y = identity_function(a3)    
    return y
        
network=init_network()
x = np.array([1.0, 0.5])
y = forward(network, x)
print(y)

# 出力
[0.31682708 0.69627909]

3.学習済モデルでの推論

 テキストでは、数字の0〜9(MNIST)の識別をするネットワークの学習済みパラメータが保存されているので、これを使って推論を行います。

 ネットワークのニューロンの数は、入力層784、第1層50、第2層100、出力層10 で、先程同様に、順伝播は下記の様な行列演算で表すことが出来ます。
スクリーンショット 2020-04-28 11.18.37.png
スクリーンショット 2020-04-28 11.16.02.png

import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax

def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)
    return y

x, t = get_data()
network = init_network()
accuracy_cnt = 0
pred = []  # 推論結果を保存するリストを用意
for i in range(len(x)):
    y = predict(network, x[i])
    p = np.argmax(y) # 最も確率の高い要素のインデックスを取得
    pred.append(p)  # 推論結果を保存
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

# 出力
Accuracy:0.9352

 後から推論結果を個別に見たいので、推論結果を保存するリストの用意pred = [] と、推論の保存pred.append(p)をコードに追加しています。

 それでは、ついでに個別にどう予測しているかを見てみましょう。変数xに画像、変数predに推論結果が入っていますので、

import matplotlib.pyplot as plt

fig = plt.figure(figsize=(10, 12))
for i in range(50):
    ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
    ax.imshow(x[i].reshape((28, 28)), cmap='gray')
    ax.set_xlabel('pred='+str(pred[i]))

スクリーンショット 2020-04-28 10.03.36.png
 画像データの先頭から50枚の推論結果を表示しています。間違えたのは、赤枠で表示した2枚のみで、結構優秀ですね。

4.バッチ処理

 さて、推論する時に1枚づつやらなくても、まとめて100枚とかバッチ処理すると効率的です。その場合は、先程のコードの推論部分のみ変更すればOK。

batch_size = 100 # バッチの数
accuracy_cnt = 0
pred = np.array([])  # 推論結果を保存する箱(numpy)を用意
for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1)
    pred = np.append(pred, p) # 推論結果を保存
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
pred = pred.astype(int)  # 推論結果をfloatからintへ変換

 先程同様、推論結果を保存するコードを追加しています。今回は、バッチ処理するため推論結果がnumpy形式で返って来ますので、保存する箱の準備は pred = np.array([])、保存は pred = np.append(pred, p)、このままだと「1.0」とかいう表示になるので、最後に整数に戻すため pred = pred.astype(int)としています。

 さて、ここでのポイントは y_batch = predict(network, x_batch)です。これは、ネットワークの入力に、100個分のデータを入れたので、出力も100個出るということです。先程の行列演算のイメージで言えば、

スクリーンショット 2020-04-28 11.14.38.png

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?