Help us understand the problem. What is going on with this article?

Affineレイヤの逆伝播を地道に成分計算する

More than 1 year has passed since last update.

0. 背景

  • ゼロから作る Deep Learning - Pythonで学ぶディープラーニングの理論と実装」を手にニューラルネットワークの勉強をしていたのですが、5章・誤差逆伝播のAffilneレイヤの逆伝播の式変形の理解に時間がかかったので、小さな次元で計算してみました。その過程を自分の備忘録的に記述します。
  • 目標としては、低次元ながらも次の式が成分計算で求められることとします。($T$は転置行列を意味します)
\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= \frac{\partial L}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T  \\
\frac{\partial L}{\partial \boldsymbol{W}} &=
\boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}
  • 初学者ですので、計算ミス、誤字などあると思います。優しくご教授いただけますと幸いです。

1. 低次元で活性化関数を考慮せず地道に成分計算をする

1
入力として$\boldsymbol{x} = (x_1\; x_2)$の2次元だとします。
この入力に対して、第1層目への出力を3つにしたい場合、$(2, 3)$の行列を右からかけます。

\boldsymbol{W} = \begin{pmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32} 
\end{pmatrix}

出力$\boldsymbol{Y}$は

\begin{align}
\boldsymbol{Y} &= \boldsymbol{X} \cdot \boldsymbol{W} \\
&=
\begin{pmatrix}
x_1 & x_2
\end{pmatrix}
\begin{pmatrix}
w_{11} & w_{21} & w_{31} \\
w_{12} & w_{22} & w_{32} 
\end{pmatrix} \\
&= 
\begin{pmatrix}
w_{11}x_1+w_{12}x_2 & w_{21}x_1+w_{22}x_2 & w_{31}x_1+w_{32}x_2
\end{pmatrix} \\
&=
\begin{pmatrix}
y_1 & y_2 & y_3
\end{pmatrix} \tag{1.1}
\end{align}

となります。
損失関数$L$の入力$\boldsymbol{X}$による偏微分は$x_1, x_2$が$y_1, y_2, y_3$に出てくることに注意すると、

\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial x_1} & \frac{\partial L}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} & \frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_2}
\end{pmatrix} 
\end{align}

ここで

\begin{align}
\frac{\partial L}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} = 
\begin{pmatrix}
\frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \frac{\partial L}{\partial y_3}
\end{pmatrix}
\cdot
\begin{pmatrix}
\frac{\partial y_1}{\partial x_1} \\
\frac{\partial y_2}{\partial x_1} \\
\frac{\partial y_3}{\partial x_1} 
\end{pmatrix}
\end{align}

なので

\begin{align}
\frac{\partial L}{\partial \boldsymbol{X}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial x_1} +
\frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial x_1} +
\frac{\partial L}{\partial y_3} \frac{\partial y_3}{\partial x_1} &
\frac{\partial L}{\partial y_1} \frac{\partial y_1}{\partial x_2} +
\frac{\partial L}{\partial y_2} \frac{\partial y_2}{\partial x_2} +
\frac{\partial L}{\partial y_3} \frac{\partial y_3}{\partial x_2} 
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} w_{11} +
\frac{\partial L}{\partial y_2} w_{21} +
\frac{\partial L}{\partial y_3} w_{31} &
\frac{\partial L}{\partial y_1} w_{12} +
\frac{\partial L}{\partial y_2} w_{22} +
\frac{\partial L}{\partial y_3} w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1} & \frac{\partial L}{\partial y_2} & \frac{\partial L}{\partial y_3}
\end{pmatrix} 
\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix} \\
&= \frac{\partial L}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T

\end{align}

一方、損失関数$L$の重み$\boldsymbol{W}$による偏微分は

\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{32}} 
\end{pmatrix} \\
\end{align}

です。
ここで、$(1.1)$式において、$w_{11}$は$y_1$だけに、$w_{12}$は$y_1$だけに、・・・$w_{31}$は$y_3$だけに、$w_{32}$は$y_3$だけに出てくることに注意すると、

\begin{align}
\frac{\partial L}{\partial w_{11}} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} \\
\frac{\partial L}{\partial w_{12}} &= \frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} \\
\frac{\partial L}{\partial w_{21}} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} \\
\frac{\partial L}{\partial w_{22}} &= \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} \\
\frac{\partial L}{\partial w_{31}} &= \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{32}} &= \frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{align}

となります。
したがって、$\frac{\partial L}{\partial \boldsymbol{W}}$は

\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &= 
\begin{pmatrix}
\frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{31}} \\
\frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{32}} 
\end{pmatrix} \\
&= 
\begin{pmatrix}
\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} &
\frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} &
\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial L}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} & \frac{\partial L}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} &
\frac{\partial L}{\partial y_3}\frac{\partial y_3}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial L}{\partial y_1}x_1 &
\frac{\partial L}{\partial y_2}x_1 &
\frac{\partial L}{\partial y_3}x_1 \\
\frac{\partial L}{\partial y_1}x_2 &
\frac{\partial L}{\partial y_2}x_2 &
\frac{\partial L}{\partial y_3}x_2
\end{pmatrix} \\
&= \begin{pmatrix}
x_1 \\
x_2
\end{pmatrix}
\begin{pmatrix}
\frac{\partial L}{\partial y_1} &
\frac{\partial L}{\partial y_2} &
\frac{\partial L}{\partial y_3} 
\end{pmatrix} \\
&= \boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}

これで、

\begin{align}
\frac{\partial L}{\partial \boldsymbol{W}} &=
\boldsymbol{X}^T \cdot \frac{\partial L}{\partial \boldsymbol{Y}}
\end{align}

の導出が(低次元で、活性化関数も考慮していませんが)できました。

2. 活性化関数と2層目を考慮してみる

1
1では、各層の活性化関数を無視したり、2層目以降を考慮していませんでした。ここでは、2層目を導入し、活性化関数も導入してみようと思います。
1層目の活性化関数を$h$、2層目(出力層)の活性化関数を$\sigma$とおきます。
最初に結果をみると、このようになります。

L = \sigma(h(\boldsymbol{X} \cdot \boldsymbol{W}) \cdot \boldsymbol{W}^{(2)})

一つずつ見ていきます。
入力$\boldsymbol{X}$

\boldsymbol{X} = (x_1\; x_2)

1層目の入力

\boldsymbol{Y} = \boldsymbol{X}\cdot \boldsymbol{W}

1層目の出力(活性化関数を作用させる)

\begin{align}
h(\boldsymbol{Y}) &= h(\boldsymbol{X}\cdot \boldsymbol{W}) \\
&= \begin{pmatrix}
h(y_1) & h(y_2) & h(y_3)
\end{pmatrix}
\end{align}

2層目(出力層)の入力

\begin{align}
z &= h(\boldsymbol{Y})\cdot \boldsymbol{W}^{(2)} \\
&=
\begin{pmatrix}
h(y_1) & h(y_2) & h(y_3)
\end{pmatrix}
\begin{pmatrix}
w_1^{(2)} \\
w_2^{(2)} \\
w_3^{(2)}
\end{pmatrix} \\
&= w_1^{(2)}h(y_1) + w_2^{(2)}h(y_2) + w_3^{(2)}h(y_3)

\end{align}

2層目(出力層)の出力

\begin{align}
L &= \sigma (z) \\
&= \sigma (w_1^{(2)}h(y_1) + w_2^{(2)}h(y_2) + w_3^{(2)}h(y_3))
\end{align}

$\boldsymbol{X}$と$\boldsymbol{W}$の偏微分は1と同じですが、再掲します。

\begin{align}
\frac{\partial \sigma}{\partial \boldsymbol{X}} &= 
\begin{pmatrix}
\frac{\partial \sigma}{\partial x_1} & \frac{\partial \sigma}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_1} & \frac{\partial \sigma}{\partial \boldsymbol{Y}} \cdot \frac{\partial \boldsymbol{Y}}{\partial x_2}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} \frac{\partial y_1}{\partial x_1} +
\frac{\partial \sigma}{\partial y_2} \frac{\partial y_2}{\partial x_1} +
\frac{\partial \sigma}{\partial y_3} \frac{\partial y_3}{\partial x_1} &
\frac{\partial \sigma}{\partial y_1} \frac{\partial y_1}{\partial x_2} +
\frac{\partial \sigma}{\partial y_2} \frac{\partial y_2}{\partial x_2} +
\frac{\partial \sigma}{\partial y_3} \frac{\partial y_3}{\partial x_2} 
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} w_{11} +
\frac{\partial \sigma}{\partial y_2} w_{21} +
\frac{\partial \sigma}{\partial y_3} w_{31} &
\frac{\partial \sigma}{\partial y_1} w_{12} +
\frac{\partial \sigma}{\partial y_2} w_{22} +
\frac{\partial \sigma}{\partial y_3} w_{32}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} & \frac{\partial \sigma}{\partial y_2} & \frac{\partial \sigma}{\partial y_3}
\end{pmatrix} 
\begin{pmatrix}
w_{11} & w_{12} \\
w_{21} & w_{22} \\
w_{31} & w_{32}
\end{pmatrix} \\
&= \frac{\partial \sigma}{\partial \boldsymbol{Y}}\cdot \boldsymbol{W}^T \\
\frac{\partial L}{\partial \boldsymbol{W}} &= 
\begin{pmatrix}
\frac{\partial \sigma}{\partial w_{11}} & \frac{\partial \sigma}{\partial w_{21}} & \frac{\partial \sigma}{\partial w_{31}} \\
\frac{\partial \sigma}{\partial w_{12}} & \frac{\partial \sigma}{\partial w_{22}} & \frac{\partial \sigma}{\partial w_{32}}
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1}\frac{\partial y_1}{\partial w_{11}} &
\frac{\partial \sigma}{\partial y_2}\frac{\partial y_2}{\partial w_{21}} &
\frac{\partial \sigma}{\partial y_3}\frac{\partial y_3}{\partial w_{31}} \\
\frac{\partial \sigma}{\partial y_1}\frac{\partial y_1}{\partial w_{12}} &
\frac{\partial \sigma}{\partial y_2}\frac{\partial y_2}{\partial w_{22}} &
\frac{\partial \sigma}{\partial y_3}\frac{\partial y_3}{\partial w_{32}} 
\end{pmatrix} \\
&=
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1}x_1 &
\frac{\partial \sigma}{\partial y_2}x_1 &
\frac{\partial \sigma}{\partial y_3}x_1 \\
\frac{\partial \sigma}{\partial y_1}x_2 &
\frac{\partial \sigma}{\partial y_2}x_2 &
\frac{\partial \sigma}{\partial y_3}x_2 
\end{pmatrix} \\
&=
\begin{pmatrix}
x_1 \\
x_2 
\end{pmatrix}
\begin{pmatrix}
\frac{\partial \sigma}{\partial y_1} &
\frac{\partial \sigma}{\partial y_2} &
\frac{\partial \sigma}{\partial y_3}  
\end{pmatrix} \\
&= \boldsymbol{X}^T \cdot \frac{\partial \sigma}{\partial \boldsymbol{Y}}
\end{align}

蛇足ですが、成分表示をすると、

\frac{\partial L}{\partial w_{ji}} = 
\frac{\partial L}{\partial y_j} \frac{\partial y_j}{\partial w_{ji}}

となります。

3. より一般的な式にする

1, 2では、層の数も限られていましたし、各層の次元も少ないケースでしたが、より一般的に第$i$層、第$j$層、第$k$層に注目してみたいと思います。
ここでは、行列は成分表示で表します。
1

\begin{align}
a_j^{(j)} &= \sum_i w_{ji}^{(j)}z_i^{(i)} \\
z_j^{(j)} &= h(a_j^{(j)})
\end{align}

重み$w_{ji}^{(j)}$による微分は、$w_{ji}^{(j)}$が$a_j^{(j)}$のみに出現することから、

\begin{align}
\frac{\partial L}{\partial w_{ji}^{(j)}} &= \frac{\partial L}{\partial a_j^{(j)}}\frac{\partial a_j^{(j)}}{\partial w_{ji}^{(j)}} \\
&= \frac{\partial L}{\partial a_j^{(j)}}\frac{\partial}{\partial w_{ji}^{(j)}} ( \sum_i w_{ji}^{(j)}z_i^{(i)} ) \\
&= \frac{\partial L}{\partial a_j^{(j)}}z_i

\end{align}

また、$j$番目の入力$a_j^{(j)}$による微分は、$a_j^{(j)}$が$a_k^{(k)}$の変化を通じてしか誤差関数を変化させないことから、

\begin{align}
\frac{\partial L}{\partial a_j^{(j)}} &=
\sum_k \frac{\partial L}{\partial a_k^{(k)}}\frac{\partial a_k^{(k)}}{\partial a_j^{(j)}} \\
&= \sum_k \frac{\partial L}{\partial a_k^{(k)}}\frac{\partial}{\partial a_j^{(j)}} ( \sum_j w_{kj}^{(k)}z_j^{(j)} ) \\
&= \sum_k \frac{\partial L}{\partial a_k^{(k)}} w_{kj} \frac{\partial h(a_j^{(j)})}{\partial a_j^{(j)}} \\
&= \frac{\partial h(a_j^{(j)})}{\partial a_j^{(j)}}\sum_k w_{kj} \frac{\partial L}{\partial a_k^{(k)}} 
\end{align}

となる。

参考

  1. ゼロから作るDeep Learning ――Pythonで学ぶディープラーニングの理論と実装
  2. パターン認識と機械学習
  3. 実装ノート
Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away