1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

「ゼロから作るDeep Learning」自習メモ(その5)4章 ニューラルネットワークの学習

Last updated at Posted at 2020-08-12

「ゼロから作るDeep Learning」(斎藤 康毅 著 オライリー・ジャパン刊)を読んでいる時に、参照したサイト等をメモしていきます。 その4← →その6

##4 ニューラルネットワークの学習

何をしているかの概略は、P112 からの説明で、なんとなくわかります。
損失関数を最小にするために、勾配を求めて、勾配の方向へ進むというのも、グラフのイメージで納得できるのですが、

微分、偏微分、2重和誤差、交差エントロピー誤差といろんな言葉と数式が並ぶと、わけがわからなくなりそうです。

とりあえず、4.1 データから学習する を読んだら、

P88
損失関数はニューラルネットワークの性能の“悪さ”を示す指標です。現在のニューラルネットワークが教師データに対してどれだけ適合していないか、教師データに対してどれだけ一致していないかということを表します。
 
P106
勾配が示す方向は、各場所において関数の値を最も減らす方向なのです。

勾配をうまく利用して関数の最小値(または、できるだけ小さな値)を探そう、というのが勾配法です。

ということを押さえて、P112 4.5 学習アルゴリズムの実装 に進んでしまってもいいかもしれません。
本の内容は、プログラム内で使っている関数の説明を順に積み上げてから実装、となっていますが、まず実装したプログラムを動かして見る。で、いろいろデータ内容を確認してみてから、関数の説明を熟読する、という順番のほうが、理解しやすいと思います。
数学の素養がない私などには、最初に関数の説明を読んでも、なんのことかよくわかりませんでしたが、プログラムを動かしてから、プログラムをトレースしていくと、本の説明がどのようにプログラムに組まれているか、理解できないまでも、納得はできました。

で、まず動かしてみようとなった場合には、P118 のプログラムの中に

# 勾配の計算
grad = network.numerical_gradient(x_batch, t_batch)
# grad = network.gradient(x_batch, t_batch) # 高速版!

となっているところがありますが、この通りにして実行すると、ものすごく時間がかかります。
ここは

# 勾配の計算
# grad = network.numerical_gradient(x_batch, t_batch)
grad = network.gradient(x_batch, t_batch) # 高速版!

と変えて実行したほうがいいでしょう。
P114 のプログラム例には gradientメソッドが書いてありませんが、ダウンロードしたファイルのch04/two_layer_net.py には、gradientメソッドも書いてあります。

損失関数が変化していく様子のグラフ

# print(train_loss_list)
import matplotlib.pylab as plt
x = np.arange(len(train_loss_list))
y = train_loss_list
plt.plot(x, y)
plt.show()

p119.jpg

学習した結果(重みとバイアス)をpickleで退避させてみた

# networkオブジェクトを、pickleで保存する。
import pickle
save_file = dataset_dir + '/gakusyuukekka_weight.pkl'    #拡張子は.pkl
with open(save_file, 'wb') as f:
    pickle.dump(network, f, -1)    

学習した結果(重みとバイアス)をセットし、テストデータを読ませて推論処理をしてみた。

import numpy as np
from common.functions import *

class TwoLayerNet:
    def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
        # 重みの初期化
        self.params = {}
        
    def predict(self, x):
        W1, W2 = self.params['W1'], self.params['W2']
        b1, b2 = self.params['b1'], self.params['b2']
        
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        y = softmax(a2)
        return y

import pickle
import sys, os

dataset_dir = os.path.dirname(os.path.abspath('__file__'))+'/dataset'

mnist_file = dataset_dir + '/mnist.pkl'
with open(mnist_file, 'rb') as f:
    dataset = pickle.load(f)
dataset['test_img'] = dataset['test_img'].astype(np.float32)
dataset['test_img'] /= 255
x = dataset['test_img']
t = dataset['test_label']

# network = TwoLayerNet(input_size=784, hidden_size=100, output_size=10)
weight_file = dataset_dir + '/gakusyuukekka_weight.pkl'
with open(weight_file, 'rb') as f:
    network = pickle.load(f)

accuracy_cnt = 0
for i in range(len(x)):
    y = network.predict(x[i])
    p= np.argmax(y) 
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x))) 

結果は 

Accuracy:0.8652

と、えらく悪い認識精度に。

# 識別結果の内容を確認してみる
import matplotlib.pyplot as plt

def showImg(x):
    example = x.reshape((28, 28))
    plt.figure()
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(example)
    plt.show()
    return

for i in range(10):
    y = network.predict(x[i])
    p= np.argmax(y)
    print("正解 " + str(t[i]))
    print("[ " + str(p) + " ]")
    count = 0
    for v in y:
        print("["+str(count)+"] {:.2%}".format(v))
        count += 1
    showImg(x[i])

p120.jpg

正解ではあるが、確率95.63% で、3章でやった処理よりも確度が低い。

p121.jpg

これは3章の処理では正解していたもの。

まだ、本の途中だから精度は低いのだろうけど、何か間違っているのかどうかは不明。

その4← →その6

読めない用語集

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?