LoginSignup
1
0

【翻訳転載】拡散生成モデル漫談(二):DDPM=自己回帰型VAE

Last updated at Posted at 2024-03-19

この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。
原文の掲載日は2022/7/16です。

苏剑林. (Jul. 06, 2022). 《生成扩散模型漫谈(二):DDPM = 自回归式VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/9152

生成拡散モデル漫談(二):DDPM=自己回帰型VAE

拡散生成モデル漫談(一):DDPM=解体+建設」では、生成拡散モデルDDPMを説明するために「解体-建設」の喩えを使いながら、DDPMの理論を完全に導出した。さらに、DDPMは本質的に伝統的な拡散モデルではないことも指摘した。DDPMはどちらかというと変分自己符号化器(VAE)であり、実際にDDPMの論文でもVAEの考え方に従って導出を行っている。

今回はVAEの角度から改めてDDPMを紹介しよう。ついでに自身のKeras実装と実践的な経験を共有したいと思う。

多重突破

一般的なVAEでは、エンコードと生成は1回の計算で完結する。

エンコード:x\rightarrow z, 生成:z\rightarrow x\qquad(1)

ここで導入される分布は3つのみ。すなわちエンコーダー分布$p(z|x)$、生成分布$q(x|z)$および事前分布$q(z)$である。この手法の長所は形式的な単純さと、$x$と$z$の投影関係が固まっているので、エンコードモデルと生成モデルを同時に獲得でき、潜在変数の編集といったタスクを実現することができることである。一方で欠点も明らかで、確率分布をモデル化する能力に限界がある。上述の3つの分布はいずれも正規分布でモデル化するしかなく、それがモデル表現力の制約になり、結果的にぼやけた生成結果を出力してしまいがちになる。

この制約を打破するため、DDPMはエンコード過程と生成過程を$T$ステップに切り分けた。

\begin{align}
エンコード&:x=x_0\rightarrow x_1\rightarrow x_2\rightarrow\cdots\rightarrow x_{T-1}\rightarrow x_T=z\\
生成&:z=x_T\rightarrow z_{T-1}\rightarrow z_{T-2}\rightarrow\cdots\rightarrow x_1\rightarrow x_0=x
\end{align}\qquad(2)

こうすることで、$p(x_t|x_{t-1})$と$q(x_{t-1}|x_t)$はそれぞれ微小な変化をモデル化すればよくなる。ただしモデル化の対象は依然正規分布である。

同じ正規分布であるなら、なぜ複数ステップに切り分けるほうが良いのだろうか?十分微小な変化であれば、正規分布で十分正確に近似できるからである。曲線を局所的に直線で近似するようなものだ。複数ステップに切り分けることは、局所的な線形関数の集合で複雑な曲線を近似することに似たようなアプローチで、理論上は1回完結型の伝統的VAEの表現力の制約を突破することができるはずである。

同時分布の距離

よって今回の作戦は、再帰的に$(2)$のような分解を行うことで伝統的なVAEの能力を増強させることだ。エンコード過程の1ステップは$p(x_t|x_{t-1})$に、生成過程の1ステップは$q(x_{t-1}|x_t)$でモデル化される。同時分布は以下の通りになる。

\begin{align}
&p(x_0,x_1,x_2,\cdots,x_T)=p(x_T|x_{T-1})\cdots p(x_2|x_1)p(x_1|x_0)\tilde{p}(x_0)\\
&q(x_0,x_1,x_2,\cdots,x_T)=q(x_0|x_1)\cdots q(x_{T-2}|x_{T-1})q(x_{T-1}|x_T)q(x_T)
\end{align}\qquad(3)

$x_0$はデータサンプルで、$\tilde{p}(x_0)$はデータの分布である。$x_T$はエンコードされた情報で、$q(x_T)$は事前分布になる。残りの$p(x_t|x_{t-1})$や$q(x_{t-1}|x_t)$がエンコードと生成の1ステップを表している。

ここでは原作者ブログがVAEを説明する際に採用している記号慣習を踏襲し、エンコーダー分布を$p$、生成分布を$q$と記している。DDPM論文は$p$と$q$の意味が逆であることに注意。

「変分自己符号化器(二):ベイズの観点から」で筆者が指摘したように、VAEを理解する最も簡潔なアプローチは、同時分布のKL距離最小化問題として考えることである。DDPMも同様で、先ほど書いた同時分布のKL距離を最小化することがDDPMのゴールになる。

KL(p||q)=\int p(x_T|x_{T-1})\cdots p(x_1|x_0)\tilde{p}(x_0)\log\frac{p(x_T|x_{T-1})\cdots 
p(x_1|x_0)\tilde{p}(x_0)}{q(x_0|x_1)\cdots q(x_{T-1}|x_T)q(x_T)}dx_0dx_1\cdots dx_T\quad(4)

これがDDPMの最適化目標だ。ここまではDDPM論文の結論と一致しており(記号は異なるが)、より原初の論文である「Deep Unsupervised Learning using Nonequilibrium Thermodynamics」とも一致している。続いて、$p(x_t|x_{t-1})$と$q(x_{t-1}|x_t)$の具体的な形式を決め、最適化目標(4)の式を簡潔な形式に書き換えることを目指す。

分割統治

まず確認しておきたいのは、DDPMが作りたいのは生成モデル「のみ」である。エンコード過程に関しては、$p(x_t|x_{t-1})=\mathcal{N}(x_t;\alpha_tx_{t-1}, \beta_t^2I)$という極めて単純な正規分布として定義している。平均値ベクトルは$x_{t-1}$にスカラー$\alpha_t$を乗算しただけである。対して、伝統的なVAEでは平均と分散は学習されたニューラルネットから得られている。生成過程$q(x_{t-1}|x_t)$については、平均値ベクトルが学習可能な正規分布$\mathcal{N}(x_{t-1};\mu(x_t),\sigma_t^2I)$と定義されている。つまりDDPMは伝統VAEが有していたエンコード能力を捨てて、純粋に生成モデルだけを作ろうとしていることになる。上の記号のうち、$\alpha_t, \beta_t, \sigma_t$はいずれも事前に設定した値で、学習可能なパラメーターを持つのは$\mu(x_t)$のみである。

本記事における$\alpha_t,\beta_t$の定義はDDPM論文と異なる。

いまのところ分布$p$は学習可能なパラメーターを含まないので、学習目標$(4)$の$p$に関する積分は定数として無視できる。よって学習目標$(4)$は以下の式と等価である。

\begin{align}
&-\int p(x_T|x_{T-1})\cdots p(x_1|x_0)\tilde{p}(x_0)\log[{q(x_0|x_1)\cdots q(x_{T-1}|x_T)q(x_T)}]dx_0dx_1\cdots dx_T\\
=&-\int p(x_T|x_{T-1})\cdots p(x_1|x_0)\tilde{p}(x_0)[\log{q(x_T)}+\sum_{t=1}^{T}\log{q(x_{t-1}|x_t)}]dx_0dx_1\cdots dx_T
\end{align}\quad(5)

このうち事前分布$q(x_T)$も通常は標準正規分布、つまりパラメーターを持たないので、この項も定数として取り出せる。なので計算すべき部分は下の式になる。

\begin{align}
&-\int p(x_T|x_{T-1})\cdots p(x_1|x_0)\tilde{p}(x_0)\log{q(x_0|x_1)}dx_0dx_1\cdots dx_T\\
=&-\int p(x_t|x_{t-1})\cdots p(x_1|x_0)\tilde{p}(x_0)\log{q(x_0|x_1)}dx_0dx_1\cdots dx_t\\
=&-\int p(x_t|x_{t-1})p(x_{t-1}|x_0)\tilde{p}(x_0)\log{q(x_0|x_1)}dx_0dx_{t-1}dx_t
\end{align}\qquad(6)

一つ目の等号は、$q(x_{t-1}|x_t)$は$x_t$までしか依存しないことを考慮すると、$t+1$から$T$までの分布の積分は$1$になるからである。二つ目の等号は、$q(x_{t-1}|x_t)$は$x_1,\cdots, x_{t-2}$にも依存しないからである。結果的に、$p(x_{t-1}|x_0)=\mathcal{N}(x_{t-1};\overline\alpha_{t-1}x_0,\overline\beta_{t-1}^2I)$となるのだが、ここは次の章の式$(9)$を参照されたい。

建築作業再び

続きは前回の「建築の仕方」の章と基本的に同じである。

  1. 最適化と無関係な定数を除くと、$-\log q(x_{t-1}|x_t)$と関連する項は$\frac{1}{2\sigma_t^2}||x_{t-1}-\mu(x_t)||^2$である
  2. $p(x_{t-1}|x_0)$は$x_{t-1}=\overline\alpha_{t-1}x_0+\overline\beta_{t-1}\overline\varepsilon_{t-1}$、 $p(x_t|x_{t-1})$は$x_t=\alpha_tx_{t-1}+\beta_t\varepsilon_t$にあたる。$\overline\varepsilon_{t-1},\varepsilon_t\sim\mathcal{N}(0,I)$
  3. $x_{t-1}=\frac{1}{\alpha_t}(x_t-\beta_t\varepsilon_t)$からヒントを得て、$\mu(x_t)$を$\mu(x_t)=\frac{1}{\alpha_t}(x_t-\beta_t\epsilon_\theta(x_t,t))$に書き換える

以上の変換を経て、最適化目標は以下の形になる。

\frac{\beta_t^2}{\alpha_t^2\sigma_t^2}\mathbb{E}_{\overline\varepsilon_{t-1},\varepsilon_t\sim\mathcal{N}(0,I), x_0\sim\tilde{p}(x_0)}[||\varepsilon_t-\epsilon_\theta(\overline\alpha_tx_0+\alpha_t\varepsilon\beta_{t-1}\overline\varepsilon_{t-1}+\beta_t\varepsilon_t,t)||^2]\qquad(7)

これを「分散を抑える」の章に従って変数を合併すると、

\frac{\beta_t^4}{\overline\beta_t^2\alpha_t^2\sigma_t^2}\mathbb{E}_{\overline\varepsilon_{t-1},\varepsilon_t\sim\mathcal{N}(0,I), x_0\sim\tilde{p}(x_0)}[||\varepsilon-\frac{\overline\beta_t}{\beta_t}\epsilon_\theta(\overline\alpha_tx_0+\overline\beta_t\varepsilon,t)||^2]\qquad(8)

これでDDPMの学習目標ができた(元論文では、$(8)$の前の係数を消したほうが良い結果になると指摘している)。この式はVAEの最適化目標を出発点として、徐々に積分式を簡潔化して得られたものである。少し長いが、どのステップもちゃんと根拠があるので、計算は難しめだが考え方は難しくない。

DDPM論文では、脈絡もなく$q(x_{t-1}|x_t,x_0)$(元論文の記号)を導入して項を打ち消すことで、正規分布のKL距離の形式に変換するアプローチをとった。この方法はトリッキー過ぎて、筆者にとっては相当受け入れがたいやり方である。

ハイパラ設定

改めて$\alpha_t, \beta_t, \sigma_t$の選択について考えてみよう。

$p(x_t|x_{t-1})$に関しては、一般的に$\alpha_t^2+\beta_t^2=1$とすることで、ハイパラの数は半減され、形式的にも簡単になる。前回のブログで述べた通り、正規分布の再生性により、

p(x_t|x_0)=\int p(x_t|x_{t-1})\cdots p(x_1|x_0)dx_1\cdots dx_{t-1}=\mathcal{N}(x_t;\overline\alpha_tx_0,\overline\beta_t^2I)\qquad(9)

ここで$\overline\alpha_t=\alpha_1\cdots\alpha_t, \overline\beta_t=\sqrt{1-\overline\alpha_t^2}$。こうすることで$p(x_t|x_0)$は簡単な形式になる。

なぜ$\alpha_t^2+\beta_t^2=1$という制約を思いついたのか、と気になった読者もいるかもしれない。$\mathcal{N}(x_t;\alpha_tx_0,\beta_t^2I)$とはつまり、$x_t=\alpha_t x_{t-1}+\beta_t\varepsilon_t,\varepsilon_t\sim\mathcal{N}(0,I)$である。ここで仮に$x_{t-1}\sim\mathcal{N}(0,I)$である場合、$x_t$についても$x_t\sim\mathcal{N}(0,I)$となるべきだろう。それを満たすための条件が$\alpha_t^2+\beta_t^2=1$である。

先ほど述べた通り、$q(x_T)$は一般的に標準正規分布$\mathcal{N}(x_T;0,I)$に従うことにしている。学習目標は同時分布のKL距離の最小化、つまり$p=q$であるので、周辺分布も当然等しくなって欲しい。

つまりこうなってほしいのだ。

q(x_T)=\int p(x_T|x_{T-1})\cdots p(x_1|x_0)\tilde p(x_0)dx_0dx_1\cdots dx_{T-1}=\int p(x_T|x_0)\tilde p(x_0)dx_0\qquad(10)

$\tilde p(x_0)$は任意の分布なので、上の等号を成立させるためには、$p(x_T|x_0)=q(x_T)$、つまり$x_0$と関係ない標準正規分布に退化させるしかない。これはつまり、$\alpha_t$を適切な値に設定して、$\overline\alpha_T\approx 0$を満たすことを意味する。ついでに、$p(x_T|x_0)$が入力$x_0$に無関係であるという事実は、(VAEと違い)DDPMはエンコード能力を持たないことを改めて示している。前回の「解体-建設」の喩えでいうと、元の高層ビルが完全に建築資材に解体された後、この建築資材で再びビルを建設するときは、任意の姿に建設することができ、解体する前の姿になるとは限らないのだ。DDPMでは$\alpha_t=\sqrt{1-\frac{0.02t}{T}}$としたが、この選択については前回ブログの「ハイパラ設定」の章で既に議論した。

$\sigma_t$に関しては、理論上は分布$\tilde p(x_0)$が異なれば最適な$\sigma_t$も異なるが、$\sigma_t$を学習可能なパラメーターにしたくもない。なので、いくつかの特殊な$\tilde p(x_0)$から最適な$\sigma_t$を導出し、これらの特例から得られた$\sigma_t$は一般的な分布にも汎化できるだろう、と考えるしかない。簡単な例を2つ挙げてみよう。

  1. 仮に訓練セットは一つのサンプル$x_*$しか含まない、つまり$\tilde p(x_0)$はディラック分布$\delta(x_0-x_*)$に従う場合、最適値は$\sigma_t=\frac{\overline\beta_{t-1}}{\overline\beta_t}\beta_t$である
  2. 仮に$\tilde p(x_0)$は標準正規分布に従う場合、最適値は$\sigma_t=\beta_t$である

実験の結果、2つの値を選択しても結果に大差はない。なのでどちらかの値を選んでサンプリングを行うことができる。この2つの結論の導出は少し長いので、機会があればまた議論したい。

参考実装

こんな素晴らしいモデルはぜひともKerasで実装しておきたい。ここで筆者の実装例を共有しておく。

https://github.com/bojone/Keras-DDPM

筆者の実装はDDPM論文のソースコードに厳密に従ったわけではない。筆者自身で単純化したU-Net構造を設計し(例えば特徴連結を加算に変えたり、Attentionを無くした)、すぐに結果が出るようにした。24Gメモリの3090カードで、blocks=1, batch_size=64の設定で128*128サイズのCelebA HQデータセットで学習した結果、半日で大体の効果を見ることができる。3日学習したモデルのサンプリング結果は以下のようになった。

image.png

実装の過程で、筆者は以下の知見を得た。

  1. 損失関数はMSEではなくユークリッド距離を使うべき。MSEはユークリッド距離に$(幅\times高\timesチャンネル数)$を除算したもので、損失の値が小さくなりすぎて、一部のパラメーターの勾配が0に割り切られる可能性がある。これにより、学習は収束後に発散してしまう。これは低精度ニューラルネットの学習でも頻繁に起こる現象であり、次の記事を参考にされたい:「混合精度とXLAでbert4kerasの学習を高速化」
  2. 正規化はInstance Norm、Layer Norm、Group Normが使えるが、Batch Normは使うべきではない。Batch Normは学習と推論で不一致が起こる問題があり、学習性能は非常にいいのに推論結果が非常に悪い、という問題が起こる
  3. ニューラルネットの構造は論文を完全に再現する必要はない。論文はSOTAを達成するために論文を出しているので、大きくて遅いネットワークになっている。普通にU-Netの要領でオートエンコーダーを設計すれば、大体の効果を得ることができる。解いているのは純粋な回帰問題なので、学習自体は容易である
  4. パラメーター$t$の入力について、論文ではSinusoidal位置エンコーディングを使っているが、学習可能なEmbeddingに直接置き換えても大差はないことが分かった
  5. 言語モデル事前学習の慣習に則って、学習率の調整が便利なLAMB最適化アルゴリズムを利用したが、基本的にどの初期化方式でも$10^{-3}$の学習率で学習できる

総合評価

拡散生成モデル漫談(一):DDPM=解体+建設」と本記事を読むことで、DDPMに対して自分なりのイメージを持てたであろう。DDPMの長所・短所および改善の方向性をおおよそ理解できたはずだ。

DDPMには分かりやすい長所がある。訓練が容易で、生成した画像も高品質だ。訓練が容易というのはGANと比べた場合の話だ。GANはmin-max過程であり、学習過程には様々な不確定要素があり、崩壊が起こりやすい。一方でDDPMは単純に回帰的な損失関数を最小化するだけの手法であり、訓練は非常に安定している。また、「解体‐建設」の喩えを通して、DDPMは分かりやすさでもGANに劣らないことが分かるだろう。

DDPMの短所も分かりやすい。最大の短所はサンプリングが遅すぎることで、前向き計算をT回繰り返さなければならない(論文では$T=1000$も必要だった)。これは1回の計算で完結するGANよりもT倍遅いことになる。後続研究ではこの課題を改善することを目指したものが多い。GANのランダムノイズからのデータ生成は一意な変換であり、ノイズは生成結果の潜在変数と見なせるので、変数の内挿を行ったり、特定の次元を編集して生成結果を制御することができる。しかしDDPMの生成過程は完全に確率的な過程で、潜在変数と生成結果に明示的な関係が無いため、このような制御能力は無い。DDPMの論文でも内挿生成がデモンストレーションされているが、元の画像をノイズでぼやけさせてから新しい画像を「想像」させるだけのものなので、語義的(semantic)な融合を行うことは難しい。

上の欠点を改善する以外に、DDPMにはまだ研究できる課題が残っている。例えば、これまで示したDDPMは条件なしの生成だが、そうなると条件付きDDPMはどうするのか、自然に思い付くだろう。VAEからC-VAEへ、GANからC-GANへ進んだのと同じ流れだ。これは目下のDDPMの主流的な応用でもある。例えばGoogleのImagenは拡散モデルによるtext-to-image手法と超解像度手法を組み合わせたものだが、これは本質的には条件付きの拡散モデルになる。

あるいは、これまでのDDPMは連続的な変数をモデリングしてきたが、仕組み的には離散的な変数にも適用できるはずだ。では離散的変数に対するDDPMはどう設計すべきなのか?

関連研究

DDPMの関連研究と言うと、多くの人は伝統的な拡散モデルやエネルギーベースモデル、またはDenoising Autoencoderを思い浮かべるかもしれない。しかし筆者がここで触れたいのはこれらではなく、このブログでも以前紹介した、DDPMの一般化形態ともいえる「強力なNVAE:VAEの画像がぼやけるのはもう過去の話」である。

VAEの視点から見て、伝統的なVAEは出力画像がぼやける傾向がある一方、DDPMは(筆者が知る限り)高品質な画像を出力できる「2つ目の」VAEと言える。最初のVAEがこのNVAEである。NVAEの形式を振り返ると、DDPMと類似した部分が非常に多いことが分かる。たとえば、NVAEでも大量の潜在変数$z=\{z_1, z_2,\cdots,z_L\}$を導入していて、これらの変数は再帰的な関係を持っている。ゆえにNVAEのサンプリング過程もDDPMとかなり近い。

理論的な視点から見ると、DDPMはかなり簡潔化されたNVAEと見なすことができる。つまり、DDPMの潜在変数の再帰的関係はただのマルコフ的な条件正規分布であり、NVAEのような非マルコフ的な形式ではない。生成モデル自身も同じニューラルネットを使った反復計算に過ぎず、$z=\{z_1, z_2,\cdots,z_L\}$を同時入力する巨大モデルを使用したNVAEとは違う。ただしNVAEが$z=\{z_1, z_2,\cdots,z_L\}$を入力する際もパラメーター共有機構を導入しており、同じニューラルネットで反復計算をするDDPMと近い思想も含んでいる。

まとめ

本記事は変分自己符号化器VAEの視点からDDPMを導出した。この視点から見ると、DDPMは簡潔化した自己回帰型VAEであり、以前紹介したNVAEとかなり類似している。本記事は更に筆者のDDPM実装と実践的な経験を共有し、DDPMに対して総合的な評価を行った。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0