課題
PRMLでは(13.12)をもとにして(13.17)の導出を行っているが、案外この導出が分かりにくい(個人的な感想)。
そのためこの導出をもう少し簡単に書き下してみようというのが今回のテーマ。
書き下し
\begin {align*}
Q\left(\mathbf{\theta}, \mathbf{\theta}^{\text {old }}\right)=\sum_{\mathbf{Z}} p\left(\mathbf{Z} \mid \mathbf{X}, \mathbf{\theta}^{\text {old }}\right) \ln p(\mathbf{X}, \mathbf{Z} \mid \mathbf{\theta})
\tag{13.12}
\end {align*}
まず、この式(13.12)は完全データ対数尤度の事後分布における期待値である。そのため
\begin {align*}
Q\left(\mathbf{\theta}, \mathbf{\theta}^{\text {old }}\right)=\mathbb{E}\left[\ln p(\mathbf{X}, \mathbf{Z} \mid \mathbf{\theta})\right]
\tag{13.12'}
\end {align*}
と書くことができる。
次に
\begin {align*}
p(\mathbf{X}, \mathbf{Z} \mid \theta)=p\left(\mathbf{z}_{1} \mid \pi\right)\left[\prod_{n=2}^{N} p\left(\mathbf{z}_{n} \mid \mathbf{z}_{n-1}, \mathbf{A}\right)\right] \prod_{m=1}^{N} p\left(\mathbf{x}_{m} \mid \mathbf{z}_{m}, \phi\right)
\tag{13.10}
\end {align*}
(13.10)を用いて$\ln p(\mathbf{X}, \mathbf{Z} \mid \mathbf{\theta})$について書き下してやると、
\begin {align*}
\ln p(\mathbf{X}, \mathbf{Z} \mid \mathbf{\theta}) = \ln p\left(\mathbf{z}_{1} \mid \pi\right) + \sum_{n=2}^{N} \ln p\left(\mathbf{z}_{n} \mid \mathbf{z}_{n-1}, \mathbf{A}\right) + \sum_{m=1}^{N} \ln p\left(\mathbf{x}_{m} \mid \mathbf{z}_{m}, \phi\right)
\tag{1}
\end {align*}
と表すことができる。この(1)を
\begin {align*}
p\left(\mathbf{z}_{1} \mid \pi\right)=\prod_{k=1}^{K} \pi_{k}^{z_{1} k}
\tag{13.8}
\end {align*}
\begin {align*}
p\left(\mathbf{z}_{n} \mid \mathbf{z}_{n-1}, \mathbf{A}\right)=\prod_{k=1}^{K} \prod_{j=1}^{K} A_{j k}^{z_{n-1 ,\ j}\ z_{n k}}
\tag{13.7}
\end {align*}
\begin {align*}
p\left(\mathbf{x}_{n} \mid \mathbf{z}_{n}, \phi\right)=\prod_{k=1}^{K} p\left(\mathbf{x}_{n} \mid \phi_{k}\right)^{z_{n k}}
\tag{13.9}
\end {align*}
(13.8)(13.7)(13.9)を用いて変形すると、
\begin {align*}
\begin{aligned}
\ln p(\mathbf{X}, \mathbf{Z} \mid \mathbf{\theta})=& \sum_{k=1}^{K} z_{1 k} \ln \pi_{k}+\sum_{n=2}^{N} \sum_{j=1}^{K} \sum_{k=1}^{K} z_{n-1, j} z_{n k} \ln A_{j k} \\
&+\sum_{n=1}^{N} \sum_{k=1}^{K} z_{n k} \ln p\left(\mathbf{x}_{n} \mid \phi_{k}\right)
\end{aligned}
\tag{2}
\end {align*}
と表すことができる。
この(2)の完全データ対数尤度に関して$p\left(\mathbf{Z} \mid \mathbf{X}, \mathbf{\theta}^{\text {old }}\right)$の下での期待値をとってやると、
\begin {align*}
\begin{aligned}
Q\left(\mathbf{\theta}, \mathbf{\theta}^{\text {old }}\right)=& \sum_{k=1}^{K} \mathbb{E}\left[z_{1 k}\right] \ln \pi_{k}+\sum_{n=2}^{N} \sum_{j=1}^{K} \sum_{k=1}^{K} \mathbb{E}\left[z_{n-1, j} z_{n k} \right] \ln A_{j k} \\
&+\sum_{n=1}^{N} \sum_{k=1}^{K} \mathbb{E}\left[z_{n k} \right] \ln p\left(\mathbf{x}_{n} \mid \phi_{k}\right)
\end{aligned}
\tag{3}
\end {align*}
となる。この時に
\begin {align*}
\gamma\left(z_{n k}\right)=\mathbb{E}\left[z_{n k}\right]
\tag{13.15}
\end {align*}
\begin {align*}
\xi\left(z_{n-1, j}, z_{n k}\right)=\mathbb{E}\left[z_{n-1, j} z_{n k}\right]
\tag{13.16}
\end {align*}
と本文にあるように定義すると、(3)は下のように変形することができる。
\begin {align*}
\begin{aligned}
Q\left(\theta, \theta^{\text {old }}\right)=& \sum_{k=1}^{K} \gamma\left(z_{1 k}\right) \ln \pi_{k}+\sum_{n=2}^{N} \sum_{j=1}^{K} \sum_{k=1}^{K} \xi\left(z_{n-1, j}, z_{n k}\right) \ln A_{j k} \\
&+\sum_{n=1}^{N} \sum_{k=1}^{K} \gamma\left(z_{n k}\right) \ln p\left(\mathbf{x}_{n} \mid \phi_{k}\right)
\end{aligned}
\tag{13.17}
\end {align*}
これによって(13.12)をもとに(13.17)を導出することができた。