はじめに
NNでは、学習を進めるために、勾配降下法を用います。
勾配降下法では、その名の通り、「勾配」を計算しなければなりません。
より具体的に述べれば、損失関数のパラメータ空間での勾配を求める必要があります。
安直に勾配を求める方法として、勾配の各成分を次のような数値微分で計算する方法があります。
L'(w) = \frac{L(w + h) - L (w - h)}{2h}
このように書くと大した計算ではないように見えますが、右辺の$L(w)$を計算するために、各層でどのような影響が$w$に伝わるか、出力層まで追跡しなければならず、DNNではその計算コストが莫迦になりません。
本記事では、そのような悩みを解決してくれる誤差逆伝播法というアルゴリムを、何番煎じかわかりませんが、紹介したいと思います。このアルゴリズムを使えば、DNNにおいて数値微分よりも計算コストをグッと抑えて、$L'(w)$を求めることができます。
その仕組みを説明するために、先んじて、「連鎖率」という概念を導入したいと思います。
連鎖率
損失関数は、層を重ねて様々な関数を入力やパラメータに作用させた結果とみなすことができ、その微分は連鎖率を用いて計算できます。
例えば、一変数に対して2つ、または3つの関数を作用させる場合は、
その微分を次のように表現することが可能で、
J = j(k(w)) \\
\ \\
\frac{d J}{d w}
= \frac{d j(k)}{d k} \frac{d k(w)}{d w} \\
\ \\
\ \\
H = h\left( j\left( k(w) \right) \right) \\
\ \\
\frac{d H}{d w}
= \frac{d h(j)}{d j} \frac{d j(k)}{d k} \frac{d k(w)}{d w} \\
上から2,4番目の式のように、微分を連ねた形を連鎖率と呼びます。
連鎖率は、微分係数($ \frac{d H}{d w} $)を求めるために、$w$が直接依存する微分係数($\frac{dk(w)}{dw}$)をどのように増幅すればよいか($\frac{d h(j)}{d j} \frac{d j(k)}{d k}$をかける)、教えてくれます。
3関数の場合をより一般化してみましょう。
微分したい関数には多変数関数を想定し、その引数となる関数も多変数関数だとします。
E = E(Z_1, Z_2, ..., Z_N) \\
Z_n = Z_n(U_1, U_2, ..., U_M) \\
U_m = U_m(w_1, w_2, ..., w_I) \\
あるパラメータ$w_i$による偏微分は、
\frac{\partial E}{\partial w_i}
= \Sigma_n \frac{\partial E}{\partial Z_n} \frac{\partial Z_n}{\partial w_i}
= \Sigma_n \frac{\partial E}{\partial Z_n} \left( \Sigma_m \frac{\partial Z_n}{\partial U_m} \frac{\partial U_m}{\partial w_i} \right) \\
$w_i$の影響の筋道をすべて考慮し、その大きさをすべて足し合わせて求まります。
添え字のややこしさはあるものの、考え方はシンプルです。
そして、一般化しても、$w_i$が直接依存する微分係数$\left( \frac{\partial U_m}{\partial w_i} \right)$に増幅率をかけているという見方ができることに変わりはありません。
以上をまとめてみます。
何層も関数に関数が重ねられた深い関数について、あるパラメータによる微分を求めたいときがあります。そのとき、連鎖率は、パラメータが直接依存する微分係数に増幅率をかけて、目的の微分を表現する方法だとわかりました。
誤差逆伝播法と連鎖率
さて、ここからは主題である誤差逆伝播法について、解説していきます。
誤差逆伝播法の肝は、連鎖率の説明で紹介した増幅率の使いまわしにある と、私は考えています。
というわけで、まずは勾配の各成分を連鎖率で表していきましょう。
損失関数や活性化関数やパラメータなどの表記方法から導入します。
入力層を$0$層目として、$l$層目の$j$番目のノードを考えます。
$j$番目のノードの値$Z^l_j$は、手前の層からの入力の重み付き和$U^l_j$に活性化関数$f^l$を作用させたもの$\left( Z^l_j = f^l(U^l_j) \right)$です。
重み付き和$U^l_j$は、パラメータ$w^l_{ij}$を使って、このように表記されます。
U^l_j = w^l_{0j} Z^{l-1}_0 + w^l_{1j} Z^{l-1}_1 + w^l_{2j} Z^{l-1}_2 + ... \\
= \Sigma_i w^l_{ij} Z^{l-1}_i \\
※ Z^{l-1}_0 = 1
損失関数$E$についても考えます。
$E$は、出力層($L$番目の層)の全ノード$Z^L_j$に対して、正解データ$d_j$との差を計るうまい関数ですので、$E(... , Z^L_{j-1}, d_{j-1}, Z^L_j, d_j, Z^L_{j+1}, d_{j+1}, ...)$と書けるでしょう。
ここまでに導入した表記を用いると、我々が知りたいのは、$\frac{\partial E}{\partial w^l_{ij}} $ですので、次はそれを計算してみましょう。
$w^l_{ij}$が直接依存するのは、$U^l_j$ただ一つですので、連鎖率は次のように書けます。
\frac{\partial E}{\partial w^l_{ij}}
= \frac{\partial E}{\partial U^l_{j}} \frac{\partial U^l_{j}}{\partial w^l_{ij}}
右辺だけ見ると、アインシュタインの縮約記号を使用しているように見えますが、$j$で和を取っていないことに注意してください。
$l\neq L$だとすると、$U^l_j$は、$E$の直接の入力になっていません。
そのため、上記の式の増幅率を偏微分とみて、そこに連鎖率の表現を適用することができます。
見通しをよくするために、上記の式の増幅率を$\delta ^l_j$と命名しますと、
\delta ^l_j = \frac{\partial E}{\partial U^l_{j}}
= \Sigma_k \frac{\partial E}{\partial U^{l+1}_{k}}
\frac{\partial U^{l+1}_{k}}{\partial Z^l_{j}}
\frac{\partial Z^{l}_{j}}{\partial U^l_{j}}
と連鎖率を計算できます。
お分かりになりますでしょうか。右辺において、左辺の増幅率$\delta ^l_j$に比べて、一つ層の次数が上の増幅率$\delta_j ^{l+1}$が含まれています。
この式とその一つ前の式を眺めると大変興味深い事実を得られます。
それは、ある層での微分$\frac{\partial E}{\partial w^l_{ij}} $の計算には、「ちょうど一層だけ」深い層で$\frac{\partial E}{\partial w^{l+1}_{ij}} $を計算する際に得た情報(増幅率)があれば十分だということです。これは、増幅率を記録していけば、毎度最下層から計算せずに済むことを意味しています。
増幅率についてメモを取り使いまわすこと、その結果、計算量を減らすこと、
誤差逆伝播法が単純な数値微分より優れている点は、まさしくここにあるのです。
勘所は、以上なのですが、層間での増幅率の関係をもう少し計算を進めて書くと、以下のようになります。
\delta ^l_j = \Sigma_k \frac{\partial E}{\partial U^{l+1}_{k}}
\frac{\partial U^{l+1}_{k}}{\partial Z^l_{j}}
\frac{\partial Z^{l}_{j}}{\partial U^l_{j}}
=\Sigma_k \delta ^{l+1}_k w^{l+1}_{jk} (f^l(U^l_j))'
計算には、表記方法の説明で記載した、$U^{l+1}_k$と$Z^l_j$の関係と、$Z^l_j$と$U^{l}_j$の関係を用いました。
$(f^l(U^l_j))'$を求めるためには、活性化関数の導関数を予め解析的に計算しておき、その後$l$層目の$j番目$のノードにおける重み付き和を代入する必要があります。つまり、順伝播で得た$U^l_j$の値を逆伝播中に使えるように記録しておかねばならないのです。
同様のことは、$w^l_{ij}$が直接依存する$\frac{\partial U^l_{j}}{\partial w^l_{ij}} = Z^{l-1}_{j}$にも言えます($Z^l_j$の保存が必要ということです)。
このように、誤差逆伝播法では、計算量を減らすためにメモリを多く必要とします。この点を実装の際に気を付けなければいけません。
最後に、増幅率の$\delta^{L}_j$を求めましょう。層間の増幅率の関係式を利用するにも、別の方法で求めた最下層の増幅率の値が必要となります。
\delta ^L_j = \frac{\partial E}{\partial U^L_{j}}
= \frac{\partial E}{\partial Z^{L}_{j}}
\frac{\partial Z^{L}_{j}}{\partial U^L_{j}}
= \frac{\partial E}{\partial Z^{L}_{j}} (f^L(U^L_j))'
例えば、
E = \frac{1}{2} \Sigma _i (Z_i - d_i)^2 \\
f^L(U^L_j) = U^L_j
という損失関数と出力層の活性化関数を考えてみると(回帰問題使われる関数です)、
\delta ^L_j = \frac{\partial E}{\partial U^L_{j}}
= \frac{\partial E}{\partial Z^{L}_{j}} (f^L(U^L_j))'
= (Z_i - d_i)
増幅率は出力値と正解データの誤差となり、層間の増幅率の関係はまさしく誤差の伝播を表していると言えます。
タイトル回収ができました。
おわりに
誤差逆伝播法は、本記事で言うところの増幅率(誤差)を使いまわす(記録して伝播する)方法だと、私は考えています!計算量を減らす代わりに、必要なメモリが増える点には注意しましょう!
本記事では、オーダー計算をまじめにやれたらよかったです!今後の課題ですね。。。
お読みいただきありがとうございました!
参考にさせていただいたもの
・ヨビノリさんの動画
・Project AIMさんの動画