「ゼロから作るDeep Learning」(斎藤 康毅 著 オライリー・ジャパン刊)を読んでいる時に、参照したサイト等をメモしていきます。 その3← →その5
##3.6.2 ニューラルネットワークの推論処理
本来なら、訓練データで学習させてから、テストデータで推論させる、という順番なわけだが、この本では学習の前に、推論処理を説明している。
で、私の場合、テストデータの準備の部分が、本の例とは違っています。
# 学習済みのニューラルネットワークで、テストデータを識別してみる
import numpy as np
import pickle
import sys, os
dataset_dir = os.path.dirname(os.path.abspath('__file__'))+'/dataset'
mnist_file = dataset_dir + '/mnist.pkl'
with open(mnist_file, 'rb') as f:
dataset = pickle.load(f)
def normalize(key):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /= 255
return dataset[key]
def get_data():
x_test = normalize('test_img')
t_test = dataset['test_label']
return x_test, t_test
def init_network():
# 学習済みの重みパラメータを読み込む
weight_file = dataset_dir + '/sample_weight.pkl'
with open(weight_file, 'rb') as f:
network = pickle.load(f)
return network
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def softmax(a):
c = np.max(a)
exp_a = np.exp(a - c)
sum_exp_a = np.sum(exp_a)
y = exp_a / sum_exp_a
return y
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network, x[i])
p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
if p == t[i]:
accuracy_cnt += 1
print("Accuracy:" + str(float(accuracy_cnt) / len(x))) # Accuracy:0.9352
これだけだと、どんなふうに推論しているのかよくわからないので、テスト画像と予測結果を表示してみた。
# 識別結果の内容を確認してみる
import matplotlib.pyplot as plt
def showImg(x):
example = x.reshape((28, 28))
plt.figure()
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(example)
plt.show()
return
for i in range(10):
y = predict(network, x[i])
p= np.argmax(y)
print("[ " + str(p) + " ]")
print(y)
showImg(x[i])
配列 y[7] の予測確率の数字が一番大きくなっています。その次が y[3]、y[9]となっています。
間違ったときは、どのように推論しているかも確認してみた。
# 間違ったものを表示
for i in range(200):
y = predict(network, x[i])
p= np.argmax(y)
if p != t[i]:
print("正解 " + str(t[i]))
print("[ " + str(p) + " ]")
print(y)
showImg(x[i])
4である確率が48%と予測している。正解の9については39%の確率と予測している。