LoginSignup
22
16

More than 1 year has passed since last update.

Diffusion modelは何故1個しかないのか?

Posted at

Diffusion model(拡散モデル)に関して何故、拡散モデルが1個なのか論文や解説を読んでも疑問は解決されなかった。
実装を調べた上での理解を推測を含めつつ自分の言葉でまとめてみる。

拡散モデルはノイズ⇒画像である

image.png
まず、拡散モデルの順伝播がノイズから画像を作るのか画像からノイズを作るのかが分かりづらい。そもそも、この図でモデルは単なる矢印で書かれるだけでどのようなモデルなのか詳細がない。
拡散モデルはノイズ⇒画像を変換するモデルである。

Samplingとは?

拡散モデルの学習を終えた状態でノイズから画像を生成するのをSamplingという。
逆に画像からノイズを作る過程をForward noise processという。
便宜上、ガウス分布を円で示せば、ノイズプロセスが進むと徐々にガウス分布の半径(偏差)は大きくなる。
その逆にSampling(denoise)を行えばノイズはどんどん小さくなる。
image.png

基本は小さいノイズを何度も与える

例えばある論文では$T=1000$で$\beta_1=10^{-4},\beta_T=0.02$である。
ここで$N(\mu_1,\sigma_1^2)$と$N(\mu_2,\sigma_2^2)$の合計は$N(\mu_1+\mu_2,\sigma_1^2+\sigma_2^2)$である。例えば偏差1の分布2個の合計の偏差は$2$ではなく$\sqrt{2}$になる。つまり、Forward noise processは偏差1のノイズを与えて初期値とノイズの中間を与える、というよりは小さなノイズ(特に最初は小さい)を沢山与えて徐々に偏差を大きくしていくという方がイメージに近いと思われる。これをDDPMという。

一方の偏差1のノイズεを1回与えて初期値とこのノイズの中間を与えるという考え方はおそらくDDIMと呼ばれる。こうすることによって一段ずつ復元するよりも10倍くらい短いstep数で復元できるらしい。

x_t=\sqrt{\alpha}x_0+\sqrt{1-\alpha}\epsilon

Forward noise processで多段のデータを作る

ノイズプロセスによってノイズを加えた画像を作り、一方の正解出力を1個ずらして、加えたノイズを減らす方向の拡散モデルを作成する。徐々にノイズを減らす拡散モデルが多段あり、この拡散モデルをすべて繋げればノイズから画像を生成できると思われる。
lossはT個の損失関数(MSEかKL)の足し合わせで表現できるが、実際にはもっとシンプルに出来るらしい。しかし、その理論はよく分からない。
image.png

拡散モデルは1個しかない

こういう拡散モデルの論文を読んで拡散ステップが$T=1000$だから拡散モデルも$1000$個分あるのかなと最初自分は考えた。何故なら、多段のモデルを考えた場合、拡散モデルの重みはタイムステップによって異なる筈だからである。
しかし、実際の拡散モデルを覗いてみると拡散モデルは1個しかない。実はこの1個の拡散モデルがタイムステップの入力によって各拡散モデルに変化するのである。拡散モデルは大抵Unetなのだが、この中のResBlockが入力$x$以外にembも入力とする。そしてこのembはタイムステップからembeddingした値と、ラベルからembeddingしたものの和である。
このembも入力とすることで1000段のUnetではなく1個のUnetで拡散モデルを表現できる。

class UNetModel(nn.Module):
...
    def forward(self, x, timesteps, y=None):
...
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        emb = emb + self.label_emb(y)
class ResBlock(TimestepBlock):
...
    def forward(self, x, emb):
...

image.png

Classifier Guidance

分類モデル(Classifier)の勾配を使ってSamplingを行う。
image.png

    def cond_fn(x, t, y=None):
...
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

最初、自分はGANにおけるGeneratorとDiscriminatorの関係に近いのかと勘違いした
つまり、GANにてGeneratorを学習するのにDiscriminatorの勾配を使うように、diffusion modelを学習するのにClassifierの勾配を利用するのかと考えた。ところが分類モデル(Classifier)の勾配を使うのはあくまでノイズから画像を生成するSampling時であって拡散モデルの学習時に分類モデルは使用されない。

また、分類モデルは拡散モデルと全くの別モデルであり、構造や重みを共有したりはしない。分類モデルのパラメータ数は拡散モデルのUNetの1/10程度である。ならば何故この拡散モデルと全く関係ない分類モデルの勾配がSampling時に使えるのかという謎がある。また、分類モデルの学習には拡散モデルは不要であるが、画像にノイズを加えるノイズプロセスが必要である。また、入力にタイムステップも必要である。そういう意味だとpretrain済みの単純な分類モデルの重みではない。

最終的な自分の理解としては入力xがノイズの時に特定ラベルの勾配を使って出るノイズのような画像を足すという理解である。例えば以下の例。また、敵対的サンプルでパンダをテナガザルに変えるノイズにも近い(これは誤った正解ラベルを与えた場合の勾配だが)。とはいえ実際、どういう原理で精度が向上するのか不明である。ほかに勾配を使う例としてはGrad-CAMとかも使うがこれは勾配の層がGlobalPooling層の手前であり、今回のような入力層ではない。

image.png
image.png

latent diffusion model

今までの拡散モデルはピクセル空間においてノイズプロセスを掛けていたが、latent diffusion modelは潜在空間においてUNetを学習させる。最近、DALL-E2やImagenのText-to-Imageの分野で進歩が目覚ましいが、それらにこの潜在空間での拡散モデルが使われている。

image.png

実際の出力結果比較

拡散モデルタイプのADMの結果とGANタイプStyleGAN_XLの出力を比較してみた。
両方とも256x256のモデルで、ADMはUpsamplingモデルではない。
StyleGAN_XLだとclass=000の場合、人の顔が上手く作成されないが、ADMならきれいに作られる。一方、class=069の三葉虫だとあまり差が分からない。class=292の虎ではどちらのモデルも一長一短に見える。ADMだと虎の舌に違和感がある。
ADMのデメリットは画像を生成するのに250stepもかかる事である。この為、画像生成(Sampling)が非常に遅い。
またImageNetは画像の多様性が高いために学習は比較的困難で、他に学習成功した例はBigGAN、VQ-VAE2などがある。またADMの結果を上回る拡散モデルだと一応、Cascaded Diffusion Models(CDM)がある。

ADM class=000(tench:テンチという名の魚)

ADM_class0000.png

StyleGAN_XL class=000(tench:テンチという名の魚)

stylegan_xl_class0000.png

ADM class=069(trilobite:三葉虫)

ADM_class0069.png

StyleGAN_XL class=069(trilobite:三葉虫)

stylegan_xl_class0069.png

ADM class=282(tiger cat:虎)

ADA_class292.png

StyleGAN_XL class=282(tiger cat:虎)

stylegan_xl_class0292.png

ADMにてclass=69の画像を保存する変更

classifier_sample.py
    j = 0
    while len(all_images) * args.batch_size < args.num_samples:
...
        model_kwargs["y"] = th.ones(size=(args.batch_size,), device=dist_util.dev(), dtype = th.int64) * 69
...
        sample = sample.contiguous()

        for i in range(args.batch_size):
            PIL.Image.fromarray(sample.to('cpu').detach().numpy().copy()[i], 'RGB').save(f'sample_{(j+i):06d}.png')
        j += args.batch_size

まとめ

何で拡散モデルが多段ではなく1個しかないのかという事について書いた。
分類モデルの勾配を使うと何故Samplingが向上するのかについては理解が及んでいない。

22
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
16