LoginSignup
11
9

最新の高性能 Diffusion Models (2024年)

Last updated at Posted at 2024-02-11

最近、Transformerベースのdiffusion modelが高いパフォーマンス(ImageNetのFID)を出している。ということで、特に性能の高い最新モデルを2つ紹介する。加えて、これらを調査していたら、それらの性能をさらに底上げする手法とCNNベースでさらに高い性能を出してSOTAを達成したぞという論文にもさらに行き着いたので、それら2本も併せて追加で紹介する。

(追記)2024/2/23に発表されたStable Diffusion 3や2/15に発表されたOpenAIのSoraでは、今回紹介するDiTがDiffusion Transformer(拡散トランスフォーマー)のベース技術に採用されている。

※以降の図は論文からそのまま引用、もしくはそれに多少の加工を加えたものとなる。
※以降のpaperswithcodeの順位は2024/1時点

目次

  1. DiT (ICCV'23): Diffusion Model + Transformer
  2. DiffiT: Diffusion Model + Transformer (+ U-Net)
  3. CADS (ICLR'24): 多様性を向上させるsampling
  4. EDM2: CNNベースの高性能 Diffusion Model
  5. 感想

DiT (ICCV'23): Diffusion Model + Transformer

ImageNet 256 - 10位 (paperswithcode)

概要

Diffusion modelの基幹であるCNNベースのU-Netを、Transformerベースのアーキテクチャに置き換えることで性能向上を実現している。
image.png

モデル解説

従来のU-Net構造を捨てて、ViTベースの生成モデルアーキテクチャDiTが提案されている。DiTは、ViTから以下に説明する入出力とself-attention blockの2点を変更した生成モデルとなる。
※全体アーキテクチャのベースはLDM (Stable Diffusion)
※ロスは通常のDDPM

(ViTからの変更点1) 入出力

image.png
ViTとはタスクが異なるため、上図(左)のように入出力を以下に変更する。

  • ViT
    • [入力] 画像
    • [出力] クラス分類結果
  • DiT
    • [入力] 画像+時刻t+クラスラベルy
      • LDMと同様、画像はVAEのencoderによりコンパクト化された状態で扱う
        • そのため、画像はNoised Latent / Noiseと表記される
      • textを扱いたい場合はtext embeddingを入力に追加する必要がある
    • [出力] 画像(+標準偏差Σ)
      • 画像を出力させるために、最終段でunpatchify処理(patchifyの逆)を行っている
      • 標準偏差Σはオプション(Diffusion modelの学習方式次第)
(ViTからの変更点2) DiT Block

image.png
ViTには存在しない入力である条件情報(=Conditioning=時刻t+クラスラベルy)を扱えるように変更する。具体的には、ViTのself-attention blockに対して、上図のようにAdaLN機構(γ/βによるScale/Shift)を追加する。このAdaLN機構により、Scale γ, Shift βを通じた条件情報の注入が達成される。

このように、基本的には元のself-attentionにAdaLNを追加するだけのシンプルな変更であるが、性能向上のために追加で以下の2つの工夫を加えている。

  • AdaLNのγ・βに加え、Scale αを新規に追加: 条件情報を参照するパラメータを増やすことで、条件情報をより強く画像に反映させることができる
    • Scale αの実装: 上図にあるように残差結合の直前にα倍のスケーリング処理を追加する
  • AdaLNのγ・βを算出するMLPの重み・バイアスを0初期化: 学習初期は条件情報を全く注入しないことになり、学習が安定化する
    • まともな画像を出力できていない段階で、条件情報でその画像をさらに大きく変化させてしまうと、学習が安定しないため

※Tensorの並び順等の都合で実装上はAdaLNであるが、動作としてはAdaINをイメージすれば良い。
※論文上ではDiT blockの候補として他に3案を挙げているが、ここでは最終的に採用された方式のみを説明した。

Experiment

classifier-free guidance付きの高GFlops(patch size=2、XLargeサイズ)なDiTモデルを用意して、SOTA級のモデルら(GANで現状最強のStyleGAN-XL や classifier-free guidance付きLDM)と比較した。結果、以下の通りFID=2.27でDiTが首位を獲得した。
image.png
※Precisionは画像の品質、Recallは画像のcover率(≒多様性)

DiffiT: Diffusion Model + Transformer (+ U-Net)

ImageNet 256 - 2位 (paperswithcode)
ImageNet 512 - 7位 (paperswithcode)

概要

上述のDiTと同じで、Diffusion modelの基幹であるCNNベースのU-Netを、Transformerベースのアーキテクチャに置き換えることで性能向上を実現している。この際、U-Net構造を継承する効率的なバージョンも用意しており、U-Netの高い効率性を得ることができる。
image.png

モデル解説

DiTと同様にViTベースの構成をしており、主な特徴(≒DiTとの差分)は独自のDiffiT blockとオプションのU-Net構造にある。(上図はU-Net構造なし版)

入出力

ViTに対して、DiTと同様の変更が行われている。
つまり、入力は画像(noised)に加えて時刻・クラス情報が追加され、出力は画像(noised)に変更されている。

DiffiT Block

DiTと同様に、ViTからself-attention blockが変更されている。DiTではself-attention処理にAdaLNを適用することで条件情報の注入を実現したが、DiffiTでは下図のようにself-attentionの入力であるQ, K, Vに条件情報トークンを加算することで注入する。また図には載っていないが、Relative positional encodingを採用することで、時刻情報を追加で与えている。
image.png
image.png

(オプション) U-Net構造

規模の小さい(解像度が低い、複雑性が低い)データセットでは、下図のようなU-Net構造による軽量化・効率化が提案されている。本構造では、DownsamplingやWindow attentionの導入によりself-attentionの計算量を削減することで、大幅な軽量化を実現している。
※ImageNetの規模だと大量のパラメータが必要なためか、U-Net構造は採用されていない
image.png

Experiments

以下の通り、DiT越えのFID=1.73を達成した。(ImageNet-256)
image.png

以下の表の上2つはFFHQ-64データセット、下2つはImageNet-256データセット(Latent空間)の結果である。DiffiTは、FFHQでは高効率なU-Net有りモデルを採用し、ImageNetでは高性能なU-Net無しモデルを採用している。それぞれ同規模の従来モデルと比較しているが、DiffiTの方が高い性能(FID)を出せている。
image.png

CADS (ICLR'24): 多様性を向上させるsampling

ImageNet 256 - 1位 (paperswithcode) ‥ DiTベース

概要

Diffusion modelのアーキテクチャ変更はなしで、samplingの工夫のみで多様性の向上を実現している。下図の左が通常sampling、右がCADS samplingの例であるが、CADSの方が高い多様性を表現できているのが見て取れる。
CADSはsamplingの工夫であるため、任意のDiffusion modelへ容易に適用可能である。本論文では、上述したDiTにCADSを適用することでImageNet 256のSOTAを叩き出している。
image.png
image.png

CADS Sampling手法

やり方は極めてシンプルで、条件情報yを徐々に弱まるノイズnで補正するだけである。(ここでの条件はクラスやテキストを指す)
image.png
sはノイズスケール、γ(t)∈[0,1] はノイズ割合調整用係数(時刻と共に徐々に小さくなる)である。

ノイズ付与によりyの平均・標準偏差が変質してしまうのを防ぐために、以下の追加のrescale手法も提案されている。ただし、これを適用すると品質は上がるものの多様性向上効果が下がるので、resclaeをどの程度行うかはハイパーパラメータψで調整可能にしている。
image.png

ちなみに、このようなノイズ付与機構なんかを追加するんでなく、単にCFG重み値(=条件の強さ)を徐々に増やせば同等のことをできるのではないか(=Dynamic CFG)、と考える人もいるかもしれない。だが、著者らもそれを考えて実際に比較実験をした結果、下表のようにCADSが性能・多様性の両面で圧倒した。画像デノイズ(=Reverse process)の初期段階で、ノイズで条件をかき乱して画像の方向性を固定化させないことが多様性に繋がるっぽい。
image.png
※Recallは画像のcover率(≒多様性)

Experiment

結果は以下の表で、DiTにCADSを適用することで、DiffiT(FID=1.73)を上回るFID=1.70を達成した。(ImageNet-256)
image.png

EDM2: CNNベースの高性能 Diffusion Model

ImageNet 512 - 1位 (paperswithcode)

概要

CNNベースのDiffusion modelを、最新の知見を生かしてチューニングしていくことで高性能・高効率なモデルを実現している。ImageNet 512では、DiT・DiffiT(下図に載ってないが)を上回る性能・効率を達成した。
image.png

モデル解説(チューニング)

下表のA→...→Gのように、モデルや学習方式を少しずつ改良していくことで高性能・高効率なモデルを作り上げた。
image.png
特に、全体の値のスケールを上手く調整するような構造(Magnitude Preserving)を取ることで、微妙に邪魔だった正規化層を大きく排することに成功しており、これによりかなり簡素なレイヤー構造になっているのが特徴的である。
image.png
ちなみにこの論文では、少数パターンのEMAパラメータのモデルから、任意のEMAパラメータのモデルを高精度に近似する手法も提案しており、この分析手法により最適なEMAパラメータを効率的に求めている。

Experiment

以下の表の通り、ImageNet-512でDiffiT(FID=2.63)やCADS(FID=2.31)を上回るFID=1.91を達成した。(ImageNet-256の評価は無い)
image.png

感想

Diffusion modelでもCNNからTransformerにしてくのが良いのかと思ったら、EDM2でCNNベースが猛追してきて、現状どっちが優れてるのかまだまだ判断できないという感じだった。
(ViTに対するConvNeXtと同じような構図)

11
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
9