自分なりに理解したことを見返せるようにまとめただけ
ニューラルネットワーク
行列を忘れかけていたので復習。
難しいことは特にしていない。
手書き文字識別の実装
重み関数のshapeからニューラルネットワークの構造を読み解いてみる
ソースコードはこんな感じ
import pickle
def get_data():
# 手書き文字を読み込む
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
# あらかじめ設定された重み関数を呼び出してくる
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
重み関数Wとかは自分で定義せず、元から用意されてるデータを読み込んでいるらしい。
shapeを調べてみると、
ソースコード:
x, t = get_data()
print("x.shape: ", x.shape)
print("t.shape: ", t.shape)
network = init_network()
print("w1.shape: ", network['W1'].shape)
print("w2.shape: ", network['W2'].shape)
print("w3.shape: ", network['W3'].shape)
print("b1.shape: ", network['b1'].shape)
print("b2.shape: ", network['b2'].shape)
print("b3.shape: ", network['b3'].shape)
結果:
x.shape: (10000, 784)
t.shape: (10000,)
w1.shape: (784, 50)
w2.shape: (50, 100)
w3.shape: (100, 10)
b1.shape: (50,)
b2.shape: (100,)
b3.shape: (10,)
手書き文字のテストデータ(x)が10000個用意されているらしい。
784という数字は、画像サイズが28×28で、これを1次元配列にしたから。
tは正解ラベルだから、xと同じ10000個。
前提として、これは
・隠れ層が2つ
・1つ目の隠れ層のノードが50個
・2つ目の隠れ層のノードが100個
らしい。
つまり、
・W1は784個の入力に対して、それぞれ重みが設定されており、次の50個のノードにつなげるから784×50の行列
・W2は次の50個の入力に対して、さらに次の100個のノードにつなげるから50×100の行列
・W3は最後の100個の入力に対して、10個の出力を持つから、100×10の行列
最後が10個の出力なのは、今回扱っている手書き文字が0~9の10種類の文字だから。
出力結果は、入力がどの文字に相当するかを表す確率になっている。
予想してみる
ソースコードはこんな感じ
x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network, x[i])
# yはどの文字である可能性が高いかを示した確率
# 最も確率の高い要素のインデックスを取得
p = np.argmax(y)
# tは正解ラベル。予想結果があっていた場合、cntをプラスする
if p == t[i]:
accuracy_cnt += 1
print("Accuracy: ", float(accuracy_cnt) / len(x))
結果
Accuracy: 0.9352
所感
久々に数学の行列を思い出しながらやっていたのでちょっと飲み込むのに時間がかかった。
機械学習の基礎は過去に一通りやったので、割とすっと入ってきたかな。
普段Java書いているので、久々にPython触ると書きやすいけど型とかすごく不安になる。
あと、紙の書籍ペラペラめくりながらソースコード打つのだるすぎるので、O'REILLYさんは早く電子書籍出してほしい。