0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

HUITAdvent Calendar 2024

Day 13

誤差逆伝搬法を理論から理解する

Last updated at Posted at 2024-12-12

前置き

こんにちは、misaizuです。
ずっとアドカレやるやる詐欺をして、めでたく2年が経過したので今年こそやります。

誤差逆伝搬法を理論から理解する

今回はいわゆる緑本 (機械学習スタートアップシリーズ これならわかる深層学習入門) を読んでいて第6章の誤差逆伝搬法のところで見事に嵌ったので、その備忘録的なのを残しておこうと思います。

筆者も勉強したて && 数学科の人間ではないので、一部正確でない点があるかもしれませんが、温かい目で見ていただけると幸いです。

数学的記法

本記事の数学的記法は基本的に上述の緑本を参考にしています。したがって、列ベクトルを太字の小文字で表記し、各成分は下付き文字で表記します。例えば、3次元ベクトルは以下のように表します。

\boldsymbol{v} = \begin{pmatrix}
    v_1 \\
    v_2 \\
    v_3 \\
\end{pmatrix},

\boldsymbol{v}^\top = \begin{pmatrix}
    v_1 & v_2 & v_3
\end{pmatrix}

また行列は太字の大文字で表記し、各成分も下付き文字で表記します。

\boldsymbol{W} = \begin{pmatrix}
    W_{ij}
\end{pmatrix}

ベクトルのノルムは以下のように表記します。

||\boldsymbol{v}||^2

ニューラルネットワークと連鎖律について

誤差逆伝搬法はニューラルネットワークで使われる手法なので、もちろん前提知識としてニューラルネットワークについての基本的な知識が必要になります。また偏微分を多用するので連鎖律(chain rule)についても知っておいてもらう必要があります。「わかってるよ!」って方はここは飛ばしていただいて構いません。

ニューラルネットワークと連鎖律について

ニューラルネットワーク

さて前置きと言いましたがニューラルネットワークについて一から解説するとめちゃめちゃ時間がかかるので、個人的に分かりやすいと思った記事を載せておきます(すみません)。

https://qiita.com/NekoAllergy/items/489a4158b15231936f11

連鎖律

こっちはちょっと真面目に解説します。

合成関数の偏微分を求めるときに使うテクニックです。例として二変数関数の連鎖律を挙げてみます。

関数 $f=f(u, v)$ が $C^1$ 級で、$u=u(x, y), v=v(x, y)$ がそれぞれ偏微分可能であるとする。このとき、

\frac{\partial f}{\partial x} = \frac{\partial f}{\partial u}\frac{\partial u}{\partial x} + \frac{\partial f}{\partial v}\frac{\partial v}{\partial x}

が成り立つ。

$C^1$ 級という聞きなれない単語が出現しましたが、ざっくり「$u, v$ の両方で微分可能」と覚えていただければ構いません。要は $f(u(x, y), v(x, y))$ を $x$ で偏微分したかったら、$x$ がかかわっている $u, v$ 両方について合成関数の微分をしようねというイメージです。詳しい証明とかは以下のサイトが参考になると思います。

https://mathlandscape.com/partial-derivative-composite/

数学記号の説明

今回は$L$層のニューラルネットワークを想定し、記号は緑本を参考に以下の表の感じで定義します。また簡単のためにバイアスはないものと仮定します。

記号 備考
入力 $\boldsymbol{x}$
各層の出力 $\boldsymbol{z}^{(l)}$ $l$層目の出力を$\boldsymbol{z}^{(l)}$と表記する
重み $w$ $l-1$層の$i$番目のユニットから$l$層目の$j$番目のユニットへの重みを $w_{ji}^{(l)}$と表記する
パラメータ $\boldsymbol{w}$ 全$w_{ji}$を縦に並べた列ベクトル。
ある時刻(エポック)$t$でのパラメータを$\boldsymbol{w}^{(t)}$と表記
出力層の出力 $\boldsymbol{\hat{y}}(\boldsymbol{x}; \boldsymbol{w})$ $\boldsymbol{w}$における推定値。
省略して$\boldsymbol{\hat{y}}$って書いたりします。
$\boldsymbol{z}^{(L)} = \boldsymbol{\hat{y}}$
教師データ $\boldsymbol{y}$ いわゆる正解ラベル
活性(重み付き線形和) $\boldsymbol{u}^{(l)}$ 活性化関数に通す前のやつ。
$u_j^{(l)} = w_{j1}^{(l)}z_1^{(l-1)} + w_{j2}^{(l)}z_2^{(l-1)} + w_{j3}^{(l)}z_3^{(l-1)} + \cdots$
活性化関数 $f^{(l)}$ $\boldsymbol{z}^{(l)} = f^{(l)}\left(\boldsymbol{u^{(l)}} \right)$
損失関数 $E(\boldsymbol{w})$ 出力(予測)$\boldsymbol{\hat{y}}$と教師データ$\boldsymbol{y}$の誤差。
$\boldsymbol{w}$での誤差を$E(\boldsymbol{w})$と表記する。
学習率 $\eta$

$\boldsymbol{z}^{(l)}$と$\boldsymbol{w}^{(t)}$で表記が似ててややこしいですが、前者はある$l$層目の中間層の話、後者はある時刻$t$での話で、全く異なるので気を付けてください。

本編

色々と前置きが終わったので、ここから本編に入ります。

勾配降下法

前置きでも触れられていると思いますが、ニューラルネットワークはユニット同士を結合する「重み」の値を推定値(最終層の出力)の誤差をもとに更新し、学習を進めます。つまり、ニューラルネットの学習は損失関数$E(\boldsymbol{w})$を最小化するパラメータ$\boldsymbol{w}$を見つけるという最適化問題に帰着できます。もっとも単純な考え方は損失関数上の適当な位置から下り続ける方法です。簡単に言うと損失関数上の適当な点にボールを置いて、転がして止まるまで待つ、という方法です。勾配降下法はまさにこの考え方を使った方法で、これは、その位置での勾配の逆方向にボールを動かす、ということと同じです。

\nabla E(\boldsymbol{w}) = \frac{\partial E(\boldsymbol{w})}{\partial \boldsymbol{w}} = 
\left(
    \frac{\partial E(\boldsymbol{w})}{\partial w_1},
    \frac{\partial E(\boldsymbol{w})}{\partial w_2},
    \cdots,
\right)^\top

この勾配を使って、次の時刻でのパラメータを以下のように更新します。

\begin{align}
\boldsymbol{w}^{(t+1)} = \boldsymbol{w}^{(t)} + \eta\Delta\boldsymbol{w}^{(t)}\\
\Delta\boldsymbol{w}^{(t)} = -\nabla E\left(\boldsymbol{w}^{(t)}\right)
\end{align}

$\eta$は学習率とか呼ばれるやつです。($\Delta$はラプラシアンじゃないので注意してください。)

ではさっき書いた勾配

\nabla E(\boldsymbol{w}) = \frac{\partial E(\boldsymbol{w})}{\partial \boldsymbol{w}}

を計算してみましょう。ただし$\boldsymbol{w}$(ベクトル)のまま考えると面倒くさいのである成分について考えることにします。$\boldsymbol{w}$は全重みを縦に並べた列ベクトルですから、ある$l$層目の重み$w_{ji}^{(l)}$の勾配は

\frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(l)}}

になります。ここで、誤差関数$E(\boldsymbol{w})$は出力(推定値)$\boldsymbol{\hat{y}}$と教師データ$\boldsymbol{y}$との誤差を測る関数なので、$E(\boldsymbol{w})$は$\boldsymbol{\hat{y}}$に依存します。
例えば、誤差関数が平均二乗誤差なら

E(\boldsymbol{w}) = \frac{1}{2}||\boldsymbol{\hat{y}}-\boldsymbol{y}||^2

という具合です。また、$\boldsymbol{\hat{y}}(\boldsymbol{x}; \boldsymbol{w})$なのでもちろん$\boldsymbol{w}$に依存しますから、連鎖律を使えばうまいこと計算できそうです。
...と言いたいところですが、これそのままやろうとするとめちゃくちゃ時間がかかります。というのも$\boldsymbol{\hat{y}}$は

\boldsymbol{\hat{y}} = f^{(L)}\left(\boldsymbol{w}^{(L)}f^{(L-1)}\left(\boldsymbol{w}^{(L-1)}f^{(L-2)}\left(\cdots \boldsymbol{w}^{(l+1)}f^{(l)}\left(\boldsymbol{w}^{(l)}f^{(l-1)}\left(\cdots f(\boldsymbol{x}) \right) \right) \right) \right) \right)

となり、入力層に近い重みほど多くの合成関数を処理しないといけないことになります。これを何とか解決すべく生まれたのが誤差逆伝搬法です。

誤差逆伝搬法

任意の重みの勾配

$w_{ji}^{(l)}$ の勾配を $l$ 層目の活性 $\boldsymbol{u}^{(l)}$ の連鎖律で表すことを考えます。つまり、

\frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(l)}}
    = \sum_k\frac{\partial E(\boldsymbol{w})}{\partial u_k^{(l)}}\frac{\partial u_k^{(l)}}{\partial w_{ji}^{(l)}}
    = \frac{\partial E(\boldsymbol{w})}{\partial u_j^{(l)}}\frac{\partial u_j^{(l)}}{\partial w_{ji}^{(l)}}

となります($\boldsymbol{u}^{(l)}$の$j$番目の成分以外は$w_{ji}^{(l)}$に依存しないので消えます)。後で使いやすいように

\delta_j^{(l)} = \frac{\partial E(\boldsymbol{w})}{\partial u_j^{(l)}}

とおいておきます。ここで、

u_j^{(l)} = w_{j1}^{(l)}z_1^{(l-1)} + w_{j2}^{(l)}z_2^{(l-1)} + \cdots + w_{ji}^{(l)}z_i^{(l-1)} + \cdots

より、

\frac{\partial u_j^{(l)}}{\partial w_{ji}^{(l)}} = z_i^{(l-1)}

となります。また $\delta_j^{(l)}$ の方は新しく $\boldsymbol{u}^{(l+1)}, \boldsymbol{z}^{(l)}$ を導入した連鎖律で書くと

\begin{align*}
    \delta_j^{(l)}
        &= \sum_k\frac{\partial E(\boldsymbol{w})}{\partial u_k^{(l+1)}}\frac{\partial u_k^{(l+1)}}{\partial u_j^{(l)}}\\
        &= \sum_k\frac{\partial E(\boldsymbol{w})}{\partial u_k^{(l+1)}}\frac{\partial u_k^{(l+1)}}{\partial z_j^{(l)}}\frac{\partial z_j^{(l)}}{\partial u_j^{(l)}}\\
        &= \sum_k\delta_k^{(l+1)}w_{kj}^{(l+1)}\frac{\partial f^{(l)}\left(u_j^{(l)} \right)}{\partial u_j^{(l)}}\\
\end{align*}
2行目から3行目の式変形
u_k^{(l+1)} = w_{k1}^{(l+1)}z_1^{(l)} + w_{k2}^{(l+1)}z_2^{(l)} + \cdots + w_{kj}^{(l+1)}z_j^{(l)} + \cdots\\

より、

\frac{\partial u_k^{(l+1)}}{\partial z_j^{(l)}} = w_{kj}^{(l+1)}

また

z_j^{(l)} = f^{(l)}\left(u_j^{(l)} \right)

より、

\frac{\partial z_j^{(l)}}{\partial u_j^{(l)}} = \frac{\partial f^{(l)}\left(u_j^{(l)} \right)}{\partial u_j^{(l)}}

こんな感じになりました。$\frac{\partial f^{(l)}\left(u_j^{(l)} \right)}{\partial u_j^{(l)}}$ は要は活性化関数の偏微分ですからすぐに計算できます。まとめると、

\frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(l)}} = 
z_i^{(l-1)} \sum_k\delta_k^{(l+1)}w_{kj}^{(l+1)}\frac{\partial f^{(l)}\left(u_j^{(l)} \right)}{\partial u_j^{(l)}}

ということで、おびただしい長さの合成関数の計算が、1つ次の層のみに依存するただの総和計算に落ちました。

最終層の重みの勾配

上の結果から、出力層の重み $w_{ji}^{(L)}$ が分かればそこから順次任意の重みについて計算ができそうです。$w_{ji}^{(L)}$ での勾配は、

\frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(L)}}
    = \frac{\partial E(\boldsymbol{w})}{\partial u_j^{(L)}}\frac{\partial u_j^{(L)}}{\partial w_{ji}^{(L)}}
    = \delta_j^{(L)}z_{i}^{(L-1)}

より、$\delta_j^{(L)}$ は

\begin{align*}
    \delta_j^{(L)} 
        &= \frac{\partial E(\boldsymbol{w})}{\partial u_j^{(L)}}\\
        &= \frac{\partial E(\boldsymbol{w})}{\partial z_j^{(L)}}\frac{\partial z_j^{(L)}}{\partial u_j^{(L)}}\\
        &= \frac{\partial E(\boldsymbol{w})}{\partial \hat{y}_j}\frac{\partial f^{(L)}\left(u_j^{(L)} \right)}{u_j^{(L)}}
\end{align*}

まとめると、

\frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(L)}} = z_{i}^{(L-1)}\frac{\partial E(\boldsymbol{w})}{\partial \hat{y}_j}\frac{\partial f^{(L)}\left(u_j^{(L)} \right)}{u_j^{(L)}}

というとこで、あとはこれを元に計算を繰り返すことで、任意の $l$ 層目の重みについて順次計算できることが分かりました。

回帰問題なんかでよく使われる例として、損失関数として二乗誤差関数、最終層の活性化関数として恒等写像、中間層の活性化関数としてReLUを仮定してみます。定式化すると

\begin{align}
    E(\boldsymbol{w}) = \frac{1}{2}||\boldsymbol{\hat{y}}-\boldsymbol{y}||^2\\
    f^{(L)}(\boldsymbol{u}) = \boldsymbol{u}\\
    f^{(l)}(\boldsymbol{u}) = ReLU(\boldsymbol{u})
\end{align}

です。$E(\boldsymbol{w})$ に $\frac{1}{2}$ がついてるのは単純に微分して係数を消せるようにするためです。定数倍しても勾配の向きは変化しないので問題はありません。これで各勾配を計算してみると、

\begin{align*}
    \frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(L)}}
        &= z_{i}^{(L-1)}\frac{\partial E(\boldsymbol{w})}{\partial \hat{y}_j}\frac{\partial f^{(L)}\left(u_j^{(L)} \right)}{u_j^{(L)}}\\
        &= z_i^{(L-1)}(\hat{y}_j - y_j)
\end{align*}
\begin{align*}
    \frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(l)}}
        &= z_i^{(l-1)} \sum_k\delta_k^{(l+1)}w_{kj}^{(l+1)}\frac{\partial f^{(l)}\left(u_j^{(l)} \right)}{\partial u_j^{(l)}}\\
        &= \begin{cases}
            z_i^{(l-1)} \sum_k\delta_k^{(l+1)}w_{kj}^{(l+1)} & (u_j^{(l)} \geq 0)\\
            0 & (u_j^{(l)} \lt 0)
        \end{cases}
\end{align*}
    

ということで、非常に単純な加減算・乗算のみで表せることが分かりました。

まとめ

今回は誤差逆伝搬法について個人的にまとめてみました。緑本の方だと勾配の式が

\frac{\partial E(\boldsymbol{w})}{\partial w_{ji}^{(l)}} = \sum_{k=1}^{D_l}\frac{\partial E(\boldsymbol{w})}{\partial \hat{y}_k}\frac{\partial \hat{y}_k}{\partial w_{ji}^{(l)}}

と $\boldsymbol{\hat{y}}$ を使った式で書かれていて、これが非常に分かりにくかったです。総和が $1$ から $D_l$ ($l$ に依存)まで走っているのを見るにこれは $\boldsymbol{u}^{(l)}$ または $\boldsymbol{z}^{(l)}$ の意味で使っているのだと思われますが、直前で $\boldsymbol{\hat{y}}$ を最終層の出力の意味で使っているのでめちゃくちゃ不自然な気がして今回の記事を作成しました。どなたかここの正しい解釈を教えていただけると幸いです...

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?