0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

「ゼロから作るDeep Learning」自習メモ(その4)3.6.2 ニューラルネットワークの推論処理

Last updated at Posted at 2020-08-11

「ゼロから作るDeep Learning」(斎藤 康毅 著 オライリー・ジャパン刊)を読んでいる時に、参照したサイト等をメモしていきます。 その3← →その5

##3.6.2 ニューラルネットワークの推論処理

本来なら、訓練データで学習させてから、テストデータで推論させる、という順番なわけだが、この本では学習の前に、推論処理を説明している。

で、私の場合、テストデータの準備の部分が、本の例とは違っています。

# 学習済みのニューラルネットワークで、テストデータを識別してみる
import numpy as np
import pickle
import sys, os

dataset_dir = os.path.dirname(os.path.abspath('__file__'))+'/dataset'
mnist_file = dataset_dir + '/mnist.pkl'
with open(mnist_file, 'rb') as f:
    dataset = pickle.load(f)

def normalize(key):
    dataset[key] = dataset[key].astype(np.float32)
    dataset[key] /= 255
    return dataset[key]

def get_data():
    x_test = normalize('test_img')
    t_test = dataset['test_label']
    return x_test, t_test

def init_network():
    # 学習済みの重みパラメータを読み込む
    weight_file = dataset_dir + '/sample_weight.pkl'
    with open(weight_file, 'rb') as f:
        network = pickle.load(f)
    return network

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def softmax(a):
    c = np.max(a)
    exp_a = np.exp(a - c)
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    return y

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y

x, t = get_data()
network = init_network() 

accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x))) # Accuracy:0.9352

これだけだと、どんなふうに推論しているのかよくわからないので、テスト画像と予測結果を表示してみた。

# 識別結果の内容を確認してみる
import matplotlib.pyplot as plt

def showImg(x):
    example = x.reshape((28, 28))
    plt.figure()
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(example)
    plt.show()
    return

for i in range(10):
    y = predict(network, x[i])
    p= np.argmax(y)
    print("[ " + str(p) + " ]")
    print(y)
    showImg(x[i])

結果はこうなりました。
p77.jpg

配列 y[7] の予測確率の数字が一番大きくなっています。その次が y[3]、y[9]となっています。

間違ったときは、どのように推論しているかも確認してみた。

# 間違ったものを表示
for i in range(200):
    y = predict(network, x[i])
    p= np.argmax(y)
    if p != t[i]:
        print("正解 " + str(t[i]))
        print("[ " + str(p) + " ]")
        print(y)
        showImg(x[i])

p78.jpg

4である確率が48%と予測している。正解の9については39%の確率と予測している。

その3← →その5

読めない用語集

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?