0. 背景
- 「ゼロから作る Deep Learning - Pythonで学ぶディープラーニングの理論と実装」を手にニューラルネットワークの勉強をしていたのですが、5章・誤差逆伝播のAffilneレイヤの逆伝播の式変形の理解に時間がかかったので、小さな次元で計算してみました。その過程を自分の備忘録的に記述します。
- 目標としては、低次元ながらも次の式が成分計算で求められることとします。($T$は転置行列を意味します)
\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= \frac{\partial L}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T \\
\frac{\partial L}{\partial \boldsymbol{W}} &=
\boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}
- 初学者ですので、計算ミス、誤字などあると思います。優しくご教授いただけますと幸いです。
1. 低次元で活性化関数を考慮せず地道に成分計算をする
入力は、$\boldsymbol{x} = (x_1; x_2)$の2次元だとします。
この入力に対して、第1層目への出力を3つにしたい場合、$(2, 3)$の行列を右からかけます。
\boldsymbol{W} = \begin{pmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32}
\end{pmatrix}
出力$\boldsymbol{Y}$は
\begin{align}
\boldsymbol{Y} &= \boldsymbol{X} \cdot \boldsymbol{W} \\
&=
\begin{pmatrix}
x_1 & x_2
\end{pmatrix}
\begin{pmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
w_{11}x_1+w_{12}x_2 & w_{21}x_1+w_{22}x_2 & w_{31}x_1+w_{32}x_2
\end{pmatrix} \\
&=
\begin{pmatrix}
y_1 & y_2 & y_3
\end{pmatrix} \tag{1.1}
\end{align}
となります。
損失関数$L$の入力$\boldsymbol{X}$による偏微分は$x_1, x_2$が$y_1, y_2, y_3$に出てくることに注意すると、
\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &=
\begin{pmatrix}
\frac{\partial L}{\partial x_1} & \frac{\partial L}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} & \frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_2}
\end{pmatrix}
\end{align}
ここで
\begin{align}
\frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} =
\begin{pmatrix}
\frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \frac{\partial L}{\partial y_3}
\end{pmatrix}
\cdot
\begin{pmatrix}
\frac{\partial y_1}{\partial x_1} \\
\frac{\partial y_2}{\partial x_1} \\
\frac{\partial y_3}{\partial x_1}
\end{pmatrix}
\end{align}
なので
\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial x_1} +
\frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial x_1} +
\frac{\partial L}{\partial y_3} \frac{\partial y_3}{\partial x_1} &
\frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial x_2} +
\frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial x_2} +
\frac{\partial L}{\partial y_3} \frac{\partial y_3}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} w_{11} +
\frac{\partial L}{\partial y_2} w_{21} +
\frac{\partial L}{\partial y_3} w_{31} &
\frac{\partial L}{\partial y_1} w_{12} +
\frac{\partial L}{\partial y_2} w_{22} +
\frac{\partial L}{\partial y_3} w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \frac{\partial L}{\partial y_3}
\end{pmatrix}
\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix} \\
&= \frac{\partial L}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T
\end{align}
一方、損失関数$L$の重み$\boldsymbol{W}$による偏微分は
\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &=
\begin{pmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{32}}
\end{pmatrix} \\
\end{align}
です。
ここで、$(1.1)$式において、$w_{11}$は$y_1$だけに、$w_{12}$は$y_1$だけに、・・・$w_{31}$は$y_3$だけに、$w_{32}$は$y_3$だけに出てくることに注意すると、
\begin{align}
\frac{\partial L}{\partial w_{11}} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} \\
\frac{\partial L}{\partial w_{12}} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} \\
\frac{\partial L}{\partial w_{22}} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} \\
\frac{\partial L}{\partial w_{31}} &= \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{32}} &= \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{align}
となります。
したがって、$\frac{\partial L}{\partial \boldsymbol{W}}$は
\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &=
\begin{pmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} &
\frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} &
\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} & \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} &
\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1}x_1 &
\frac{\partial L}{\partial y_2}x_1 &
\frac{\partial L}{\partial y_3}x_1 \\
\frac{\partial L}{\partial y_1}x_2 &
\frac{\partial L}{\partial y_2}x_2 &
\frac{\partial L}{\partial y_3}x_2
\end{pmatrix} \\
&= \begin{pmatrix}
x_1 \\
x_2
\end{pmatrix}
\begin{pmatrix}
\frac{\partial L}{\partial y_1} &
\frac{\partial L}{\partial y_2} &
\frac{\partial L}{\partial y_3}
\end{pmatrix} \\
&= \boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}
これで、
\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &=
\boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}
の導出が(低次元で、活性化関数も考慮していませんが)できました。
2. 活性化関数と2層目を考慮してみる
1では、各層の活性化関数を無視したり、2層目以降を考慮していませんでした。ここでは、2層目を導入し、活性化関数も導入してみようと思います。 1層目の活性化関数を$h$、2層目(出力層)の活性化関数を$\sigma$とおきます。 最初に結果をみると、このようになります。L = \sigma(h(\boldsymbol{X} \cdot \boldsymbol{W}) \cdot \boldsymbol{W}^{(2)})
一つずつ見ていきます。
入力$\boldsymbol{X}$
\boldsymbol{X} = (x_1\; x_2)
1層目の入力
\boldsymbol{Y} = \boldsymbol{X}\cdot \boldsymbol{W}
1層目の出力(活性化関数を作用させる)
\begin{align}
h(\boldsymbol{Y}) &= h(\boldsymbol{X}\cdot \boldsymbol{W}) \\
&= \begin{pmatrix}
h(y_1) & h(y_2) & h(y_3)
\end{pmatrix}
\end{align}
2層目(出力層)の入力
\begin{align}
z &= h(\boldsymbol{Y})\cdot \boldsymbol{W}^{(2)} \\
&=
\begin{pmatrix}
h(y_1) & h(y_2) & h(y_3)
\end{pmatrix}
\begin{pmatrix}
w_1^{(2)} \\
w_2^{(2)} \\
w_3^{(2)}
\end{pmatrix} \\
&= w_1^{(2)}h(y_1) + w_2^{(2)}h(y_2) + w_3^{(2)}h(y_3)
\end{align}
2層目(出力層)の出力
\begin{align}
L &= \sigma (z) \\
&= \sigma (w_1^{(2)}h(y_1) + w_2^{(2)}h(y_2) + w_3^{(2)}h(y_3))
\end{align}
$\boldsymbol{X}$と$\boldsymbol{W}$の偏微分は1と同じですが、再掲します。
\begin{align}
\frac{\partial \sigma}{\partial \boldsymbol{X}} &=
\begin{pmatrix}
\frac{\partial \sigma}{\partial x_1} & \frac{\partial \sigma}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} & \frac{\partial \sigma}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} \frac{\partial y_1}{\partial x_1} +
\frac{\partial \sigma}{\partial y_2} \frac{\partial y_2}{\partial x_1} +
\frac{\partial \sigma}{\partial y_3} \frac{\partial y_3}{\partial x_1} &
\frac{\partial \sigma}{\partial y_1} \frac{\partial y_1}{\partial x_2} +
\frac{\partial \sigma}{\partial y_2} \frac{\partial y_2}{\partial x_2} +
\frac{\partial \sigma}{\partial y_3} \frac{\partial y_3}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} w_{11} +
\frac{\partial \sigma}{\partial y_2} w_{21} +
\frac{\partial \sigma}{\partial y_3} w_{31} &
\frac{\partial \sigma}{\partial y_1} w_{12} +
\frac{\partial \sigma}{\partial y_2} w_{22} +
\frac{\partial \sigma}{\partial y_3} w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} & \frac{\partial \sigma}{\partial y_2} & \frac{\partial \sigma}{\partial y_3}
\end{pmatrix}
\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix} \\
&= \frac{\partial \sigma}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T \\
\frac{\partial L}{\partial \boldsymbol{W}} &=
\begin{pmatrix}
\frac{\partial \sigma}{\partial w_{11}} & \frac{\partial \sigma}{\partial w_{21}} & \frac{\partial \sigma}{\partial w_{31}} \\
\frac{\partial \sigma}{\partial w_{12}} & \frac{\partial \sigma}{\partial w_{22}} & \frac{\partial \sigma}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} &
\frac{\partial \sigma}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} &
\frac{\partial \sigma}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial \sigma}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} &
\frac{\partial \sigma}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} &
\frac{\partial \sigma}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1}x_1 &
\frac{\partial \sigma}{\partial y_2}x_1 &
\frac{\partial \sigma}{\partial y_3}x_1 \\
\frac{\partial \sigma}{\partial y_1}x_2 &
\frac{\partial \sigma}{\partial y_2}x_2 &
\frac{\partial \sigma}{\partial y_3}x_2
\end{pmatrix} \\
&=
\begin{pmatrix}
x_1 \\
x_2
\end{pmatrix}
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} &
\frac{\partial \sigma}{\partial y_2} &
\frac{\partial \sigma}{\partial y_3}
\end{pmatrix} \\
&= \boldsymbol{X}^T \cdot \frac{\partial \sigma}{\partial \boldsymbol{Y}}
\end{align}
蛇足ですが、成分表示をすると、
\frac{\partial L}{\partial w_{ji}} =
\frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial w_{ji}}
となります。
3. より一般的な式にする
1, 2では、層の数も限られていましたし、各層の次元も少ないケースでしたが、より一般的に第$i$層、第$j$層、第$k$層に注目してみたいと思います。
ここでは、行列は成分表示で表します。
\begin{align}
a_j^{(j)} &= \sum_i w_{ji}^{(j)}z_i^{(i)} \\
z_j^{(j)} &= h(a_j^{(j)})
\end{align}
重み$w_{ji}^{(j)}$による微分は、$w_{ji}^{(j)}$が$a_j^{(j)}$のみに出現することから、
\begin{align}
\frac{\partial L}{\partial w_{ji}^{(j)}} &= \frac{\partial L}{\partial a_j^{(j)}}\frac{\partial a_j^{(j)}}{\partial w_{ji}^{(j)}} \\
&= \frac{\partial L}{\partial a_j^{(j)}}\frac{\partial}{\partial w_{ji}^{(j)}} ( \sum_i w_{ji}^{(j)}z_i^{(i)} ) \\
&= \frac{\partial L}{\partial a_j^{(j)}}z_i
\end{align}
また、$j$番目の入力$a_j^{(j)}$による微分は、$a_j^{(j)}$が$a_k^{(k)}$の変化を通じてしか誤差関数を変化させないことから、
\begin{align}
\frac{\partial L}{\partial a_j^{(j)}} &=
\sum_k \frac{\partial L}{\partial a_k^{(k)}}\frac{\partial a_k^{(k)}}{\partial a_j^{(j)}} \\
&= \sum_k \frac{\partial L}{\partial a_k^{(k)}}\frac{\partial}{\partial a_j^{(j)}} ( \sum_j w_{kj}^{(k)}z_j^{(j)} ) \\
&= \sum_k \frac{\partial L}{\partial a_k^{(k)}} w_{kj} \frac{\partial h(a_j^{(j)})}{\partial a_j^{(j)}} \\
&= \frac{\partial h(a_j^{(j)})}{\partial a_j^{(j)}}\sum_k w_{kj} \frac{\partial L}{\partial a_k^{(k)}}
\end{align}
となる。