1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

誤差逆伝播法についてわからなくなった時に見るやつ

Last updated at Posted at 2024-03-26

誤差逆伝播法がわからなくなった時のための備忘録。わかりづらい箇所は今後編集していく予定です。

誤差逆伝播法って何?

ニューラルネットワークの学習アルゴリズムのひとつ。出力層で計算された誤差が、ネットワークを逆伝播していくことで、各層の重みを更新する。

前準備

本記事では、上図のような$L$層ニューラルネットワークを考える。
第$k$層の$i$番目のノードの入力値を$i_i^k$、第$k$層の$i$番目のノードの出力値を$o_i^k$とする。
また、第$k-1$層$i$番目のノードと、第$k$層$j$番目のノードとの間の重みを$w_{i \rightarrow j}^{k}$とする。

このとき、$i_i^k$は前層のノードの出力値の重み付き線形和となるから、
$$
i_i^k = \sum_m w_{m\rightarrow i}^k o_m^{k-1}
$$
である。また、$o_i^k$は入力に活性化関数を適用した値なので、
$$
o_i^k = h(i_i^k)
$$
である。
また、誤差逆伝播法では一般的に$E$は以下のような式で表される。

$$
E = \frac{1}{2} \sum_i( o_i - t_i )^2
$$

ここで、$o_i$は出力層の$i$番目のニューロンの出力値で、$t_i$はそれに対応する教師信号である。このような誤差関数を、二乗誤差関数という。

連鎖律

連鎖律の原理は、合成関数の微分によって説明される。例えば、次の式で表される関数があったとする。
$$
y = f(u) \quad u = g(x)
$$
このとき、$\frac{dy}{dx}$は、合成関数の微分より、以下のように求められる。
$$
\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}
$$
これを多変数関数に拡張したものが、連鎖律である。次の式で表される関数があったとする。
$$
z = f(u,v) \quad u = g(x,y) \quad v = h(x,y)
$$
このとき、$\frac{\partial z}{\partial x}$および$\frac{\partial z}{\partial y}$は、以下のように求められる。
$$
\frac{\partial z}{\partial x} =
\frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial z}{\partial v} \cdot \frac{\partial v}{\partial x} \
\frac{\partial z}{\partial y} =
\frac{\partial z}{\partial u} \cdot \frac{\partial u}{\partial y} + \frac{\partial z}{\partial v} \cdot \frac{\partial v}{\partial y}
$$
さらに変数を多くした場合を見てみよう。次の式で表される関数があったとする。
$$
z = f(u_1,u_2,u_3,\cdots) \quad u_k = g_k(x_1,x_2,x_3,\cdots)
$$
このとき、$\frac{\partial z}{\partial x_k}$は、以下のように求められる。
$$
\frac{\partial z}{\partial x_k} =
\frac{\partial z}{\partial u_1} \cdot \frac{\partial u_1}{\partial x_k} +
\frac{\partial z}{\partial u_2} \cdot \frac{\partial u_2}{\partial x_k} + \cdots
= \sum_i \frac{\partial z}{\partial u_i} \cdot \frac{\partial u_i}{\partial x_k}
$$

つまり、ある合成関数の偏微分は、その合成関数を構成するそれぞれの偏微分の積によって表されるということである。

誤差逆伝播法の式

誤差逆伝播法では、重みを以下の式に従って更新する。
$$
w_{i \rightarrow j}^{k} \leftarrow - \eta \frac{\partial E}{\partial w_{i \rightarrow j}^k}
$$
上式について考えてみよう。重みの更新をするためには、誤差関数$E$を現在の重みで偏微分すればよい。
しかし、$E$は出力値$o_i^k$の関数なので、重みで偏微分するには工夫が必要である。

重みの計算

出力層における重み

はじめに、出力層における重み$w_{i \rightarrow j}^{L}$がどのように計算されるのかを求めていこう。

まず、出力層における重みの更新式は、以下のようになる。
$$
w_{i \rightarrow j}^{L} \leftarrow - \eta \frac{\partial E}{\partial w_{i \rightarrow j}^L}
$$

更新式の$\frac{\partial E}{\partial w_{i \rightarrow j}^{L}}$は、次のように変形できる。
$$
\frac{\partial E}{\partial w_{i \rightarrow j}^L} =
\frac{\partial E}{\partial i_j^L} \cdot
\frac{\partial i_j^L}{\partial w_{i \rightarrow j}^L}
$$
これはどのように変形しているのだろうか。
まず、$E$は、出力層の各ノードの出力値$o_1^L,o_2^L,\cdots$の関数だった、ということを思い出してほしい。
$$
E(o_1^L,o_2^L,\cdots o_j^L,\cdots )
$$
そして、ノードの出力値$o_j^L$は、入力値$i_j^L$に活性化関数を適用した値であるから、
$$
o_j^L = h(i_j^L)
$$
つまり、$E$は$i_j^L$の合成関数である。
$$
E(h(i_1^L),h(i_2^L),\cdots h(i_j^L),\cdots )
$$
よって、連鎖律より、
$$
\frac{\partial E}{\partial w_{i \rightarrow j}^L} =
\sum_k \frac{\partial E}{\partial i_k^L} \cdot
\frac{\partial i_k^L}{\partial w_{i \rightarrow j}^L}
$$
が成り立つ。
ここで、$i_k^L=\sum_m w_{m\rightarrow k}^L o_m^{L-1}$であることから、$w_{i \rightarrow j}^{L}$が出現するのは$i_j^L$のみであり、それ以外の偏微分の値は0になる。
$$
\frac{\partial E}{\partial w_{i \rightarrow j}^L} =
\frac{\partial E}{\partial i_j^L} \cdot
\frac{\partial i_j^L}{\partial w_{i \rightarrow j}^L}
$$
さらに、入力値は重みと前層の出力値の線形和で表されるから、
$$
\frac{\partial E}{\partial i_j^L} \cdot
\frac{\partial i_j^L}{\partial w_{i \rightarrow j}^L} =
\frac{\partial E}{\partial i_j^L} \cdot
\frac{\partial \sum_m w_{m\rightarrow j}^L o_m^{L-1}}{\partial w_{i \rightarrow j}^L} =
\frac{\partial E}{\partial i_j^L} \cdot o_i^{L-1}
$$
ここで、新たな変数$\delta_j^L$を次のように定義する。
$$
\delta_j^L = \frac{\partial E}{\partial i_j^L}
$$
では、次に$\delta_j^L$について考えてみよう。
連鎖律より、次のように変形できる。
$$
\delta_j^L = \frac{\partial E}{\partial i_j^L} =
\sum_k \frac{\partial E}{\partial o_k^L} \cdot
\frac{\partial o_k^L}{\partial i_j^L} =
\frac{\partial E}{\partial o_j^L} \cdot \frac{\partial o_j^L}{\partial i_j^L}
$$
ここで、$o_i^k$は入力に活性化関数を適用した値なので、$\frac{\partial o_j^L}{\partial i_j^L}$は、活性化関数の導関数であるから、$h'(i_j^L)=\frac{\partial o_j^L}{\partial i_j^L}$と表すと
$$
\delta_j^L = \frac{\partial E}{\partial o_j^L} \cdot h'(i_j^L)
$$
さらに、
$$
E = \frac{1}{2} \sum_i( o_i - t_i )^2
$$
であるから、
$$
\frac{\partial E}{\partial o_j^L} = o_j^L - t_j
$$
となる。
したがって、$\delta_j^L$は、以下のようになる。
$$
\delta_j^L = (o_j^L - t_j)h'(i_j^L)
$$
最終的に、$\frac{\partial E}{\partial w_{i \rightarrow j}^{L}}$は以下のように計算される。
$$
\frac{\partial E}{\partial w_{i \rightarrow j}^L} =
(o_j^L - t_j)h'(i_j^L)o_i^{L-1}
$$

中間層における重み

次に、中間層における重み$w_{i \rightarrow j}^{k}$がどのように計算されるのかを求めていこう。
まず、第$k$層$(0 \leq k < L)$における重みの更新式は、以下のようになる。
$$
w_{i \rightarrow j}^{k} \leftarrow - \eta \frac{\partial E}{\partial w_{i \rightarrow j}^{k}}
$$
最終層のときの式変形と同じように、連鎖律を用いて$\frac{\partial E}{\partial w_{i \rightarrow j}^{k}}$を変形すると、
$$
\frac{\partial E}{\partial w_{i \rightarrow j}^{k}} =
\frac{\partial E}{\partial i_j^k} \cdot
\frac{\partial i_j^k}{\partial w_{i \rightarrow j}^k} =
\delta_j^k o_i^{k-1}
$$
となる。
次に、$\delta_j^k$について考える。連鎖律を用いて次のように変形する。
$$
\delta_j^k = \frac{\partial E}{\partial i_j^k} =
\sum_m \frac{\partial E}{\partial i_m^{k+1}} \cdot
\frac{\partial i_m^{k+1}}{\partial i_j^k} =
\sum_m \delta_m^{k+1}
\frac{\partial i_m^{k+1}}{\partial i_j^k}
$$
さらに、出現した$\frac{\partial i_m^{k+1}}{\partial i_j^k}$に対して連鎖律を適用する。
$$
\frac{\partial i_m^{k+1}}{\partial i_j^k} =
\sum_n \frac{\partial i_m^{k+1}}{\partial o_n^k} \cdot
\frac{\partial o_n^k}{\partial i_j^k}
$$
が成り立つ。
ここで、$o_n^k=h(i_n^k)$であることから、$i_j^k$が出現するのは$o_j^k$のみであり、それ以外の偏微分の値は0になる。ゆえに
$$
\frac{\partial i_m^{k+1}}{\partial i_j^k} =
\frac{\partial i_m^{k+1}}{\partial o_j^k} \cdot
\frac{\partial o_j^k}{\partial i_j^k}
$$
そして、$\frac{\partial i_m^{k+1}}{\partial o_j^k}$は、
$$
i_m^{k+1} = \sum_l w_{l\rightarrow m}^{k+1} o_n^k
$$
であるから、
$$
\frac{\partial i_m^{k+1}}{\partial o_j^k} = w_{j\rightarrow m}^{k+1}
$$
となる。また、$\frac{\partial o_j^k}{\partial i_j^k}=h'(i_j^k)$であるから、$\delta_j^k$は次のように表される。
$$
\delta_j^k =
\sum_m \delta_m^{k+1}
\frac{\partial i_m^{k+1}}{\partial o_j^k} \cdot
\frac{\partial o_j^k}{\partial i_j^k} =
h'(i_j^k)\sum_m \delta_m^{k+1}w_{j\rightarrow m}^{k+1}
$$

参考にしたもの

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?