こんにちは!今回は画像生成に使われる「拡散モデル」について学んでいきます。
いろいろな拡散モデルが世に出ていますが、まずは基礎となる「DDPM(Denoising Diffusion Probabilistic Model)ノイズ除去拡散確率モデル」について記載します。
※ 数式はなく、直感的に理解できることを目的としています
概念図
大きく以下2つのフェーズに分かれています。
-
学習フェーズ(=訓練)
画像に徐々にノイズを付与していきます。
どの時刻でどれぐらいのノイズが混じっているかを学習しています。 -
生成フェーズ(=推論)
学習フェーズを逆方向に実施することで完全なノイズから徐々にノイズを取り除き、元の鮮明な画像を生成していきます。
どれぐらいノイズを取り除けばいいかは、事前に学習しているのでそれを使って推定します。
では、この2つのフェーズの中で実際に何が行われているかを深堀していきます!
画像については私の好きなほんやら堂のキャラクター「なまけたろう」を拝借しています(__)
学習フェーズ(=訓練)
学習フェーズがやりたいことは、今の画像$x_t$(任意の時刻$t$)にどんなノイズが含まれているかを予測することです。この当たりはずれの誤差が小さくなるように重みを更新していきます。このときのノイズは正規分布からサンプリングしています。
学習に使われているネットワークとは、U-Netが主に使われています。
U-Net
U-Netの構造をお示しします。
例として、入力画像(572×572×1)⇒出力画像(388×388×2)としています。
U-Netで行っていることは ① 縮めて要約(Encoder)、② 広げて復元(Decoder) です。
さらに、特徴的なポイントとして、③ スキップ接続があります。
① 縮めて要約(Encoder)
行っていることはダウンサンプリングです。
ダウンサンプリングとは、畳み込み処理で空間解像度(H,W)を小さくしながらチャネル数を増やしていくことで表現力を高める処理のことです。
形状の変化のポイントは以下3点です
- conv は毎回 H,W を2ずつ削る(stride=1, padding=0 のため)
- Max Pool で 1/2
- 段を下るごとにチャネル数は倍増(64→128→256→512→1024)
②広げて復元(Decoder)
空間解像度を段ごとに2倍することで、元の解像度に戻していきます。これをアップサンプリングと言います。
2倍にするやり方としては、転置畳み込みなどを使用します。
形状の変化のポイントは以下4点です
- Up の直後は「2倍」に拡大される
- 続く conv でまた 2 ずつ削れる
- skip の後はチャネルが一時的に増える → conv で整形
- 最終的に出力が入力より小さいのは、全段で padding=0 の convを使っているから
③スキップ接続
ダウンサンプリング時の隠れ層の値は、対応するアップサンプル時の隠れ層の値に結合(concat)されています。
これにより、ダウンサンプル時の情報をアップサンプル時に利用することができます。
特にダウンサンプリングでは、細部の情報を失い低解像度の抽象的な画像になっていくので、アップサンプリングの同じ段に橋渡しすることで復元を狙います。
なお、スキップ接続はconcat(チャネル方向に結合)なので、H,Wを一致させるためにEncoder側の特徴をcrop(切り取り)します。チャネル数はconcatで連結された後、次の畳み込みで整えることになります。
拡散モデルで使われているU-Net
拡散モデルで使われる通常のU-Netと以下の点で異なります。
① 時刻埋め込み
どの時刻$t$のノイズ段階かを各層に伝える必要があるので、正弦時刻埋め込み(sin/cos)→MLPを通してベクトルを作り、各残差ブロックへ付加します。
② Wide ResNet系の残差ブロック
残差ブロックの幅(チャネル数)を増やして、深さを抑える設計です。
これにより学習が速くなります。
③ Attention
設計にもよりますが、ダウン/アップサンプリング時にはLinear Attention(または省略)を、谷底の最小解像度の際にDot-Product Attention(通常のAttention)を使うことが多いです。
通常のAttentionでは重み行列がN×Nとなって$N^2d$の行列ができあがりますが、Linear Attentionでは行列の近似・再配置によって二乗を回避することで$Nd$程度の行列になるので高解像度(大きなN)の画像でもメモリ不足になりません。そのため、ダウン/アップサンプリング時に使われます。
④ Group Normalization
BatchNormの弱点を克服するものとして、GroupNormが使われています。
BatchNormはバッチ次元で平均・分散をとるため、小バッチの場合バッチ内の統計が不安定になってしまいます。
GroupNormはサンプル内のチャネルをG個のグループに分けて正規化するので、少バッチでも安定して学習することができます(=バッチサイズ非依存)。
⑤ 畳み込み時のpadding=1
オリジナルのU-Netは入力画像の形状が 572×572×1 で出力画像の形状が 388×388×2 となっています。
もともとU-Netはセグメンテーションタスクに使われ、入力画像にグレースケール画像(医用、衛星、深度など)が多く用いられました(もちろん3チャンネルのRGB画像を入力とすることもできます)。
そして、セグメンテーションのクラス数が出力のチャンネル数となっています(最後の 1×1convはチャネル数をクラス数へ写像する層)。
一方、拡散モデルに使われるU-Netは3×3の畳み込みの際に、padding=1 を入れるためサイズ(H,W)が変わりません。最終的な出力は入力と同じサイズになっています。
そのため、各段でのskip結合では、cropは不要となります。これで入力と同じ解像度のノイズ推定ができるようになります。
学習フェーズまとめ
概念図では、少しずつノイズを足していくように記載しましたが、実装では毎イテレーションで時刻$t$をランダムに一つ選び、その時刻の$x_t$を一発で作ります。「元画像$x_0$と標準正規ノイズ$ε$を決まった係数で合成」すれば、段階を踏んだものと同じ分布の$x_t$が得られます。これは、ノイズ付加はパラメータ(ノイズ量スケジュール)が固定で、学習対象ではないからです。
つまり、任意の時刻$t$の画像$x_t$は、決まった係数*$x_0$+決まった係数×εにより一発で作ることができ、U-Netに時刻$t$と$x_t$を入れることで、ノイズの予測値$\hat{\varepsilon}$を得ることができます。そして、真の$ε$と$\hat{\varepsilon}$の平均二乗誤差が小さくなるようにU-Netの重みを更新することが学習で行われいるということです。
生成フェーズ(=推論)
純ノイズ$x_T$から始め、各時刻で先ほど学習した U-Net を使ってノイズを推定し$x_{T-1}$(一つ前の時刻の画像)を計算します。これを繰り返して元の画像である$x_0$を得ます。
ちなみに、今回の「DDPM(Denoising Diffusion Probabilistic Model)ノイズ除去拡散確率モデル」では、各段で少しランダム性を足すことになります(ノイズ再注入)。
対して、「DDIM(Denoising Diffusion Implicit Model)ノイズ除去拡散暗黙モデル」では、ランダムな値を入れないで決定的な更新をします。
なお、Stable Diffusion系は VAEで画像を潜在空間に圧縮してから同じ手続きを回し、最後にVAEで画像に戻します。これは、計算を軽くするための工夫で、U-Netの役割(各時刻で引くべき成分を推定する)は同じです。
生成フェーズのまとめ
最初は純ノイズ $x_T$ から始めて、時刻を $T \rightarrow 0$ とひとつずつ下げながら、各ステップで次の処理を行います。
- U-Net に $(x_t, t,[, \text{条件}])$ を入力し、「いま混ざっているノイズ」の推定値 $\hat{\varepsilon}_t$ を得る。
- そのノイズ分を 引いて、少しだけきれいな $x_{t-1}$ を計算する。
(DDPM はこのとき少しランダムを再注入する/DDIM は入れない=決定的)
この 「当てる → 引く」 を繰り返して、最終的に $x_0$(ノイズのない画像)に到達します。
以上です。読んでいただきありがとうございました。