NVAEやVery Deep VAEs(VDVAE)といったVAEベースの画像生成モデルに興味を持ったので,その損失関数についてメモ
VAE
◯(階層的でないVAEの)損失関数
$$
\mathcal L(x, z) = \mathbb E_{z\sim{q(z\mid x)}}\left[ \log p_\theta(x\mid z)\right] - \mathrm{KL}(q_\phi(z\mid x)||p(z))
$$
- $x\in[0,1]^{C\times W\times H}$は観測値
- $z\in\mathbb R^D$は潜在変数
- $C,W,H,D\in\mathbb N$は,それぞれ,画像のチャネル,幅,高さ,潜在変数の次元
◯モデル
なんらかの分布を仮定し,そのパラメータを出力する。
(確率分布そのものはNNから出力できないため)
- encoder|$q_\phi(z\mid x)$
- 出力|正規分布のパラメータ(平均・分散)
- decoder|$p_\theta(x\mid z)$
- 出力|Mixture of Discretized Logisticsのパラメータ
Mixture of Discretized Logistics
多分初出であるPixelCNN++[2017]を参考にした。それ以前のPixelCNNでは,出力を255次元のカテゴリカルとしていたが,Mixture of Discretized Logisticsは,そのパラメータを減らすモチベで考案された。
◯Mixture of Discretized Logisticsの確率関数
$$
P(x \mid \pi, \mu, s) =\sum_{i=1}^{K}
\pi_{i}\left[\sigma\left(\left(x+0.5-\mu_{i}\right) / s_{i}\right)-\sigma\left(\left(x-0.5-\mu_{i}\right) / s_{i}\right)\right]
$$
- $P(x=a)=P(x\geq a+0.5) - P(x\geq a-0.5)$というCDFとPDFの関係を利用
◯RGBの同時分布のモデル化(混合なしの場合)
$$
p\left(r_{i, j}, g_{i, j}, b_{i, j} \mid C_{i, j}\right)= P\left(r_{i, j} \mid \mu_{r}\left(C_{i, j}\right), s_{r}\left(C_{i, j}\right)\right) \times P\left(g_{i, j} \mid \mu_{g}\left(C_{i, j}, r_{i, j}\right), s_{g}\left(C_{i, j}\right)\right)
\times P\left(b_{i, j} \mid \mu_{b}\left(C_{i, j}, r_{i, j}, g_{i, j}\right), s_{b}\left(C_{i, j}\right)\right)
$$
$$
\mu_{g}\left(C_{i, j}, r_{i, j}\right)=\mu_{g}\left(C_{i, j}\right)+\alpha\left(C_{i, j}\right) r_{i, j}
$$
$$
\mu_{b}\left(C_{i, j}, r_{i, j}, g_{i, j}\right)=\mu_{b}\left(C_{i, j}\right)+\beta\left(C_{i, j}\right) r_{i, j}+\gamma\left(C_{i, j}\right) b_{i, j}
$$
- R→G→Bの順で依存関係があると仮定(RGB各値は独立ではないためらしい)
- $\mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$の9つのパラメータ per pixel
◯混合ありの場合
上記の9つのパラメータに加えて,混合比率である$\pi$が加わって,計$10K$個のチャネルが必要
$$
p(r_{i,j},g_{i,j},b_{i,j}\mid C_{i,j}) = \sum_{k=1}^K\pi_k p_i(r_{i,j},g_{i,j},b_{i,j}\mid C_{i,j})
$$
※NVAEやVDVAEでは,$K=10$が用いられているため,出力は100チャネルとなる。
再構成ロス
上記の確率分布に基づいて,nllを計算するだけ。各実装をコピペすると楽。
- tf(openai)|https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py
- torch(nvidia)|https://github.com/NVlabs/NVAE/blob/master/distributions.py
- torch(openai)|https://github.com/openai/vdvae/blob/main/vae_helpers.py
まとめ
- VAEで画像を扱うときは,3チャネルではなく100チャネル出力
- $100=10\times (9 + 1)$で,10が混合数・9がパラメタ数・1が混合比率
- 再構成ロスはnll(著者実装が参考になる)