LoginSignup
0
0

More than 3 years have passed since last update.

PRML 演習問題 5.22(標準) 解答

Last updated at Posted at 2020-07-01

問題

微分のチェーンルールを応用して、2層フィードフォワードネットワークのヘッセ行列の要素について(5.93), (5.94), および(5.95)の結果を導け。

参考

2層フィードフォワードネットワーク

20200702_Exercise_5-23-1.png
引用:@YusukeToda1984

ヘッセ行列

\begin {align*}
\delta_{k}=\frac{\partial E_{n}}{\partial a_{k}}, \quad M_{k k^{\prime}} \equiv \frac{\partial^{2} E_{n}}{\partial a_{k} \partial a_{k^{\prime}}}
\tag{5.92}
\end {align*}

1.両方の重みが第2層にある:

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{k j}^{(2)} \partial w_{k^{\prime} j^{\prime}}^{(2)}}=z_{j} z_{j^{\prime}} M_{k k^{\prime}}
\tag{5.93}
\end {align*}

2.両方の重みが第1層にある:

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}}=x_{i} x_{i^{\prime}} h^{\prime \prime}\left(a_{j^{\prime}}\right) I_{j j^{\prime}} \sum_{k} w_{k j^{\prime}}^{(2)} \delta_{k}\\
\quad+x_{i} x_{i^{\prime}} h^{\prime}\left(a_{j^{\prime}}\right) h^{\prime}\left(a_{j}\right) \sum_{k} \sum_{k^{\prime}} w_{k^{\prime} j^{\prime}}^{(2)} w_{k j}^{(2)} M_{k k^{\prime}}
\tag{5.94}
\end {align*}

3.重みは1つの層に1つずつある:

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{k j^{\prime}}^{(2)}}=x_{i} h^{\prime}\left(a_{j^{\prime}}\right)\left\{\delta_{k} I_{j j^{\prime}}+z_{j^{\prime}} \sum_{k^{\prime}} w_{k^{\prime} j^{\prime}}^{(2)} M_{k k^{\prime}}\right\}
\tag{5.95}
\end {align*}

解答

解答としては問題文の通りに1.2.3.を順に確認していくことになる。ただし、同じ$a$でも添字によって意味合いが異なってしまうため、常に添字に注意しながら微分を行なっていくことになる。

1.両方の重みが第2層にある:

\begin {align*}
a_{j}=\sum_{i} w_{j i} z_{i}
\tag{5.48}
\end {align*}

より、

\begin {align*}
a_{k}=\sum_{j} w_{k j} z_{j}
\tag{5.48'}
\end {align*}

となる。この(5.48')を用いると、

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{k j}^{(2)} \partial w_{k^{\prime} j^{\prime}}^{(2)}} =& \frac{\partial^{2} E_{n}}{\partial a_{k} \partial a_{k^{\prime}}}\frac{\partial a_{k}}{\partial w_{k j}^{(2)}} \frac{\partial a_{k^{\prime}}}{\partial w_{k j}^{(2)}}
\\=& z_{j} z_{j^{\prime}} M_{k k^{\prime}}
\tag{5.93}
\end {align*}

よって(5.93)は示すことができた。

2.両方の重みが第1層にある:

\begin {align*}
\frac{\partial E_{n}}{\partial w_{j i}}=\delta_{j} x_{i}
\tag{5.53}
\end {align*}

本文では$x_i$ではなく$z_i$だが、今回は第1層を考えているため$x_i$となる。

\begin {align*}
\delta_{j}=h^{\prime}\left(a_{j}\right) \sum_{k} w_{k j} \delta_{k}
\tag{5.56}
\end {align*}

上の(5.53), (5.56)より、

\begin {align*}
\frac{\partial E_{n}}{\partial w_{j i}^{(1)}} = x_{i}h^{\prime}\left(a_{j}\right) \sum_{k} w_{k j}^{(2)} \delta_{k}
\tag{ex5.22.1}
\end {align*}

となる。このことと微分のチェーンルールを用いて、

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}} =  
\frac{\partial}{\partial a_{j^{\prime}}}\left(\frac{\partial E_{n}}{\partial w_{j i}^{(1)}}\right) \frac{\partial a_{j^{\prime}}}{\partial w_{j^{\prime} i^{\prime}}^{(1)}}
\tag{ex5.22.2}
\end {align*}

この(ex5.22.2)に関しては$h^{\prime}\left(a_{j}\right)$での微分が$j \neq j'$の場合と$j = j'$の場合で変わってくるため、$j \neq j'$の場合と$j = j'$の場合で分けて考えるのがよい。

  • $j \neq j'$の場合
\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}} =  
\sum_{k^{\prime}} \frac{\partial}{\partial a_{k^{\prime}}}\left(\frac{\partial E_{n}}{\partial w_{j i}^{(1)}}\right) \frac{\partial a_{k^{\prime}}}{\partial a_{j^{\prime}}} x_{i^{\prime}}
\end {align*}
\begin {align*}
\frac{\partial a_{k}^{\prime}}{\partial a_{j}^{\prime}}=w_{k^{\prime} j^{\prime}} h^{\prime}\left(a_{j^{\prime}}\right)
\end {align*}
\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}}=x_{i} x_{i^{\prime}} h^{\prime}\left(a_{j^{\prime}}\right) h^{\prime}\left(a_{j}\right) \sum_{k} \sum_{k^{\prime}} w_{k^{\prime} j^{\prime}}^{(2)} w_{k j}^{(2)} M_{k k^{\prime}}
\tag{5.94'}
\end {align*}
  • $j = j'$の場合
\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}}=x_{i} x_{i^{\prime}}\sum_{k} w_{k j}\left\{\left(\frac{\partial}{\partial a_{j^{\prime}}} \frac{\partial E_{n}}{\partial a_{k}}\right) h^{\prime}\left(a_{j}\right)+\left(\frac{\partial}{\partial a_{j^{\prime}}} h^{\prime}\left(a_{j}\right)\right) \frac{\partial E_{n}}{\partial a_{k}}\right\}
\end {align*}
\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}}=x_{i} x_{i^{\prime}} h^{\prime \prime}\left(a_{j^{\prime}}\right) \sum_{k} w_{k j^{\prime}}^{(2)} \delta_{k}\\
\quad+x_{i} x_{i^{\prime}} h^{\prime}\left(a_{j^{\prime}}\right) h^{\prime}\left(a_{j}\right) \sum_{k} \sum_{k^{\prime}} w_{k^{\prime} j^{\prime}}^{(2)} w_{k j}^{(2)} M_{k k^{\prime}}
\tag{5.94"}
\end {align*}

よって、(5.94')と(5.94")の結果を合わせると、

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{j^{\prime} i^{\prime}}^{(1)}}=x_{i} x_{i^{\prime}} h^{\prime \prime}\left(a_{j^{\prime}}\right) I_{j j^{\prime}} \sum_{k} w_{k j^{\prime}}^{(2)} \delta_{k}\\
\quad+x_{i} x_{i^{\prime}} h^{\prime}\left(a_{j^{\prime}}\right) h^{\prime}\left(a_{j}\right) \sum_{k} \sum_{k^{\prime}} w_{k^{\prime} j^{\prime}}^{(2)} w_{k j}^{(2)} M_{k k^{\prime}}
\tag{5.94}
\end {align*}

が導かれる。

3.重みは1つの層に1つずつある:

\begin {align*}
\frac{\partial E_{n}}{\partial w_{j i}^{(1)}} = x_{i}h^{\prime}\left(a_{j}\right) \sum_{k} w_{k j}^{(2)} \delta_{k}
\tag{ex5.22.1}
\end {align*}

(ex5.22.1)を用いると、

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{k j^{\prime}}^{(2)}} = 
\frac{\partial}{\partial w_{k j^{\prime}}^{(2)}}\left(\frac{\partial E_{n}}{\partial w_{j i}^{(1)}}\right)
\tag{ex5.22.3}
\end {align*}

今回も$w_{k j^{\prime}}^{(2)}$での微分が$j \neq j'$の場合と$j = j'$の場合で変わってくる。ただし今回は場合分けをせずに一気に行う。

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{k j^{\prime}}^{(2)}} = 
x_{i} h^{\prime}\left(a_{j}\right)\left\{\left(\frac{\partial}{\partial w_{k j^{\prime}}^{(2)}} \sum_{k} w_{k j}^{(2)}\right) \delta_{k}+\left(\frac{\partial}{\partial w_{k j^{\prime}}^{(2)}} \delta_{k}\right) \sum_{k} w_{k j}^{(2)}\right\}
\end {align*}

これをまとめると、

\begin {align*}
\frac{\partial^{2} E_{n}}{\partial w_{j i}^{(1)} \partial w_{k j^{\prime}}^{(2)}}=x_{i} h^{\prime}\left(a_{j^{\prime}}\right)\left\{\delta_{k} I_{j j^{\prime}}+z_{j^{\prime}} \sum_{k^{\prime}} w_{k^{\prime} j^{\prime}}^{(2)} M_{k k^{\prime}}\right\}
\tag{5.95}
\end {align*}

以上で3つの結果を導くことができた。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0