2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【メモ】VAEにおける画像の再構成ロス

Posted at

NVAEやVery Deep VAEs(VDVAE)といったVAEベースの画像生成モデルに興味を持ったので,その損失関数についてメモ

VAE

◯(階層的でないVAEの)損失関数
$$
\mathcal L(x, z) = \mathbb E_{z\sim{q(z\mid x)}}\left[ \log p_\theta(x\mid z)\right] - \mathrm{KL}(q_\phi(z\mid x)||p(z))
$$

  • $x\in[0,1]^{C\times W\times H}$は観測値
  • $z\in\mathbb R^D$は潜在変数
  • $C,W,H,D\in\mathbb N$は,それぞれ,画像のチャネル,幅,高さ,潜在変数の次元

◯モデル
なんらかの分布を仮定し,そのパラメータを出力する。
(確率分布そのものはNNから出力できないため)

  • encoder|$q_\phi(z\mid x)$
    • 出力|正規分布のパラメータ(平均・分散)
  • decoder|$p_\theta(x\mid z)$
    • 出力|Mixture of Discretized Logisticsのパラメータ

Mixture of Discretized Logistics

多分初出であるPixelCNN++[2017]を参考にした。それ以前のPixelCNNでは,出力を255次元のカテゴリカルとしていたが,Mixture of Discretized Logisticsは,そのパラメータを減らすモチベで考案された。

◯Mixture of Discretized Logisticsの確率関数

$$
P(x \mid \pi, \mu, s) =\sum_{i=1}^{K}
\pi_{i}\left[\sigma\left(\left(x+0.5-\mu_{i}\right) / s_{i}\right)-\sigma\left(\left(x-0.5-\mu_{i}\right) / s_{i}\right)\right]
$$

  • $P(x=a)=P(x\geq a+0.5) - P(x\geq a-0.5)$というCDFとPDFの関係を利用

◯RGBの同時分布のモデル化(混合なしの場合)
$$
p\left(r_{i, j}, g_{i, j}, b_{i, j} \mid C_{i, j}\right)= P\left(r_{i, j} \mid \mu_{r}\left(C_{i, j}\right), s_{r}\left(C_{i, j}\right)\right) \times P\left(g_{i, j} \mid \mu_{g}\left(C_{i, j}, r_{i, j}\right), s_{g}\left(C_{i, j}\right)\right)
\times P\left(b_{i, j} \mid \mu_{b}\left(C_{i, j}, r_{i, j}, g_{i, j}\right), s_{b}\left(C_{i, j}\right)\right)
$$

$$
\mu_{g}\left(C_{i, j}, r_{i, j}\right)=\mu_{g}\left(C_{i, j}\right)+\alpha\left(C_{i, j}\right) r_{i, j}
$$

$$
\mu_{b}\left(C_{i, j}, r_{i, j}, g_{i, j}\right)=\mu_{b}\left(C_{i, j}\right)+\beta\left(C_{i, j}\right) r_{i, j}+\gamma\left(C_{i, j}\right) b_{i, j}
$$

  • R→G→Bの順で依存関係があると仮定(RGB各値は独立ではないためらしい)
  • $\mu_r, \mu_g, \mu_b, s_r, s_g, s_b, \alpha, \beta, \gamma$の9つのパラメータ per pixel

◯混合ありの場合
上記の9つのパラメータに加えて,混合比率である$\pi$が加わって,計$10K$個のチャネルが必要
$$
p(r_{i,j},g_{i,j},b_{i,j}\mid C_{i,j}) = \sum_{k=1}^K\pi_k p_i(r_{i,j},g_{i,j},b_{i,j}\mid C_{i,j})
$$

※NVAEやVDVAEでは,$K=10$が用いられているため,出力は100チャネルとなる。

再構成ロス

上記の確率分布に基づいて,nllを計算するだけ。各実装をコピペすると楽。

まとめ

  • VAEで画像を扱うときは,3チャネルではなく100チャネル出力
  • $100=10\times (9 + 1)$で,10が混合数・9がパラメタ数・1が混合比率
  • 再構成ロスはnll(著者実装が参考になる)
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?