はじめに
世間では生成AIのニュースで賑やかです。生成AIには、①GAN、②VAE、③Diffusion modelなどがあります。変分オートエンコーダ(Variational Autoencoder:VAE)を知らなかったので調査しました。本記事では、②VAEについて簡単に紹介します。
VAEの実装
変分下限(Evidence Lower Bound:ELBO)
対数尤度$\log p(x)$を直接最大化することは困難です。そこで、対数尤度$\log p(x)$のELBO$\mathcal{L}(x,z)$を最大化することで、間接的に対数尤度$\log p(x)$を最大化することを考えます。
\begin{eqnarray}
\log p(x)
&=&\log \int p(x,z)dz\\
&=&\log \int \dfrac{p(x,z)q_{\phi}(z|x)}{q_\phi(z|x)}dz\\
&=&\log \mathbb{E}_{q_\phi}(z|x)\Biggl\lbrack\dfrac{p(x,z)}{q_\phi(z|x)}\Biggr\rbrack\\
&\geq&\mathbb{E}_{q_\phi}(z|x)\Biggl\lbrack\log \dfrac{p(x,z)}{q_\phi(z|x)}\Biggr\rbrack \ \ (\because \text{Jensen's Inequality})\\
&=&\mathcal{L}(x,z)
\end{eqnarray}
VAEの学習
対数尤度$\log p(x)$を直接最大化することは困難です。そこで、対数尤度$\log p(x)$のELBO$\mathcal{L}(x,z)$を最大化することで、VAEのEncoder$q_\phi(z|x)$を構成するパラメータ$\phi$と、Decoder$p_\theta(x|z)$を構成するパラメータ$\theta$を最適化することを考えます。間接的に、対数尤度$\log p(x)$を最大化できるはずです。
\begin{eqnarray}
\mathcal{L}(x,z)
&=&\mathbb{E}_{q_\phi(z|x)}\Bigl\lbrack \log \dfrac{p(x,z)}{q_\phi(z|x)}\Bigr\rbrack\\
&=&\mathbb{E}_{q_\phi(z|x)}\Bigl\lbrack \log \dfrac{p_{\theta}(x|z)p(z)}{q_\phi(z|x)}\Bigr\rbrack\\
&=&\mathbb{E}_{q_\phi(z|x)}\Bigl\lbrack \log p_{\theta}(x|z)\Bigr\rbrack
+ \mathbb{E}_{q_\phi(z|x)}\Bigl\lbrack \log \dfrac{p(z)}{q_\phi(z|x)}\Bigr\rbrack\\
&=&\mathbb{E}_{q_\phi(z|x)}[\log p_{\theta}(x|z)]-\text{KL}(q_\phi(z|x)||p(z))
\end{eqnarray}
ELBOを最大化することを考えます。ELBOと目的関数の$\pm$の符号が違うことに注意すれば、式(1)の最小化問題を解けば、EncoderとDecoderのパラメータ$\theta,\phi$が求まります。
\min_{\theta,\phi} \dfrac{1}{N}\sum_{n=1}^{N}L_D(x_n)+L_E(x_n) \tag{1}
ただし、
\begin{eqnarray}
L_D(x)&=&-\mathbb{E}_{z\sim q_\phi}[\log p_\theta(x|z)]\\
L_E(x)&=&\text{KL}(q_\phi(z|x)||p(z))
\end{eqnarray}
ここで、$L_D(x)$は、Decoderが$x$をどれだけ再構成ができているかを評価しています。$x$が2値をとる場合は、モンテカルロ法から$L_D(x)$は交差エントロピーとなります。$L_E(x)$は、Encoderの$q_\phi(z|x)$が事前分布$p(z)$とどれだけ同じかを評価しています。事前分布$p(z)$は、人が勝手に決める設計情報で、よく正規分布などが用いられます。そのときは、$q_\phi(z|x)$も正規分布なので、$L_E(x)$は解析解をえることができます。
\begin{eqnarray}
q_\phi(z|x)&=&\mathcal{N}(z;\mu_{\phi}(x),\sigma^{2}_{\phi}(x)I)\\
p(z) &=& \mathcal{N}(z;0,I)
\end{eqnarray}
デコーダーの項$L_D(x)$は、モンテカルロ法から近似計算をします。
\begin{eqnarray}
&&\min_{\theta,\phi} \dfrac{1}{N}\sum_{n=1}^{N}L_D(x_n)+L_E(x_n)\\
&&\simeq \min_{\theta,\phi} -\sum_{l=1}^{L}\log p_{\theta}(x|z^{(l)})
+ \text{KL}(q_\phi(z|x)||p(z))
\end{eqnarray}
$\lbrace z^{(l)}\rbrace_{l=1}^{L}$は、$q_\phi(z|x)$からサンプリングしたものです。
しかし、$z$をサンプリングする操作は微分不可能であり、Encoder側に勾配を伝えることが出来ません。そこで、reparameterization trickを用いることで微分可能にします。これは、標準正規分布から$\varepsilon$をサンプリングし、$z$を再構成することを考えます。
z = \mu_{\phi}(x)+\varepsilon\circ\sigma_{\phi}(x) \ \ (\text{with} \ \ \varepsilon\sim \mathcal{N}(\varepsilon;0,I))
# Encoder側からサンプリングする
mean, log_var = self.encoder(x)
# reparameterization trick
z = self.sample_z(mean, log_var, device)
y = self.decoder(z)
# Lower Boundを計算する
L_D = -torch.sum(x * torch.log(y + self.eps) + (1 - x) * torch.log(1 - y + self.eps))
L_E = -0.5 * torch.sum(1 + log_var - mean**2 - torch.exp(log_var))
loss = L_E + L_D