0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

技術書一冊やり込もうAdvent Calendar 2024

Day 22

ゼロから作るDeepLearning #8 -誤差逆伝播 (グラフを除いて…) -

Last updated at Posted at 2024-12-21

※以下の企画です

前回に続いてゼロつくやっていきます。
今回からは5章の誤差逆伝播の内容に入っていきます。
いつものことながら図は用いないアウトプットになるので、書籍上だと説明の要になっている「計算グラフ」はすっ飛ばされます。ごめんなさい。

それでは頑張っていきます〜

5章 誤差逆伝播法

概要

誤差逆伝播法は、パラメータの効果的な更新を目指して、誤差の減少を速く実現する主要な手段となる。
微分って時間も手間もかかるみたいなので、もっと早くて効率の良いものを考えようってことみたい。

誤差逆伝播法の核心は、順伝播で計算した出力誤差を入力層まで逆向きに伝播させながら、各層のパラメータに対する誤差の勾配を効率的に計算すること
この手法は、数値微分を使った勾配計算よりもはるかに計算コストが低く、ニューラルネットワークの学習における標準技術となっているらしい。

計算の流れ

誤差逆伝播法の計算は以下の手順で行われる

  1. 順伝播

    • 入力データを使って、ニューラルネットワークの出力(予測値)を計算する
    • 各層での中間値(活性化関数の出力)も記録する
  2. 出力層での誤差計算

    • 出力と正解ラベルの差を基に誤差を計算し、損失関数を求める
    • この誤差が逆伝播の起点となる
  3. 逆伝播

    • 出力層から入力層に向かって、各層でのパラメータ(重みとバイアス)に関する勾配を計算する
    • 勾配は連鎖律を使って効率的に計算される
  4. パラメータの更新

    • 勾配降下法などの最適化アルゴリズムを使って、計算した勾配を基にパラメータを更新する

連鎖律(GPTにも聞いてみてる)

勾配は連鎖律を使って効率的に計算される。
連鎖律(Chain Rule)は、合成関数の微分法則であり、誤差逆伝播法の核となる仕組みである
例えば、関数$z = f(g(x))$の微分を求める場合、連鎖律を用いると$\frac{dz}{dx} = (\frac{df}{dg}) (\frac{dg}{dx})$という形で計算できる。これにより、複数の層で構成されるニューラルネットワークの勾配を効率的に計算できる。
※一つのノードの出力結果=ある関数の出力結果とみなすことができる

まぁ数式だけ見てもようわからんってなると思う(少なくとも僕はなる)。こればっかりは本書の計算グラフの内容を是非とも見てみてほしい(放棄)

Pythonでの実装

誤差逆伝播法をPythonで実装する代表的な二層ネットワークを使う。
この実装では、数値微分を使う代わりに、逆伝播で効率的に勾配を計算する。

誤差逆伝播法
class TwoLayerNet:
    def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
        # パラメータの初期化
        self.params = {}
        self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params['b1'] = np.zeros(hidden_size)
        self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
        self.params['b2'] = np.zeros(output_size)
        
        self.grads = {}  # 勾配を格納する辞書

    def predict(self, x):
        W1, b1 = self.params['W1'], self.params['b1']
        W2, b2 = self.params['W2'], self.params['b2']

        # 順伝播を実行
        self.a1 = np.dot(x, W1) + b1
        self.z1 = sigmoid(self.a1)
        self.a2 = np.dot(self.z1, W2) + b2
        self.y = softmax(self.a2)
        return self.y

    def loss(self, x, t):
        y = self.predict(x)
        return cross_entropy_error(y, t)

    def backward(self, x, t):
        # 誤差逆伝播法による勾配の計算
        batch_size = x.shape[0]

        # 出力層での誤差
        dy = (self.y - t) / batch_size
        self.grads['W2'] = np.dot(self.z1.T, dy)
        self.grads['b2'] = np.sum(dy, axis=0)

        # 隠れ層での誤差
        dz1 = np.dot(dy, self.params['W2'].T)
        da1 = dz1 * sigmoid_grad(self.z1)
        self.grads['W1'] = np.dot(x.T, da1)
        self.grads['b1'] = np.sum(da1, axis=0)
        
        return self.grads

def sigmoid_grad(x):
    return x * (1 - x)

# 実行例
x = np.random.rand(3, 2)  # ダミーデータ
t = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])  # 正解ラベル

network = TwoLayerNet(input_size=2, hidden_size=3, output_size=3)

# 順伝播の誤差計算
loss = network.loss(x, t)
print("Loss:", loss)

# 誤差逆伝播法による勾配計算
network.backward(x, t)
print("Gradients:", network.grads)
実行結果
Loss: 2.1972294090761904

Gradients: {'W2': array([[-0.11087135, -0.11135754, -0.11141672],
       [-0.11083958, -0.11147447, -0.11173506],
       [-0.11090013, -0.11120184, -0.1114807 ]]), 'b2': array([-0.22180814, -0.22244025, -0.22241828]), 'W1': array([[-1.11597240e-04, -2.06919280e-04,  2.20979662e-04],
       [-9.33578570e-05, -3.65867914e-04,  4.27203605e-04]]), 'b1': array([ 0.00078198, -0.00192981,  0.00023119])}

softmaxなどの関数はこれまでの記事で作成したものと同じ。
実行結果を見ると、ちゃんと計算されてるっぽいのがわかる。わかるだけで検算的なことはできていない

まとめ

5章では、誤差逆伝播法によってニューラルネットワークの効率的な学習が可能になる事を学んだ。
計算グラフが理解の要になっているものの、すべて自分で作成してアウトプットしてもあまり自分の学習にとって効率的ではなかったので端折っている。
故にあんまり面白みのない記事になっているが、目を瞑っている。すみません。
次回も頑張ります!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?