LoginSignup
4
3

ゼロからわかるDiffusion Model

Last updated at Posted at 2023-10-14

拡散モデル

拡散モデル(DDPM[1])はForwad ProcessとReverse Processから構成されています。Forward Processでは、標準正規分布に従うノイズを入力画像に少しずつ加えていき、標準正規分布に従う潜在変数$\mathrm{\mathbf{z}}_T$を出力します。Reverse Processでは、ニューラルネットワークを使って、$\mathrm{\mathbf{z}}_T$からノイズを少しずつ取り除き、入力画像を再構成します。以下では、Forward ProcessとReverse Processについて説明し、拡散モデルを学習させるための損失関数の導出を解説していきます。(後半で計算を飛ばした部分は時間があったら書きます)
スクリーンショット 2023-10-14 23.34.18.png
図1. 拡散モデル([2]より引用)

Forward Process

入力画像を$\mathrm{\mathbf{x}}$とし、ノイズを加えていく過程における潜在変数を$\mathrm{\mathbf{z}}_1,\dots,\mathrm{\mathbf{z}}_T$とする。

\begin{align}
\mathrm{\mathbf{z}}_1&=\sqrt{1-\beta_1}\cdot\mathrm{\mathbf{x}}+\sqrt{\beta_1}\cdot\epsilon_1 \tag{1}\\
\mathrm{\mathbf{z}}_t&=\sqrt{1-\beta_{t}}\cdot\mathrm{\mathbf{z}}_{t-1}+\sqrt{\beta_t}\cdot\epsilon_t \quad \forall\,t \in\{2,\dots,T\} \tag{2}
\end{align}

である。ただし、$\epsilon_t\sim\mathcal{N}(0,\mathrm{\mathbf{I}})$であり、$\beta_t\in[0,1]$はノイズを加える速さを決めるハイパーパラメータである。(1)と(2)を確率分布の形に書き直すと

\begin{align}
q(\mathrm{\mathbf{z}}_1|\mathrm{\mathbf{x}})&=\mathcal{N}(\sqrt{1-\beta_1}\mathrm{\mathbf{x}},\beta_1\mathrm{\mathbf{I}})\tag{3}\\
q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1})&=\mathcal{N}(\sqrt{1-\beta_t}\mathrm{\mathbf{z}}_{t-1},\beta_t\mathrm{\mathbf{I}})\quad \forall\,t\in\{2,\dots,T\} \tag{4}
\end{align}

となる。この時潜在変数$\mathrm{\mathbf{z}}_1,\dots,\mathrm{\mathbf{z}}_T$の同時確率分布は

\begin{align}
q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})&=q(\mathrm{\mathbf{z}}_1|\mathrm{\mathbf{x}})\cdot q(\mathrm{\mathbf{z}}_2|\mathrm{\mathbf{z}}_1)\cdot\cdots \cdot q(\mathrm{\mathbf{z}}_T|\mathrm{\mathbf{z}}_{T-1})\\&=q(\mathrm{\mathbf{z}}_{1}|\mathrm{\mathbf{x}})\prod_{t=2}^Tq(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1})\tag{5}
\end{align}

である。$\mathrm{\mathbf{z}}_t$を計算したい時に、(1)と(2)を使って、$\mathrm{\mathbf{x}}$から順番に計算しても良いが、正規分布に従うノイズを複数回発生させる必要があり、計算効率が良くない。ここでは、$\mathrm{\mathbf{x}}$で$\mathrm{\mathbf{z}}_t$を表すことを考える。

\mathrm{\mathbf{z}}_2=\sqrt{1-\beta_2}\cdot\mathrm{\mathbf{z}}_1+\sqrt{\beta_2}\cdot\epsilon_2\tag{6}

であるから、これに(1)を代入すると

\begin{align}
\mathrm{\mathbf{z}}_2&=\sqrt{1-\beta_2}\cdot \left(\sqrt{1-\beta_1}\cdot\mathrm{\mathbf{x}+\sqrt{\beta_1}\cdot \epsilon_1}\right)+\sqrt{\beta_2}\cdot \epsilon_2\\&=\sqrt{1-\beta_2}\left(\sqrt{1-\beta_1}\cdot\mathrm{\mathbf{x}+\sqrt{1-(1-\beta_1)}\cdot \epsilon_1}\right)+\sqrt{\beta_2}\cdot \epsilon_2\\&=\sqrt{(1-\beta_1)(1-\beta_2)}\cdot \mathrm{\mathbf{x}}+\sqrt{1-\beta_2-(1-\beta_1)(1-\beta_2)}\cdot\epsilon_1+\sqrt{\beta_2}\cdot\epsilon_2 \tag{7}
\end{align}

となる。ここで、正規分布の和の再生性より

\begin{align}
\sqrt{1-\beta_2-(1-\beta_1)(1-\beta_2)}\cdot\epsilon_1+\sqrt{\beta_2}\cdot\epsilon_2&=\mathcal{N}\left(0,\left(1-\beta_2-(1-\beta_1)(1-\beta_2)\right)\mathbf{I}\right)+\mathcal{N}\left(0,\beta_2\mathrm{\mathbf{I}}\right)\\&=\mathcal{N}\left(0,\left(1-(1-\beta_1)(1-\beta_2)\right)\mathrm{\mathbf{I}}\right)\tag{8}
\end{align}

であるので、(7)は

\mathrm{\mathbf{z}}_2=\sqrt{(1-\beta_1)(1-\beta_2)}\cdot\mathrm{\mathbf{x}}+\sqrt{1-(1-\beta_1)(1-\beta_2)}\cdot\epsilon\tag{9}

で表せる。ただし、$\epsilon\sim\mathcal{N}(0,\mathrm{\mathbf{I}})$である。この計算を$t$まで繰り返すと

\begin{align}
\mathrm{\mathbf{z}}_t&=\sqrt{(1-\beta_1)(1-\beta_2)\cdot \,\cdots\,\cdot(1-\beta_t)}\cdot\mathrm{\mathbf{x}}\\&\qquad+\sqrt{1-(1-\beta_1)(1-\beta_2)\cdot \,\cdots\,\cdot(1-\beta_t)}\cdot\epsilon\\&=\sqrt{\prod_{s=1}^t(1-\beta_s)}\cdot\mathrm{\mathbf{x}}+\sqrt{1-\prod_{s=1}^{t}(1-\beta_s)}\cdot \epsilon\\&=\sqrt{\alpha_t}\cdot\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\cdot\epsilon\tag{10}
\end{align}

になる。ここでは、

\begin{align}
\alpha_t=\prod_{s=1}^T(1-\beta_s)\tag{11}
\end{align}

とおいた。(10)から、

\begin{align}
q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})=\mathcal{N}(\sqrt{\alpha_t}\mathrm{\mathbf{x},(1-\alpha_t)\mathrm{\mathbf{I}}})\tag{12}
\end{align}

であり、直接に$\mathrm{\mathbf{x}}$から$\mathrm{\mathbf{z}}_t$をサンプリングすることができる。
また、式(10)から、$t\rightarrow \infty$の時に、$\alpha_t\rightarrow 0$になるから、$q(\mathrm{\mathbf{z_t}}|\mathrm{\mathbf{x}})=\mathcal{N}(0,\mathrm{\mathbf{I}})$になり、入力画像$\mathrm{\mathbf{x}}$の痕跡は完全に消える。

Reverse Process

Reverse processでは、$\mathrm{\mathbf{z}}_T$から$\mathrm{\mathbf{x}}$を再構築したいが、$q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t)$は計算できない。なぜなら、ベイズの定理より

\begin{equation}
q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t)=\frac{q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1})q(\mathrm{\mathbf{z}}_{t-1})}{q(\mathrm{\mathbf{z}}_t)}\tag{13}
\end{equation}

であり、$q(\mathrm{\mathbf{z}}_t)$と$q(\mathrm{\mathbf{z}}_{t-1})$は未知である。そこで、ニューラルネットワークを用いて、$q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t)$を近似する。

\begin{equation}
q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t)\thickapprox Pr(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\phi_t)=\mathcal{N}(\mathrm{\mathbf{z}}_{t-1};f_t(\mathrm{\mathbf{z}}_t,\phi_t),\sigma^2_t\mathrm{\mathbf{I}})\tag{14}
\end{equation}

この時、Reverse Processの同時確率分布は

Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{\phi}}_{1\dots T})=Pr(\mathrm{x}|\mathrm{\mathbf{z}}_t,\phi_1)\prod_{t=2}^TPr(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\phi_t)\cdot Pr(\mathrm{\mathbf{z}}_T)\tag{15}

である。

損失関数の導出

$\mathrm{\mathbf{z}}_{1\dots T}$について周辺化すると、$\mathrm{\mathbf{x}}$がモデル(Reverse Process)から生成される確率の確率分布は

Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})=\int Pr(\mathrm{\mathrm{\mathbf{x}}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})d\mathrm{\mathbf{z}}_{1\dots T}\tag{16}

である。ここで、学習データ$\lbrace x\rbrace_{i=1}^I$が与えられた時に、学習データと同じようなデータをモデルから生成できるようにモデルを学習したい。この学習データがモデルから観測される対数尤度は

\log\prod_{i=1}^IPr(\mathrm{\mathbf{x}}_i|\phi_{1\dots T})=\sum_{i=1}^{I}\log Pr(\mathrm{\mathbf{x}}_i|\phi_{1\dots T})\tag{17}

である。(17)を最大化する

\hat \phi_{1\dots T}=\arg\max_{\phi_1\dots T}\sum_{i=1}^{I}\log Pr(\mathrm{\mathbf{x}}_i|\phi_{1\dots T})\tag{18}

を求めれば良い。しかし、潜在変数$\mathrm{\mathbf{z}}_{1\dots T}$の次元は非常に高く、式(16)の積分は計算できない(intractable)。$\Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})$を計算する別の方法として、

Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})=\frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})}{Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T})}\tag{19}

があるが、分母の事後分布$Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{x},\phi_{1\dots T})$はやはり計算できない。
ここで、変分推論[3]で近似的に計算する。変分推論では、簡単に計算できる分布$r(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})$で、真の事後分布$Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{x},\phi_{1\dots T})$を近似する。$r(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})$はデータ$\mathrm{\mathbf{x}}$が与えられた時の潜在変数$\mathrm{\mathbf{z}}_{1\dots T}$の分布であるので、エンコーダとして考えることができる。拡散モデルにおけるエンコーダは上で説明したForward Processであり、その分布は$q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})$と定義されており、式(5)により計算できる。よって、

q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})\thickapprox Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T})\tag{20}

となるように、$\phi_{1\dots T}$を調整すれば良い。
ここまでまとめると、$q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})$を使って、近似的に計算した$\Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})$の最大化と式(20)の近似を同時に実現したい。これらを実現する損失関数を求めたい。ここで、$\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})$に対して少しトリッキな変形を行う。

\begin{align}
\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})&=\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})\int q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})d\mathrm{\mathbf{z}}_{1\dots T}\\&=\int q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})d\mathrm{\mathbf{z}}_{1\dots T}\\&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})\right]\\&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})}{Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T})}\right]\\&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}{Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T})q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\right]\\&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})}{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\right]+\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}{Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T})}\right]\\&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})}{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\right]+\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})\|Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T}))\\&\ge\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})}{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\right]\tag{21}
\end{align}

式(21)における

\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{\mathbf{x}},\mathrm{\mathbf{z}}_{1\dots T}|\phi_{1\dots T})}{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\right]\tag{22}

はEvidence Lower Bound(ELBO)と呼ぶ。式(21)の下から2行目の右辺の第2項$\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})|Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T}))$はKL Divergenceであり、二つの分布間の距離を示す指標である。その値は非負である。よって、式(22)は$\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})$の下限である。$\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})|Pr(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}},\phi_{1\dots T}))$が0の時に、ELBOが$\log Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})$に一致するので、ELBOを最大化すれば、近似的に計算した$\Pr(\mathrm{\mathbf{x}}|\phi_{1\dots T})$の最大化と式(20)の近似を同時に実現できる。
ここからは、ELBOを綺麗な形に整理する。式(5)と式(15)を代入すると

\mathrm{ELBO}=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{x}|\mathrm{\mathbf{z}}_t,\phi_1)\prod_{t=2}^TPr(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\phi_t)Pr(\mathrm{\mathbf{z}}_t)}{q(\mathrm{\mathbf{z}}_{1}|\mathrm{\mathbf{x}})\prod_{t=2}^Tq(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1})}\right]\tag{23}

ここで、$\mathrm{\mathbf{z}}_{t-1}$は$\mathrm{\mathbf{z}}_{t}$のみに依存し(マルコフ性)、余分な条件$\mathrm{\mathbf{x}}$があっても何も変化しないので

q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1})=q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1},\mathrm{\mathbf{x}})\tag{24}

が成り立つ。これをを利用すると

\begin{align}
\mathrm{ELBO}&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_{1\dots T}|\mathrm{\mathbf{x}})}\left[\log \frac{Pr(\mathrm{x}|\mathrm{\mathbf{z}}_t,\phi_1)\prod_{t=2}^TPr(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\phi_t)Pr(\mathrm{\mathbf{z}}_t)}{q(\mathrm{\mathbf{z}}_{1}|\mathrm{\mathbf{x}})\prod_{t=2}^Tq(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{z}}_{t-1},\mathrm{x})}\right]\\&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_1|\mathrm{\mathbf{x}})}\left[\log Pr(\mathrm{\mathbf{x}}|\mathrm{\mathbf{z}}_1)\right]-\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_T|\mathrm{\mathbf{x}})\|p(\mathrm{\mathbf{z}}_T))\\&\qquad-\sum_t^T\mathbb{E}_{q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\left[\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\mathrm{\mathbf{x}})\|p(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t))\right]\tag{25}
\end{align}

となる。
式(25)の各項を一つずつ計算する。

第1項

式(14)を代入すると

\begin{align}
\log Pr(\mathrm{\mathbf{x}}|\mathrm{\mathbf{z}}_1)&=\log\frac{1}{\sqrt{2\pi}\sigma_1}\exp(-\frac{(\mathrm{\mathbf{x}}-f(\mathrm{\mathbf{z}}_1,\phi_1))^2}{2\sigma_1^2})\\&=-\frac{1}{2\sigma_1^2}\|\mathrm{\mathbf{x}}-f(\mathrm{\mathbf{z}}_1,\phi_1)\|^2+C\tag{26}
\end{align}

になる。ただし、$C$は定数。

第2項

$T$が十分大きい時に、$q(\mathrm{\mathbf{z}}_T|\mathrm{\mathbf{x}})$は標準正規分布になる。事前分布$Pr(\mathrm{\mathbf{z}}_T)$も標準正規分布として定義すると、

\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_T|\mathrm{\mathbf{x}})\|Pr(\mathrm{\mathbf{z}}_T))\thickapprox0\tag{27}

第3項

まずは、$q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\mathbf{x})$を計算する。ベイズの定理と式(4)、(12)より

\begin{align}
q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{z}}_t,\mathrm{\mathbf{x}})&=\frac{q(\mathbf{z}_{t}|\mathrm{\mathbf{z}}_{t-1},\mathrm{\mathbf{x}})q(\mathrm{\mathbf{z}}_{t-1}|\mathrm{\mathbf{x}})}{q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\\&=\frac{\mathcal{N}(\mathrm{\mathbf{z}}_{t};\sqrt{1-\beta_t}\mathrm{\mathbf{z}}_{t-1},\beta_t\mathrm{\mathbf{I}})\mathcal{N}(\mathrm{\mathbf{z}}_{t-1};\sqrt{\alpha_{t-1}}\mathrm{\mathbf{z}}_{t-1},(1-\alpha_{t-1})\mathrm{\mathbf{I}})}{\mathcal{N}(\mathrm{\mathbf{z}}_{};\sqrt{\alpha_{t}}\mathrm{\mathbf{z}}_{t},(1-\alpha_{t})\mathrm{\mathbf{I}})}\\&=\mathcal{N}\left(\mathrm{\mathbf{z}}_{t-1};\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\mathrm{\mathbf{z}}_t+\frac{\sqrt{\alpha_{t-1}}\beta_{t}}{1-\alpha_t}\mathrm{\mathbf{x}},\frac{\beta_t(1-\alpha_{t-1})}{1-\alpha_t}\mathrm{\mathbf{I}}\right)\tag{28}
\end{align}

となる(正規分布の式を展開して、整理するだけ)。そして、二つの正規分布間のKL Divergenceは解析的に計算することができ、

\begin{align}
\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{t-1}|&\mathrm{\mathbf{z}}_t,\mathrm{\mathbf{x}})\|Pr(\mathrm{\mathbf{z}}_{t-1}\|\mathrm{\mathbf{z}}_{t},\phi_t))=\\&\frac{1}{2\sigma_t^2}\left\|\frac{(1-\alpha_{t-1})}{1-\alpha_t}\sqrt{1-\beta_t}\mathrm{\mathbf{z}}_t+\frac{\sqrt{\alpha_{t-1}}\beta_{t}}{1-\alpha_t}\mathrm{\mathbf{x}}-f(\mathrm{\mathbf{z_t},\phi_t})\right\|^2+C\tag{29}
\end{align}

になる。ただし, $C$は定数であり、$\sigma_t^2=\frac{\beta_t(1-\alpha_{t-1})}{1-\alpha_t}$とする。
式(10)より

\mathrm{\mathbf{x}}=\frac{1}{\sqrt{\alpha}_t}\mathrm{\mathbf{z}}_t-\frac{\sqrt{1-\alpha_t}}{\sqrt{\alpha_t}}\epsilon\tag{30}

であり、式(29)に代入すると

\begin{align}
\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{t-1}|&\mathrm{\mathbf{z}}_t,\mathrm{\mathbf{x}})\|Pr(\mathrm{\mathbf{z}}_{t-1}\|\mathrm{\mathbf{z}}_{t},\phi_t))=\\&\frac{1}{2\sigma_t^2}\left\|\left(\frac{1}{\sqrt{1-\beta_t}}\mathrm{\mathbf{z}}_t-\frac{\beta_t}{\sqrt{1-\alpha_t}\sqrt{1-\beta_t}}\epsilon\right)-f(\mathrm{\mathbf{z}}_t,\phi_t)\right\|^2\tag{31}
\end{align}

Reparameterization

式(31)をさらに簡単にするために、

\begin{align}
f(\mathrm{\mathbf{z}}_t,\phi_t)=\frac{1}{\sqrt{1-\beta_t}}\mathrm{\mathbf{z}}_t-\frac{\beta_t}{\sqrt{1-\alpha_t}\sqrt{1-\beta_t}}g(\mathrm{\mathbf{z}}_t,\theta_t)\tag{32}
\end{align}

となるような新しいニューラルネットワーク$g(\mathrm{\mathbf{z}}_t,\theta_t)$を定義する。これを式(31)に代入し、$\mathrm{\mathbf{z}}_t$を式(10)で消去すると、

\begin{align}
\mathcal{D}_{KL}(q(\mathrm{\mathbf{z}}_{t-1}|&\mathrm{\mathbf{z}}_t,\mathrm{\mathbf{x}})\|Pr(\mathrm{\mathbf{z}}_{t-1}\|\mathrm{\mathbf{z}}_{t},\phi_t))=\\&\frac{\beta_t^2}{(1-\alpha_t)(1-\beta_t)\sigma_t^2}\left\|g(\sqrt{\alpha_t}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\epsilon,\theta_t)-\epsilon\right\|^2\tag{33}
\end{align}

そして、式(26)に式(30),(32)を代入すると

\begin{align}
\log (Pr(\mathrm{\mathbf{x}}|\mathrm{\mathbf{z}}_1,\phi_1))&=-\frac{1}{2\sigma^2_1}\|\frac{1}{\sqrt{\alpha}_1}\mathrm{\mathbf{z}}_1-\frac{\sqrt{1-\alpha_1}}{\sqrt{\alpha_1}}\epsilon-\frac{1}{\sqrt{1-\beta_1}}\mathrm{\mathbf{z}}_1\\&\qquad+\frac{\beta_1}{\sqrt{1-\alpha_1}\sqrt{1-\beta_1}}g(\mathrm{\mathbf{z}}_1,\theta_1)\|^2\\&=-\frac{1}{2\sigma^2_1}\|\frac{\beta_1}{\sqrt{1-\alpha_1}\sqrt{1-\beta_1}}g(\mathrm{\mathbf{z}}_1,\theta_1)-\frac{\sqrt{1-\alpha_1}}{\sqrt{\alpha_1}}\epsilon\|^2\\&=-\frac{\beta_1^2}{(1-\alpha_1)(1-\beta_1)2\sigma^2_1}\|g(\mathrm{\mathbf{z}}_1,\theta_1)-\epsilon\|^2\tag{34}
\end{align}

以上より

\begin{align}
\mathrm{ELBO}&=\mathbb{E}_{q(\mathrm{\mathbf{z}}_1|\mathrm{\mathbf{x}})}\left[-\frac{\beta_1^2}{(1-\alpha_1)(1-\beta_1)2\sigma^2_1}\|g(\sqrt{\alpha_1}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_1}\epsilon,\theta_1)-\epsilon\|^2\right]-\\&\quad\sum_{t=2}^T\mathbb{E}_{q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\left[\frac{\beta_t^2}{(1-\alpha_t)(1-\beta_t)2\sigma_t^2}\left\|g(\sqrt{\alpha_t}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\epsilon,\theta_t)-\epsilon\right\|^2\right]\\&=-\sum_{t=1}^T\mathbb{E}_{q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\left[\frac{\beta_t^2}{(1-\alpha_t)(1-\beta_t)2\sigma_t^2}\left\|g(\sqrt{\alpha_t}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\epsilon,\theta_t)-\epsilon\right\|^2\right]\tag{35}
\end{align}

したがって、

\begin{align}
\hat \theta_{1\dots T}&=\arg\max_{\theta_{1\dots T }}-\sum_{t=1}^T\mathbb{E}_{q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\left[\frac{\beta_t^2}{(1-\alpha_t)(1-\beta_t)2\sigma_t^2}\left\|g(\sqrt{\alpha_t}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\epsilon,\theta_t)-\epsilon\right\|^2\right]\\&=\arg\min_{\theta_{1\dots T}}\mathbb{E}_{t\sim \mathcal{U}(1,T),q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\left[\frac{\beta_t^2}{(1-\alpha_t)(1-\beta_t)2\sigma_t^2}\left\|g(\sqrt{\alpha_t}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\epsilon,\theta_t)-\epsilon\right\|^2\right]\tag{36}
\end{align}

DDPM[2]の論文ではこれを簡単化した

\mathbb{E}_{t\sim \mathcal{U}(1,T),q(\mathrm{\mathbf{z}}_t|\mathrm{\mathbf{x}})}\left[\left\|g(\sqrt{\alpha_t}\mathrm{\mathbf{x}}+\sqrt{1-\alpha_t}\epsilon,\theta_t)-\epsilon\right\|^2\right]\tag{37}

を損失関数として採用している。式(37)の方が性能がより良いことを実験から確認した。

参考文献

[1] J. Ho, A. Jain, P. Abbel. Denoising Diffusion Probabilistic Models. NeurIPS, 2020.
[2] Simon J.D. Prince. Understanding Deep Learning. MIT Press, 2023.
[3] Diederik P. Kingma, M. Welling. An Introduction to Variational Autoencoders. Foundations and Trends in Machine Learning, 2019.

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3