この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。
苏剑林. (Jul. 19, 2022). 《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪 》[Blog post]. Retrieved from https://kexue.fm/archives/9164
これまで、筆者は拡散生成モデルDDPMの導出を二通り紹介した。「拡散生成モデル漫談(一):DDPM=解体+建設」で紹介した平易な喩え話に基づいた手法と、「拡散生成モデル漫談(二):DDPM=自己回帰型VAE」で紹介した変分自己符号化器による手法だ。二つの方法はそれぞれ特徴があり、前者はより端的で理解しやすいが、理論的な延長や定量的な理解の余地がない。後者は分析としてはより完全だが、少し形式ばっていて啓発性に欠ける。
この記事では、更にもう一つのDDPMの導出法を紹介したい。この方法はベイズの定理を利用して計算を簡潔化するもので、導出過程には「推敲」の妙味があり、啓発性に富んでいる。さらに、この方法は後に紹介するDDIMとも緊密な関係がある。
モデルのポートフォリオ
改めてDDPMのモデル化過程を振り返ろう。変換の過程は以下のとおりである。
x=x_0\rightleftarrows x_1\rightleftarrows x_2\rightleftarrows \cdots \rightleftarrows x_{T-1}\rightleftarrows x_T=z\qquad(1)
右方向はデータサンプル$x$をランダムノイズ$z$に変換する過程で、その逆方向がランダムノイズ$z$をデータサンプル$x$に変えていく過程である。後者が我々の求める「生成モデル」だ。
右方向の過程は単純だ。各ステップは以下の通りである。
x_t=\alpha_tx_{t-1}+\beta_t\varepsilon_t,\quad \varepsilon_t\sim\mathcal{N}(0,I)\qquad(2)
あるいは、$p(x_t|x_{t-1})=\mathcal{N}(x_t;\alpha_tx_{t-1},\beta_t^2I)$とも書ける。$\alpha_t^2+\beta_t^2=1$と定めると、以下の通りになる。
\begin{align}
x_t&=\alpha_tx_{t-1}+\beta_t\varepsilon_t\\
&=\alpha_t(\alpha_{t-1}x_{t-2}+\beta_{t-1}\varepsilon{t-1})+\beta_t\varepsilon_t\\
&=\cdots\\
&=(\alpha_t\cdots\alpha_1)x_0+\underbrace{(\alpha_t\cdots\alpha_2)\beta_1\varepsilon_1
+(alpha_t\cdots\alpha_3)\beta_2\varepsilon_2+\cdots+
\alpha_t\beta_{t-1}\varepsilon_{t-1}+\beta_t\varepsilon_t}
_{\sim\mathcal{N}(0,(1-\alpha_t^2\cdots\alpha_1^2)I)}
\end{align}
\qquad(3)
ゆえに$p(x_t|x_0)=\mathcal{N}(x_t;\overline\alpha_tx_0,\overline\beta_t^2I)$であることが分かる。ここで$\overline\alpha_t=\alpha_1\cdots\alpha_t$ $\overline\beta_t=\sqrt{1-\overline\alpha_t^2}$。
DDPMの目標は、上述の情報から逆過程$p(x_{t-1}|x_t)$を求めることである。それができれば、任意の$x_T=z$から逐次的に$x_{T-1},x_{T-2},\cdots,x_1$をサンプリングすることで、最終的にデータサンプル$x_0=x$を得ることができる。
ベイズご登場
ここで偉大なるベイズの定理にご登場いただこう。ベイズの定理を直接適用すると、
p(x_{t-1}|x_t)=\frac{p(x_t|x_{t-1})p(x_{t-1})}{p(x_t)}\qquad(4)
しかし、我々は$p(x_{t-1})$と$p(x_{t})$の中身が分からないので、これ以上は何もできない。ただ、一歩引いて$x_0$を条件付けた上で、ベイズの定理を使用することはできる。
p(x_{t-1}|x_t,x_0)=\frac{p(x_t|x_{t-1})p(x_{t-1}|x_0)}{p(x_t|x_0)}\qquad(5)
$p(x_t|x_{t-1})$、$p(x_{t-1}|x_0)$と$p(x_t|x_0)$は既知なので、この式なら計算ができる。それぞれの式を代入すると、以下の結果になる。
p(x_{t-1}|x_t,x_0)=\mathcal{N}\left(x_{t-1};\frac{\alpha_t\overline\beta_{t-1}^2}{\overline\beta_t^2}x_t\frac{\alpha_{t-1}\beta_{t}^2}{\overline\beta_t^2}x_0,\frac{\overline\beta_{t-1}^2\beta_t^2}{\overline\beta_t^2}I\right)\qquad(6)
上の式の導出は普通に整理すればいいので難しくはないが、少しテクニックを使って計算を早めることもできる。まず、それぞれの式を代入すると、指数部分は$-\frac{1}{2}$を除くと、以下の通りになる。
\frac{||x_t-\alpha_tx_{t-1}||^2}{\beta_t^2}+\frac{||x_{t-1}-\overline\alpha_{t-1}x_0||^2}{\overline\beta_{t-1}^2}-\frac{||x_t-\overline\alpha_tx_0||^2}{\overline\beta_t^2}\qquad(7)
この式は$x_{t-1}$に関する二次関数なので、分布は依然正規分布であることが分かる。なのでその平均値と共分散を求めればいい。式を展開すると、$||x_{t-1}||^2$の項の係数は
\frac{a_t^2}{\beta_t^2}+\frac{1}{\overline\beta_{t-1}^2}=\frac{a_t^2\overline\beta_{t-1}^2+\beta_t^2}{\overline\beta_{t-1}^2\beta_t^2}=\frac{\alpha_t^2(1-\overline\alpha_{t-1}^2)+\beta_t^2}{\overline\beta_{t-1}^2\beta_t^2}=\frac{1-\overline\alpha_t^2}{\overline\beta_{t-1}^2\beta_t^2}=\frac{\overline\beta_t^2}{\overline\beta_{t-1}^2\beta_t^2}\qquad(8)
つまり、式を整理すると$\frac{\overline\beta_t^2}{\overline\beta_{t-1}^2\beta_t^2}||x_{t-1}+\tilde\mu(x_t,x_0)||^2$のような形になる。これで共分散行列は$\frac{\overline\beta_t^2}{\overline\beta_{t-1}^2\beta_t^2}I$であることが分かった。一方で、1次項は$-2\left(\frac{\alpha_t}{\beta_t^2}x_t+\frac{\overline\alpha_{t-1}\beta_t^2}{\overline\beta_t^2}x_0\right)$なので、$\frac{-2\overline\beta_t^2}{\overline\beta_{t-1}^2\beta_t^2}$を割ることで
\tilde\mu(x_t,x_0)=\frac{\alpha_t\overline\beta_{t-1}^2}{\overline\beta_t^2}x_t
+\frac{\overline\alpha_{t-1}\beta_t^2}{\overline\beta_t^2}x_0\qquad(9)
となる。これで式$p(x_{t-1}|x_t,x_0)$は式$(6)$の通りになる。
デノイジング過程
$p(x_{t-1}|x_t,x_0)$の明示的な式が判明したが、これが知りたいわけではない。我々は$x_t$から$x_{t-1}$を予測したいのだから、最終的な生成結果である$x_0$に依存してはいけないのだ。ここで、いかにも「飛躍的」な発想だが、こう考えてみてはどうだろうか。
「$x_t$から$x_0$を推測できるようにすれば、$p(x_{t-1}|x_t,x_0)$の$x_0$を消去して、$x_t$のみに依存する式に書けるのでは?」
早速試してみよう。$x_0$を推測する関数を$\overline\mu(x_t)$とし、損失関数は$||x_0-\overline\mu(x_t)||^2$とする。学習が完了した時点で、以下の通りになると考えることにしよう。
p(x_{t-1}|x_t)\approx p(x_{t-1}|x_t, x_0=\overline\mu(x_t))=\mathcal{N}\left(x_{t-1};\frac{\alpha_t\overline\beta_{t-1}^2}{\overline\beta_t^2}x_t\frac{\alpha_{t-1}\beta_{t}^2}{\overline\beta_t^2}\overline\mu(x_t),\frac{\overline\beta_{t-1}^2\beta_t^2}{\overline\beta_t^2}I\right)
$||x_0-\overline\mu(x_t)||^2$の$x_0$は元のデータ、$x_t$はノイズ付きデータを表すので、これは実質的にデノイジングモデルを学習していることになる。DDPMの一つ目のD(Denoising)はここから来ていると言える。
具体的に書くと、$p(x_t|x_0)=\mathcal{N}(x_t;\overline\alpha_tx_0,\overline\beta_t^2I)$すなわち$x_t=\alpha_tx_0+\overline\beta_t\varepsilon, \varepsilon\sim \mathcal{N}(0,I)$、あるいは$x_0=\frac{1}{\overline\alpha_t}(x_t-\overline\beta_t\varepsilon)$と書けるので、これに倣って$\overline\mu(x_t)$を以下の形にパラメーター化する。
\overline\mu(x_t)=\frac{1}{\overline\alpha_t}\left(x_t-\overline\beta_t\epsilon_\theta(x_t,t)\right)\qquad(11)
このときの損失関数は以下のとおりである。
||x_0-\overline\mu(x_t)||^2=\frac{\overline\beta_t^2}{\overline\alpha_t^2}||\varepsilon-\epsilon_\theta(\alpha_tx_0+\overline\beta_t\varepsilon,t)||^2\qquad(12)
先頭の係数を省けば、DDPM論文が採用した損失関数になる。ここで注目すべきは、これまでのように$x_t$から$x_{t-1}$のノイズ除去過程の積分変換から導出するアプローチと違い、今回は$x_t$から$x_0$へのノイズ除去過程に直接たどり着いた。今回の導出は、これまでよりも更に直接的なアプローチであるといえる。
一方、式$(11)$を式$(10)$に代入すると、
p(x_{t-1}|x_t)\approx p(x_{t-1}|x_t, x_0=\overline\mu(x_t))=\mathcal{N}\left(x_{t-1};\frac{1}{\overline\alpha_t}\left(x_t-\overline\beta_t\epsilon_\theta(x_t,t)\right), \frac{\overline\beta_{t-1}^2\beta_t^2}{\overline\beta_t^2}I\right)\qquad(13)
これが逆過程のサンプリングで使われる分布である。これでDDPMの導出が完了した。
導出の読みやすさを考慮して、本記事の$\epsilon_\theta$は前の記事とは異なっているが、逆にDDPM論文とは一致する。
予想と修正
ここで気になった読者もいるかもしれない。我々は$x_T$を徐々に$x_0$に変化させようとしているのに、$p(x_{t-1}|x_t,x_0)$で$p(x_{t-1}|x_t)$を近似するときに、「$\overline\mu(x_t)$で$x_0$を推定する」という事をした。仮に正確に推定できるのだとすれば、そのまま$x_0$を求めれば良いわけで、反復的な推定も要らないのではないか?
現実は、$\overline\mu(x_t)$から推定された$x_0$は当然正確な推定ではない。少なくとも、初期状態から相当経っても不正確なままである。この推定の役割はあくまで事前的な予想であり、そのあとに$p(x_{t-1}|x_t)$で少しだけ前進させる、ということをしている。まずは最終的な「雑な予想」を求めてから、その「雑な予想」の方向にちょっとだけ進めることで、徐々に精度の高い結果を得るという、「予想-修正」の思想は多くのアルゴリズムで見られるものである。
Hintonが三年前に提案した「Lookahead Optimizer: k steps forward, 1 step back」も、予想(k steps forward)と修正(1 step back)からなる手法である。論文ではこの手法を「Fast-Slow」な重みの組み合わせであると説明した。すなわち、Fastの重みは予想から得られた結果で、Slowの重みは予想に基づいて修正した結果である。DDPMの「予想-修正」過程も同じような解釈をすることができる。
残りの課題
ベイズの定理を使用したとき、$(4)$が使えない原因は$p(x_{t-1})$と$p(x_t)$が分からないからである、と説明した。これは、$p(x_t)$の定義を見ればわかる。
p(x_t)=\int p(x_t|x_0)\tilde p(x_0)dx_0\qquad(15)
ここで$p(x_t|x_0)$は既知だが、データの分布$\tilde p(x_0)$は未知なので、計算できない。ただ、両方とも計算できる特殊なケースは存在するので、ここで紹介しよう。ここで紹介する結論は、前回の解説で残された分散の選択問題の答えにもなる。
一つ目のケースは、データサンプルが一つしか存在しない場合である。一般性を損なわずに、このサンプルを$0$とすると、$\tilde p(x_0)$はディラック分布$\delta(x_0)$になり、$p(x_t)=p(x_t|0)$となる。これを式$(4)$に代入すると、ちょうど$p(x_{t-1}|x_t, x_0)$の$x_0=0$の特例になる。
p(x_{t-1}|x_t)=p(x_{t-1}|x_t, x_0=0)=\mathcal{N}\left(x_{t-1};\frac{\alpha_t\overline\beta_{t-1}^2}{\overline\beta_t^2},\frac{\overline\beta_{t-1}^2\beta_t^2}{\beta_t^2}I\right)\qquad(16)
我々が主に知りたいのは分散$\frac{\overline\beta_{t-1}^2\beta_t^2}{\beta_t^2}$である。これがサンプリング時に使う分散の一つである。
二つ目のケースは、データが標準正規分布に従う場合、すなわち$\tilde p(x_0)=\mathcal{N}(x_0;0,I)$のときである。先ほど、$p(x_t|x_0)=\mathcal{N}(x_t;\overline\alpha_tx_0,\overline\beta_t^2I)$は$x_t=\alpha_tx_0+\overline\beta_t\varepsilon, \varepsilon\sim \mathcal{N}(0,I)$と書けることを示したが、ここに$x_0\sim\mathcal{N}(0,I)$を代入すると、正規分布の再生性により$x_t$も標準正規分布に従うことが分かる。標準正規分布の確率密度関数を式$(4)$に代入すると、$-\frac{1}{2}$を除いて結果は以下の通りになる。
\frac{||x_t-\alpha_tx_{t-1}||^2}{\beta_t^2}+||x_{t-1}||^2-||x_t||^2\qquad(17)
$p(x_{t-1}|x_t,x_0)$の導出と同じような方法で、以下の結論が導き出される。
p(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\alpha_tx_t,\beta_t^2I)\qquad(18)
同じく、我々が知りたい分散は$\beta_t^2$である。これが分散のもう一つのチョイスだ。
まとめ
本記事ではDDPMの「推敲」的な導出を紹介した。この導出はベイズの定理を利用して逆方向の生成過程を直接求めるもので、最初に紹介した「解体-建築」や変分推定に比べて、さらに直接的な方法である。また、この導出は啓発性にも富んでおり、この後紹介するDDIMとも密接な関係がある。