0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Deeplerning基礎知識おさらい活動 その2

Posted at

目的

Deeplerning基礎知識おさらい活動 第二回です!!
今回はニューラルネットワーク(3層)の推論処理に関してです!!
実際にニューラルネットがどのような手順で推論処理を行っているのかをコードを追いながら見ていきたいと思います!!
(今回のソースは「0から始まるDeeplearning」という著書を基にしています。)

ニューラルネットワークの推論処理

predict.py
def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y

早速、上記のコードを追っていきたいと思います。
まず、networkには既にラベル毎に('W1','b1'など)既に学習済みの重みが入っています。
それを各重み(W)、バイアス(b)に代入しています。それから各層のそれぞれの計算を行なっており、最後にsoftmax関数を用いて値を0~1に変換し、確率として表します!!

続いて以下のコードを追って行きましょう!!

doing_predict.py
# データセットを代入(xにデータ、tにラベル)
x, t = get_data()
# networkに学習済みの重みパラメータを代入
network = init_network()

accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    #最も確率の高い要素のインデックスを取得
    p = np.argmax(y)

    if p == t[i]
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x))) 

get_data()はデータとラベルを返す関数です。
networkに学習済みのパラメータを保存します。

そしてfor文の中の処理を見て行きましょう。
iには一つのデータの番地が入っています。そのi番目のデータを先ほどのpredict関数にかますことによって確率としてyが得られます。
pにはyの最も確率の高かったインデックスが入りそれとラベルと比較をします。
正しければaccuracy_cntに1が加算されます。
この工程を繰り返しデータの個数で割ってあげることで正解率(推論結果)を出します。

まとめ

実際の推論処理の工程を見て行きましたがラベルとの差分→繰り返す→確率を出すという順番でしたね。
推論の説明が難しくコードを用いて説明して見ました。
なかなか難しかったです。。
いかがでしたでしょうか!!
またお願い致します!!

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?