導入
Alphafold3はタンパク質のアミノ酸配列、核酸のDNA/RNA配列、低分子化合物の構造式を入力にして、タンパク質単体、タンパク質複合体、タンパク質核酸複合体、タンパク質リガンド複合体の構造を予測するディープニューラルネットワークのモデルである。
Alphafold2は標準アミノ酸のみからなるタンパク質のアミノ酸配列のみを予測対象としていたが、Alphafold3は修飾基のある非標準のアミノ酸やDNA/RNAポリマー、低分子化合物を含む複合体にも対象が拡大されており、精度も向上しているとされている。
ここでいう予測対象の「構造」とは、全重原子(水素以外の原子)の三次元空間の座標である。AlphaFold2ではデコーダ部分は入力配列のアミノ酸の主鎖原子の座標と側鎖の二面角を直接予測していたが、この方法ではあらかじめ主鎖と側鎖の構造式が固定されている必要があり、標準アミノ酸以外への拡張は難しい。AlphaFold3では、Diffusionモデルを採用することで、正規分布の乱数で生成した全重原子の三次元座標からDenoisingにより全重原子の三次元座標を生成する。
AlphaFold3で採用されているDiffusionモデルは画像生成でよく用いられるDDPMやDDIMではなく、EDMと呼ばれるアルゴリズムである。
EDM(Elucidating the Design Space of Diffusion-Based Generative Models)は、確率微分方程式として定式化された拡散過程を基にしたDiffusionモデルであり、スコア関数(確率密度関数の対数の勾配)をDeep learningのモデルで表現し、データから学習して予測できるようにすることで、デノイズを実現する。EDMは時刻$t$でなく、Diffusion過程で加えられたとするノイズの大きさ$\sigma$を基に、スコア関数を予測する特徴がある。
この記事ではEDMのアルゴリズムについて、論文のAppendixにある理論的背景・導出含めて丁寧に記載する。ただし、確率論・確率微分方程式関連の厳密な議論は避ける。
EDMのDiffusionモデルの定式化、Denoiser定義
確率微分方程式と記号の準備
生成対象のデータの分布を$p_{\mathrm{data}}(\boldsymbol{x})$とする。生成対象のデータの空間は$d$次元であるとする。生成モデルは通常多次元であり、RGBのカラー画像で解像度が$h, w$であれば$d=3hw$であり、タンパク質の3次元構造であれば重原子数が$K$個なら$d=3K$である。
確率微分方程式を
d \boldsymbol{X}_t = f(t) \boldsymbol{X}_t dt + g(t) d \boldsymbol{W}_t
とする。$\boldsymbol{W}_t$は各次元が独立な$d$次元のWiener過程。時刻$t=0$の$\boldsymbol{X}_0$がノイズの加わる前の確率変数で生成対象のデータ分布に従う。時刻$t\ge0$でのデータ分布からの$\boldsymbol{X}_t$の時間発展がこの確率微分方程式で表現される。
確率微分方程式を単純に離散化して整理すると、
\boldsymbol{X}_{t_{i+1}} = (1 + f(t_{i}) (t_{i+1} - t_{i}) )\boldsymbol{X}_{t_{i}} + g(t_{i}) (\boldsymbol{W}_{t_{i+1}} - \boldsymbol{W}_{t_{i}})
のようになる。
(1 + f(t_{i}) (t_{i+1} - t_{i}) ) \boldsymbol{X}_{t_{i}}
は$\boldsymbol{X}$が全体的に$\boldsymbol{0}$に近づくことを表す。($f(t)$は通常負の値を想定する)
g(t) (\boldsymbol{W}_{t_{i+1}} - \boldsymbol{W}_{t_{i}})
は、時刻$t_{j}$までとは独立な正規分布$N(\boldsymbol{0}, g(t) (t_{i+1} - t_{i})I_d)$に従うので、$t_{i+1} - t_{i}$の期間に加えられるノイズを表す。
$p_{\mathrm{data}}(\boldsymbol{x})$を時刻$t=0$における初期分布として、確率微分方程式$d \boldsymbol{X}_t = f(t) \boldsymbol{X}_t dt + g(t) d \boldsymbol{W}_t$に従って時間発展すると、時刻$t$における確率密度関数を$p_t(\boldsymbol{x})$は、
p_t(\boldsymbol{x}) = \frac{1}{s(t)^d}\left[ p_{\mathrm{data}} * N(\boldsymbol{0}, \sigma(t)^2 \mathbf{I}_d)\right] \left( \frac{\boldsymbol{x}}{s(t)} \right)
となる。
ただし、$\mathcal{N}_D(\cdot | \boldsymbol{\mu}, \mathbf{\Sigma})$は平均$\boldsymbol{\mu}$、分散$\mathbf{\Sigma}$の正規分布の確率密度関数を表しており、$s(t), \sigma(t)$は、
\begin{eqnarray}
s(t) &=& e^{\int_0^{t} f(\xi) d\xi} \\
\sigma(t) &=& \sqrt{\int_0^t \frac{g(\tau)^2}{s(\tau)^2} d\tau}
\end{eqnarray}
で定義される。分布間の二項演算$[\cdot * \cdot]$は畳み込みを表し、2つの確率変数の和の分布の確率密度関数となる。
p(\boldsymbol{x};\sigma) = \left[ p_{\mathrm{data}} * N(\boldsymbol{0}, \sigma^2 \mathbf{I}_d)\right] (\boldsymbol{x})
とする。これは、$p_{\mathrm{data}}$に従って分布する$d$次元の確率変数の各次元に対し、平均$0$分散$\sigma^2$の独立な正規分布に従うノイズを加えた確率変数の分布になる。これを用いると、
p_t(\boldsymbol{x}) = \frac{1}{s(t)^d} p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right)
と表せる。
なお、$f(t), g(t)$を$s(t), \sigma(t)$で表すと、
\begin{eqnarray}
f(t) &=& \frac{\dot{s}(t)}{s(t)} \\
g(t) &=& s(t) \sqrt{2 \dot{\sigma}(t) \sigma(t)}
\end{eqnarray}
である。
$\sigma_{\mathrm{data}}^2$を$\boldsymbol{X}_0$の各次元の分散の平均とする。この時、時刻$t$における各次元の分散の平均は
s(t)^2 \left(\sigma_{\mathrm{data}}^2 + \sigma (t)^2\right)
となる。
Probabilistic Flow ODE
確率微分方程式
d \boldsymbol{X}_t = f(t) \boldsymbol{X}_t dt + g(t) d \boldsymbol{W}_t
と各時刻の確率密度関数が一致するような、ノイズ項のない微分方程式Probabilistic Flow ODEを構成する。
各時刻の確率分布を維持した多次元の確率微分方程式の変形の式変形で
\begin{eqnarray}
\boldsymbol{f}(\boldsymbol{x}, t) &=& f(t) \boldsymbol{x} \\
\mathbf{G}(\boldsymbol{x}, t) &=& g(t)
\end{eqnarray}
とすれば、Probabilistic Flow ODEは、
d \boldsymbol{X}_t = \left( f(t) \boldsymbol{X}_t - \frac{1}{2}g(t)^2 \nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{X}_t) \right)dt
である。$(\boldsymbol{x}, t)$に対して時刻$t$の確率密度関数の自然対数の勾配を返す関数$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$が既知であれば、このProbabilistic Flow ODEを用いることで、時刻$T$におけるノイズのみの分布からサンプリングした点を起点に、時間を遡る方向に微分方程式を離散化してシミュレーションすることで、時刻$0$のデータ分布$p_{\mathrm{data}}$からのサンプリングを実現できる。
ただ、データ分布$p_{\mathrm{data}}$に依存する$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$は一般に既知ではない。Probabilistic Flow ODEを使ったFlow matching系のDiffusionモデルは$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$に相当する$(\boldsymbol{x}, t)$の関数を学習可能なパラメータ$\theta$を持つ機械学習モデル$F_{\theta}(\boldsymbol{x}, t)$で表現し、何らかの方法でデータからパラメータ$\theta$を学習して、$F_{\theta}(\boldsymbol{x}, t)$を使って$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$を近似することになる。確率密度関数の自然対数の勾配である$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$をスコア関数と呼ぶこともある。
$f(t), g(t)$を$s(t), \sigma(t)$に直すと、Probabilistic Flow ODEは
d \boldsymbol{X}_t = \left( \frac{\dot{s}(t)}{s(t)} \boldsymbol{x} - s(t)^2 \dot{\sigma}(t) \sigma(t) \nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x}) \right)dt
となる。
EDMのDenoiser
データ分布$p_{\mathrm{data}}$からサンプリングされたデータ$\boldsymbol{y}$と、平均$0$標準偏差$\sigma$の独立な$d$次元正規分布のノイズ$\boldsymbol{n}$を想定する。$\boldsymbol{y}+\boldsymbol{n}$と$\sigma$から、$\boldsymbol{y}$を予測するようなDenoiser$\boldsymbol{D}$を考える。
つまり、$\boldsymbol{y} \sim p_{\mathrm{data}}$であれば、
\begin{eqnarray}
&&\boldsymbol{D}(\boldsymbol{y}+\boldsymbol{n}, \sigma) \simeq \boldsymbol{y} \\
&&\boldsymbol{n} \sim N(\boldsymbol{0}, \sigma^2 \mathbf{I}_d)
\end{eqnarray}
となるようなDenoiser$\boldsymbol{D}$を構成する。
このような性質のDenoiser$\boldsymbol{D}$を得るために、各$\sigma$について以下のような損失関数を考える。
\mathcal{L}(\boldsymbol{D};\sigma) = \mathbb{E}_{\boldsymbol{y}\sim p_{\mathrm{data}}, \boldsymbol{n} \sim N(\boldsymbol{0}, \sigma^2 \mathbf{I}_d)} \left[ ||\boldsymbol{D}(\boldsymbol{y}+\boldsymbol{n}, \sigma) - \boldsymbol{y}||^2 \right]
この損失関数$\mathcal{L}(\boldsymbol{D};\sigma)$が最小化の$\boldsymbol{D}$はDenoiserとしては最大限望ましいものと考えられる。
実はこの損失関数$\mathcal{L}(\boldsymbol{D};\sigma)$が最小となる$\boldsymbol{D}$を用いてスコア関数$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$を計算することができる。その事実を以下に示す。
$\boldsymbol{x} = \boldsymbol{y}+\boldsymbol{n}$としてこれを変形すると、
\begin{eqnarray}
\mathcal{L}(\boldsymbol{D};\sigma) &=& \int_{\boldsymbol{y} \in \mathbb{R}^d} \left( \int_{\boldsymbol{n} \in \mathbb{R}^d} ||\boldsymbol{D}(\boldsymbol{y}+\boldsymbol{n}, \sigma) - \boldsymbol{y}||^2 \mathcal{N}_d(\boldsymbol{n}; \boldsymbol{0}, \sigma^2 \mathbf{I}_d) d \boldsymbol{n} \right) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y} \\
&=& \int_{\boldsymbol{y} \in \mathbb{R}^d} \left( \int_{\boldsymbol{x} \in \mathbb{R}^d} ||\boldsymbol{D}(\boldsymbol{x}, \sigma) - \boldsymbol{y}||^2 \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) d \boldsymbol{x} \right) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y} \\
&=& \int_{\boldsymbol{x} \in \mathbb{R}^d} \left( \int_{\boldsymbol{y} \in \mathbb{R}^d} ||\boldsymbol{D}(\boldsymbol{x}, \sigma) - \boldsymbol{y}||^2 \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y} \right) d \boldsymbol{x} \\
&=& \int_{\boldsymbol{x} \in \mathbb{R}^d} \mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma) d \boldsymbol{x}
\end{eqnarray}
ただし、
\mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma) = \int_{\boldsymbol{y} \in \mathbb{R}^d} ||\boldsymbol{D}(\boldsymbol{x}, \sigma) - \boldsymbol{y}||^2 \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}
である。
各$\boldsymbol{x}$に対して、独立に$\mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma)$を最小化するような$\boldsymbol{D}(\boldsymbol{x}, \sigma)$を選べば、$\mathcal{L}(\boldsymbol{D};\sigma)$は最小化される。
$\boldsymbol{x}, \sigma$を固定して、
\mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma) = \int_{\boldsymbol{y} \in \mathbb{R}^d} ||\boldsymbol{D} - \boldsymbol{y}||^2 \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}
が最小になる$\boldsymbol{D} \in \mathbb{R}^d$を求める。
\begin{eqnarray}
\nabla_{\boldsymbol{D}} \mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma) &=& \int_{\boldsymbol{y} \in \mathbb{R}^d} 2 (\boldsymbol{D} - \boldsymbol{y}) \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y} \\
&=& 2 \boldsymbol{D} \int_{\boldsymbol{y} \in \mathbb{R}^d} \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y} - 2 \int_{\boldsymbol{y} \in \mathbb{R}^d} \boldsymbol{y} \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}
\end{eqnarray}
なので、$\mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma)$が最小となる$\nabla_{\boldsymbol{D}} \mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma)=0$では、
\boldsymbol{D} = \frac{\int_{\boldsymbol{y} \in \mathbb{R}^d} \boldsymbol{y} \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}}{\int_{\boldsymbol{y} \in \mathbb{R}^d} \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}}
となる。従って、損失関数$\mathcal{L}(\boldsymbol{D};\sigma)$が最小になる$\boldsymbol{D}(\boldsymbol{x}, \sigma)$は、
\boldsymbol{D}(\boldsymbol{x}, \sigma) = \frac{\int_{\boldsymbol{y} \in \mathbb{R}^d} \boldsymbol{y} \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}}{\int_{\boldsymbol{y} \in \mathbb{R}^d} \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}}
である。
次に、$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$を変形する。
p_t(\boldsymbol{x}) = \frac{1}{s(t)^d} p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right)
であるので、
\begin{eqnarray}
\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}) &=& \nabla_{\boldsymbol{x}}\log \left (\frac{1}{s(t)^d} p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right)\right) \\
&=& \nabla_{\boldsymbol{x}} \left( \log p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) -d \log s(t)\right) \\
&=& \nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right)
\end{eqnarray}
となる。
\begin{eqnarray}
p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) &=& \left[ p_{\mathrm{data}} * N(\boldsymbol{0}, \sigma^2 \mathbf{I}_d)\right] \left( \frac{\boldsymbol{x}}{s(t)} \right) \\
&=& \int_{\boldsymbol{y} \in \mathbb{R}^d} p_{\mathrm{data}} (\boldsymbol{y}) \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) d \boldsymbol{y}
\end{eqnarray}
であるので、
\begin{eqnarray}
\nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) &=& \frac{\int_{\boldsymbol{y} \in \mathbb{R}^d} p_{\mathrm{data}} (\boldsymbol{y}) \left( \nabla_{\boldsymbol{x}} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) \right) d \boldsymbol{y}}{\int_{\boldsymbol{y} \in \mathbb{R}^d} p_{\mathrm{data}} (\boldsymbol{y}) \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) d \boldsymbol{y}}
\end{eqnarray}
と表せる。
\begin{eqnarray}
\nabla_{\boldsymbol{x}} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) &=& \nabla_{\boldsymbol{x}} \left( C(\sigma) e^{ -\frac{1}{2\sigma(t)^2} \left\| \frac{\boldsymbol{x}}{s(t)} - \boldsymbol{y} \right\|^2} \right) \\
&=& -\frac{1}{\sigma(t)^2 s(t)} \left( \frac{\boldsymbol{x}}{s(t)} - \boldsymbol{y} \right) \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right)
\end{eqnarray}
となる。ただし、式変形の途中の$C(\sigma)$は$\mathcal{N}_d$が確率密度関数であるための規格化定数。
これより、
\begin{eqnarray}
&&\int_{\boldsymbol{y} \in \mathbb{R}^d} p_{\mathrm{data}} (\boldsymbol{y}) \left( \nabla_{\boldsymbol{x}} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) \right) d \boldsymbol{y} \\
&=& \frac{1}{\sigma(t)^2 s(t)} \int_{\boldsymbol{y} \in \mathbb{R}^d} \boldsymbol{y} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y} \\
&& - \frac{\boldsymbol{x}}{\sigma(t)^2 s(t)^2} \int_{\boldsymbol{y} \in \mathbb{R}^d} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right)p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}
\end{eqnarray}
であるため、
\begin{eqnarray}
&&\nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) \\
&=& \frac{1}{\sigma(t)^2 s(t)} \frac{\int_{\boldsymbol{y} \in \mathbb{R}^d} \boldsymbol{y} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}}{\int_{\boldsymbol{y} \in \mathbb{R}^d} \mathcal{N}_d\left(\frac{\boldsymbol{x}}{s(t)} ; \boldsymbol{y}, \sigma(t)^2 \mathbf{I}_d \right) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}} - \frac{\boldsymbol{x}}{\sigma(t)^2 s(t)^2}
\end{eqnarray}
となる。
すなわち、
\nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) = \frac{\boldsymbol{D} \left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) - \frac{\boldsymbol{x}}{s(t)}}{\sigma(t)^2 s(t)}
となる。
\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}) = \nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right)
だったので、これで損失関数$\mathcal{L}(\boldsymbol{D};\boldsymbol{x}, \sigma)$が最小になる$\boldsymbol{D} (\boldsymbol{x}; \sigma)$を用いてProbabilistic Flow ODEで必要な$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$を計算することができる。
Denoiserの学習
Probabilistic Flow ODEで必要な$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$を計算できる完全なDenoiser$\boldsymbol{D}(\boldsymbol{x}, \sigma)$は、$(\boldsymbol{x}, \sigma)$
\mathcal{L}(\boldsymbol{D}; \boldsymbol{x}, \sigma) = \int_{\boldsymbol{y} \in \mathbb{R}^d} ||\boldsymbol{D} - \boldsymbol{y}||^2 \mathcal{N}_d(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}_d) p_{\mathrm{data}}(\boldsymbol{y}) d \boldsymbol{y}
を最小化する$\boldsymbol{D}$で定義される。ただ、データ分布$p_{\mathrm{data}}$は未知であるため、この積分を厳密に計算することはできない。$p_{\mathrm{data}}$からサンプリングされた有限個のデータを使って、学習によって十分な性能のDenoiserを得ることを試みる。学習可能なパラメータ$\theta$を持つ機械学習モデルとしてDenoiserを設計し、損失関数
\mathcal{L} (\theta) = \mathbb{E}_{\boldsymbol{y} \sim p_{\mathrm{data}},\boldsymbol{n} \sim N(\boldsymbol{0}, \sigma^2 \mathbf{I}_d), \sigma \sim p_{\mathrm{train}}}\left[ \lambda(\sigma) \left\| \boldsymbol{D}_{\theta} (\boldsymbol{y} + \boldsymbol{n}, \sigma) - \boldsymbol{y} \right\|^2 \right]
を最小化するような$\theta$を勾配法で学習する。$p_{\mathrm{train}}$はDenoiserを学習するためにデータへ加えるノイズの標準偏差$\sigma$の分布であり学習用のハイパーパラメータである。EDMの論文では生成時のノイズスケジューリングを考慮したパラメータの対数正規分布を使用している。$\lambda(\sigma)$は$\sigma$に依存した損失関数の重みで、これも学習のためのハイパーパラメータ。
Alphafold3の学習では、ミニバッチに含まれるデータ$\boldsymbol{y}_m$1つに対して、複数の$\sigma_m, \boldsymbol{n}_m$をそれぞれの分布から独立にランダムサンプリングして平均している。バッチサイズ$B$、ミニバッチに含まれるサンプルを
\boldsymbol{y}_{b_1},...,\boldsymbol{y}_{b_{B}}
とし、1サンプルあたりの$\sigma_j, \boldsymbol{n}_j$のサンプル数を$K$(Diffusion Batchsizeという)とすると、ミニバッチの損失関数は、
\mathcal{L}_{\mathrm{batch}} (\theta) = \frac{1}{BM} \sum_{j=1}^{B} \sum_{k=1}^{K} \lambda(\sigma_{b_j,k})\left\| \boldsymbol{D}_{\theta} (\boldsymbol{y}_{b_j} + \sigma_{b_j,k} \boldsymbol{n}_{b_j,k}, \sigma_{b_j,k}) - \boldsymbol{y}_{b_j} \right\|^2
のように計算される。
EDMの学習済みDenoiserを使った生成
ODE Sampler
データから学習したDenoiser$\boldsymbol{D}_{\theta}$を使って、ノイズ分布からのサンプルを基にデノイズによってデータ分布からのサンプリング(生成)を行う。
Probabilistic Flow ODEは
d \boldsymbol{X}_t = \left( \frac{\dot{s}(t)}{s(t)} \boldsymbol{x} - s(t)^2 \dot{\sigma}(t) \sigma(t) \nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x}) \right)dt
であり、スコア関数$\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x})$は理想的なDenoiser$\boldsymbol{D}$から計算できて、
\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x}) = \frac{\boldsymbol{D} \left( \frac{\boldsymbol{x}}{s(t)} ;\sigma(t)\right) - \frac{\boldsymbol{x}}{s(t)}}{\sigma(t)^2 s(t)}
だったので、データから学習したDenoiser$\boldsymbol{D}_{\theta}$で置き換えてProbabilistic Flow ODEを表現すると、
d \boldsymbol{X}_t = \left( \left( \frac{\dot{s}(t)}{s(t)} + \frac{\dot{\sigma}(t)}{\sigma(t)} \right) \boldsymbol{X}_t - s(t) \frac{\dot{\sigma}(t)}{\sigma(t)} \boldsymbol{D}_{\theta} \left( \frac{\boldsymbol{X}_t}{s(t)} ;\sigma(t)\right) \right)dt
となる。
EDMでは$s(t), \sigma(t)$の最適な形として、
\begin{eqnarray}
s(t) &=& 1 \\
\sigma(t) &=& t
\end{eqnarray}
という最もシンプルな関数を提示している。この時、Probabilistic Flow ODEは
d \boldsymbol{X}_t = \frac{\boldsymbol{X}_t - \boldsymbol{D}_{\theta} (\boldsymbol{X}_t, t)}{t} dt
となる。
データの生成では、時刻$t$を離散化してODEをシミュレーションする。ODEシミュレーションのステップサイズを$N$として、時刻を$t_0,t_1,...,t_N$と離散化する。Probabilistic Flow ODEは時刻$t$について正の方向がノイズの加わる方向であるので、デノイズによる生成ではほぼノイズ分布である時刻$t=T$を初期値として、データ分布である時刻$t=0$に向かって、時刻について負の方向にシミュレーションする。つまり、
t_0=T > t_1 > t_2 >...>t_{N-1}>t_N=0
とする。$\sigma(t) = t$とするので、$t$と同じ値で$\sigma$も離散化して$\sigma_i = t_i$と表す。
EDMでは、この$\sigma_i = t_i$の決め方について、等間隔でなく
\sigma_i = \begin{cases}
\left( {\sigma_{\mathrm{max}}}^{\frac{1}{\rho}} + \frac{i}{N-1} \left({\sigma_{\mathrm{min}}}^{\frac{1}{\rho}} - {\sigma_{\mathrm{max}}}^{\frac{1}{\rho}}\right)\right)^{\rho} & (i < N)\\
0 & (i=N)
\end{cases}
の形で設定している。
$\rho$が大きいほど、$\sigma$が大きい範囲でのステップ数が少なく、ある程度小さな$\sigma$まで一気にデノイズし、$\sigma$が小さい範囲でより細かくシミュレーションする設定になる。
$\sigma_{\mathrm{max}}=\sigma_0=t_0=T$, $\sigma_{\mathrm{min}}=\sigma_{N-1}=t_{N-1}$である。
$\rho, \sigma_{\mathrm{max}}, \sigma_{\mathrm{min}}$はいずれも生成のタイムステップ=ノイズスケジュールを決めるパラメータである。Denoiserはタイムステップで定義された$\sigma_i$でなく、連続値の分布$p_{\sigma}$からランダムサンプリングした$\sigma \sim p_{\sigma}$に対して、$\boldsymbol{D}(\boldsymbol{y}+\sigma\boldsymbol{n}, \sigma)$が$\boldsymbol{y}$に近くなるように学習しているため、タイムステップ=ノイズスケジュールの決め方について学習時一致させる必要があるといったような制約はない。ただし、Denoiserが学習した$\sigma$の範囲内でDenoiserの性能が発揮されると考えると、$\sigma_{\mathrm{max}}, \sigma_{\mathrm{min}}$が$p_{\sigma}$でサンプリングされる範囲に含まれるよう、$\sigma_{\mathrm{max}}, \sigma_{\mathrm{min}}$あるいは$p_{\sigma}$を設定することになる。
また、生成においては、時刻$t=T$における分布$p_t$が$t=0$のデータ分布$p_{\mathrm{data}}$を無視できるほどノイズ分布(正規分布)に近づいていると仮定した上で、ノイズ分布$N(\boldsymbol{0}, \sigma(T)^2 \mathbf{I}_d)$からサンプリングしたデータを$t_0=T$の初期値としてProbabilistic Flow ODEで時刻$t_N=0$までシミュレーションする。$s(t)=1, \sigma(t)=t$であれば、
p_T(\boldsymbol{x}) = \left[ p_{\mathrm{data}} * N(\boldsymbol{0}, {\sigma_{\mathrm{max}}}^2 \mathbf{I}_d)\right] (\boldsymbol{x})
となり、データ分布$p_{\mathrm{data}}$に従う確率変数の各次元の標準偏差が概ね等しく$\sigma_{\mathrm{data}}$であるとすれば、時刻$T$の分布$p_T$に従う確率変数の各次元の標準偏差は$\sqrt{{\sigma_{\mathrm{data}}}^2 + {\sigma_{\mathrm{max}}}^2}$となる。
$\sigma_{\mathrm{data}} << \sigma_{\mathrm{max}}$であれば、$p_T$において$p_{\mathrm{data}}$の影響は無視できるほど小さくなり、$p_T \simeq N(\boldsymbol{0}, {\sigma_{\mathrm{max}}}^2 \mathbf{I}_d)$となる。
そのため、$\sigma_{\mathrm{data}}$のスケールを適切に見積った上で$\sigma_{\mathrm{data}} << \sigma_{\mathrm{max}}$になるように$\sigma_{\mathrm{max}}$を設定する必要がある。
EDMでは画像の生成を例として、画像の各ピクセルの輝度0~255を[-1, 1]に規格化した上で、$\sigma_{\mathrm{data}}=0.5$としている。その上で、
\begin{eqnarray}
\sigma_{\mathrm{max}} &=& 80 \\
\sigma_{\mathrm{min}} &=& 0.002 \\
\rho &=& 7 \\
\end{eqnarray}
と設定している。学習時の$\sigma$の分布$p_{\sigma}$は、対数正規分布
LN(P_{\mathrm{mean}}=-1.2, P_{\mathrm{std}}=1.2)
としている。
この設定で生成のノイズスケジュールの$\sigma_0,...\sigma_{N-1}$をヒストグラムで表すと以下のようになる。
そして、学習時の$\sigma$の分布(対数正規分布)の確率密度関数は以下になる。
EDMの設定は学習時の$\sigma$の分布と生成のノイズスケジュールで使う$\sigma_i$の分布が近くなるように設計されていることがわかる。
AlphaFold3でもノイズスケジュールの式はEDMと同じである。ただし、タンパク質等の全重原子の3次元座標(単位Å)を生成対象としており、パラメータには以下に相当する値が設定されている。
\begin{eqnarray}
\sigma_{\mathrm{data}} &=& 16 \\
\sigma_{\mathrm{max}} &=& 2560 \\
\sigma_{\mathrm{min}} &=& 0.0064 \\
\rho &=& 7 \\
\end{eqnarray}
学習時の$\sigma$の分布は
\sigma_{\mathrm{data}} \cdot LN(P_{\mathrm{mean}}=-1.2, P_{\mathrm{std}}=1.5)
としている。つまり対数正規分布$LN(P_{\mathrm{mean}}=-1.2, P_{\mathrm{std}}=1.5)$からサンプリングした値に$\sigma_{\mathrm{data}} = 16$をかけた値を$\sigma$として学習する。
このタイムステップ=ノイズスケジュールを使って、$N(\boldsymbol{0}, {\sigma_{\mathrm{max}}}^2 \mathbf{I}_d)$からランダムサンプリングした$\boldsymbol{x}_0$を初期値として、ODE
\frac{d \boldsymbol{x}}{dt} = \frac{\boldsymbol{x} - \boldsymbol{D}_{\theta} (\boldsymbol{x}, t)}{t}
で$N$ステップの離散シミュレーションをして、$\boldsymbol{x}_N$が学習したDenoizerを使ってデータ分布から生成したサンプルとなる。
ODEの離散サンプリングには様々な方法があるが、EDMではごく単純なEular methodでなく、Heun’s 2nd order methodの使用を提案している。Heun’s 2nd order methodは、以下のように2回微分係数を計算した平均で$\boldsymbol{x}_i$を更新する。
\begin{eqnarray}
&& \boldsymbol{d}_i = \frac{\boldsymbol{x}_i - \boldsymbol{D}_{\theta} (\boldsymbol{x}_i, t_i)}{t_i} \\
&& \boldsymbol{x}'_{i+1} = \boldsymbol{x}_i + (t_{i+1} - t_{i}) \boldsymbol{d}_i \\
&& \boldsymbol{d}'_i = \frac{\boldsymbol{x}'_{i+1} - \boldsymbol{D}_{\theta} (\boldsymbol{x}'_{i+1}, t_{i+1})}{t_{i+1}} \\
&& \boldsymbol{x}_{i+1} = \boldsymbol{x}_i + (t_{i+1} - t_{i}) \left(\frac{1}{2} \boldsymbol{d}_i + \frac{1}{2} \boldsymbol{d}'_{i}\right)
\end{eqnarray}
ただし、$t_N=0$なので$i=N-1$のステップは分母が$0$になるので$\boldsymbol{d}'_{N-1}$は計算せず、
\boldsymbol{x}_{N} = \boldsymbol{x}'_{N}
で代替する。
Stochastic Sampler
Probabilistic Flow ODEを用いたサンプリングは、初期値のノイズ分布からのサンプルが決まると生成結果が決まるDetarministicな過程になる。
これよりProbabilistic Flow ODEによるDetarministicな離散シミュレーションは、その離散化に由来する誤差によって生成性能が低くなる。
EDMではさらに生成性能を高めるために、シミュレーションにノイズ付与による微調整を含めることで生成性能を高める工夫を施したStochastic Samplerを提案している。
$s(t)=1, \sigma(t)=t$の場合、Diffusionモデルの確率微分方程式の$f(t), g(t)$は
\begin{eqnarray}
f(t) &=& 0 \\
g(t) &=& \sqrt{2t}
\end{eqnarray}
であるので、Diffusionモデルの確率微分方程式は
d \boldsymbol{X}_t = \sqrt{2t} d \boldsymbol{W}_t
となる。ある$0 \le t_i < \hat{t}_i \le T$について、この確率微分方程式に従って時刻$t_i$から$\hat{t}_i$まで時間展開することを考える。
\begin{eqnarray}
\boldsymbol{X}_{\hat{t}_i} &=& \boldsymbol{X}_{t_i} + \int_{t_i}^{\hat{t}_i} \sqrt{2t} d \boldsymbol{W}_t \\
&=& \boldsymbol{X}_{t_i} + \lim_{\Delta \rightarrow 0} \sum_{j=1}^{n} \sqrt{2\tau_j} \left(\boldsymbol{W}_{\tau_{j+1}} - \boldsymbol{W}_{\tau_{j}} \right)
\end{eqnarray}
となる。ただし、
\tau_1=t_i, \tau_n=\hat{t}_i, \Delta = \sup_{j} (\tau_{j+1} - \tau_{j})
としている。ここで、
\sqrt{2\tau_j}(\boldsymbol{W}_{\tau_{j+1}} - \boldsymbol{W}_{\tau_{j}}) \sim N(\boldsymbol{0}, 2\tau_i(\tau_{j+1} - \tau_{j}))
であり、
\boldsymbol{W}_{\tau_2} - \boldsymbol{W}_{\tau_1}, \boldsymbol{W}_{\tau_3} - \boldsymbol{W}_{\tau_2}, ... , \boldsymbol{W}_{\tau_n} - \boldsymbol{W}_{\tau_{n-1}}
は独立なので、
\sum_{j=1}^{n} \sqrt{2\tau_j} \left(\boldsymbol{W}_{\tau_{j+1}} - \boldsymbol{W}_{\tau_{j}} \right) \sim N\left(\boldsymbol{0}, \sum_{j=1}^n 2\tau_i(\tau_{j+1} - \tau_{j})\right)
である。分散について極限をとり、
\begin{eqnarray}
\lim_{\Delta \rightarrow 0} \sum_{j=1}^n 2\tau_i(\tau_{j+1} - \tau_{j}) &=& \int_{t_i}^{\hat{t}_i} 2\tau d\tau
&=& {\hat{t}_i}^2 - {t_i}^2
\end{eqnarray}
であるから、
\boldsymbol{X}_{\hat{t}_i} = \boldsymbol{X}_{t_i} + \sqrt{{\hat{t}_i}^2 - {t_i}^2} \boldsymbol{\varepsilon}_i, \, \boldsymbol{\varepsilon}_i \sim N(\boldsymbol{0}, \mathbf{I}_d)
となる。この$t_i$から$\hat{t}_i$への時間展開は確率微分方程式の離散近似でなく、Wiener過程からの解析的な導出であるので、この式に従った時刻の正の方向へのシミュレーションは離散化に由来する誤差を含まない。
EDMのStochastic Samplerは元のDiffusionの確率微分方程式に由来するこの式の時刻の正の方向への確率的シミュレーションと、Probabilistic Flow ODEによる時刻の負の方向への決定論的シミュレーションを交互に行う。これによりProbabilistic Flow ODEの決定論的シミュレーションで発生する離散化の誤差を、元のDiffusionの確率微分方程式に基づいた確率的シミュレーションで緩和し、離散化誤差が蓄積して生成の性能が低下することを防いでいる。
更新式は以下のようになる。
\begin{eqnarray}
&& \hat{t}_i = (1 + \gamma_i) t_i \\
&& \hat{\boldsymbol{x}}_i = \boldsymbol{x}_i + \sqrt{{\hat{t}_i}^2 - {t_i}^2} \boldsymbol{\varepsilon}_i, \, \boldsymbol{\varepsilon}_i \sim N(\boldsymbol{0}, \mathbf{I}_d)\\
&& \boldsymbol{d}_i = \frac{\hat{\boldsymbol{x}}_i - \boldsymbol{D}_{\theta} (\hat{\boldsymbol{x}}_i, \hat{t}_i)}{\hat{t}_i} \\
&& \boldsymbol{x}'_{i+1} = \hat{\boldsymbol{x}}_i + (t_{i+1} - \hat{t}_i) \boldsymbol{d}_i \\
&& \boldsymbol{d}'_i = \frac{\boldsymbol{x}'_{i+1} - \boldsymbol{D}_{\theta} (\boldsymbol{x}'_{i+1}, t_{i+1})}{t_{i+1}} \\
&& \boldsymbol{x}_{i+1} = \hat{\boldsymbol{x}}_i + (t_{i+1} - \hat{t}_i) \left(\frac{1}{2} \boldsymbol{d}_i + \frac{1}{2} \boldsymbol{d}'_{i}\right)
\end{eqnarray}
ただし、$\gamma_i$は小さな正の数でサンプリング時のハイパーパラメータである。
EDMのStochastic Samplerが決定論的なProbabilistic Flow ODEの更新式と異なるのは最初の2行である。ここでは、$t_i$から$\hat{t}_i$へ時刻が正の方向へ、元のDiffusionの確率微分方程式に従って確率的にシミュレーションする。これによりわずかにノイズが付与される。そのあとは、時刻$t_i$の代わりに$\hat{t}_i$からのProbabilistic Flow ODEとして$\hat{\boldsymbol{x}}_i$から1ステップ分デノイズする。
EDMでは、$\gamma_i$の決め方は以下としている
\gamma_i = \begin{cases}
\min\left(\frac{S_{\mathrm{churn}}}{N}, \sqrt{2}-1\right) & t_i \in \left[S_{\mathrm{min}}, S_{\mathrm{max}}\right] \\
0 & \mathrm{otherwise}
\end{cases}
EDMでは時刻の正の方向への確率的なシミュレーションをchurnと呼んでいる。churnの大きさを$S_{\mathrm{churn}}$でハイパーパラメータとして定義し、それをサンプリングのステップ数$N$で割った値を基本的な$\gamma_i$の値としている。ただし、ステップ数$N$が小さい時に$\gamma_i$が大きくなりすぎないように$\sqrt{2}-1$以下であるようにしている。さらに、このchurnを行う$t_i$の範囲も制限し、その範囲外の場合は$\gamma_i=0$としてchurnを行わないようにしている。
ちなみに、AlphaFold3ではこのchurnを含むStochastic Samplerを使用しているが、1stepでDenoiserを2回計算する計算コストの高さからか、Heun’s 2nd order methodによるProbabilistic Flow ODEのシミュレーションは採用されていない。以下のようなごく単純なEular methodを使用している。
\begin{eqnarray}
&& \hat{t}_i = (1 + \gamma_i) t_i \\
&& \hat{\boldsymbol{x}}_i = \boldsymbol{x}_i + \lambda \sqrt{{\hat{t}_i}^2 - {t_i}^2} \boldsymbol{\varepsilon}_i, \, \boldsymbol{\varepsilon}_i \sim N(\boldsymbol{0}, \mathbf{I}_d)\\
&& \boldsymbol{d}_i = \frac{\hat{\boldsymbol{x}}_i - \boldsymbol{D}_{\theta} (\hat{\boldsymbol{x}}_i, \hat{t}_i)}{\hat{t}_i} \\
&& \boldsymbol{x}_{i+1} = \hat{\boldsymbol{x}}_i + \eta (t_{i+1} - \hat{t}_i) \boldsymbol{d}_i \\
\end{eqnarray}
$\lambda$はnoise scale、$\eta$はstep scaleというハイパーパラメータで、AlphaFold3では
\begin{eqnarray}
\lambda &=& 1.003 \\
\eta &=& 1.5
\end{eqnarray}
と設定されている。noise scaleはシミュレーション時に実際に加えるノイズを時刻$\sigma(t)=t$から計算される理論的な値よりもわずかに大きくする効果、step scaleは1stepのデノイズでの変化量をProbabilistic Flow ODEのEular methodによる離散化の理論的な値から定数倍分大きくするような効果を持つ。
Alphafold3の$\gamma_i$の設定は以下である。
\gamma_i = \begin{cases}
0.8 & t_i > 1.0 \\
0 & \mathrm{otherwise}
\end{cases}
Alphafold3では$\sigma_{\mathrm{data}}=16$を想定するので、ノイズの標準偏差がその16分の1の$\sigma_i = t_i \le 1.0$までデノイズで小さくなるとするステップまで進んだらchurnしないということになる。Alphafold3のchurnの$\gamma_i=0.8$はEDMの最大$\gamma_i=\sqrt{2}-1 \simeq 0.414$の設定に比べると大きい。
EDMのDenoiserの設計
Denoiserは、Deep Neural Network(MLP, CNN, Transfomerなど)で作り、データを用いて損失関数の勾配で学習することになるが、EDMではDenoiserをそのまま$\boldsymbol{x}, \sigma$を入力としたDeep Neural Networkに置き換えるのではなく、特に$\sigma$への依存性を考慮して工夫した設計を行っている。
\boldsymbol{D}_{\theta} (\boldsymbol{x}; \sigma) = c_{\mathrm{skip}}(\sigma) \boldsymbol{x} + c_{\mathrm{out}} (\sigma) \boldsymbol{F}_{\theta} (c_{\mathrm{in}} (\sigma) \boldsymbol{x}; c_{\mathrm{noise}} (\sigma))
という形である。
$\boldsymbol{D}_{\theta} (\boldsymbol{x}; \sigma)$は、$\boldsymbol{x}$がデータに標準偏差$\sigma$の正規分布ノイズを加えられたものであるとしてノイズを加えられる前のデータを予測するDenoiserである。なので$\sigma$によるがある程度$\boldsymbol{x}$と近いはずである。それが第一項のスキップコネクションで表現される。
第二項は$\boldsymbol{x}, \sigma$を基に$\boldsymbol{x}$をデノイズする方向を決めるDeep Neural Network部分である。$\boldsymbol{F}_{\theta}$がDeep Neural Networkで、$\sigma$に依存した係数
c_{\mathrm{out}}(\sigma), c_{\mathrm{in}}(\sigma), c_{\mathrm{noise}}(\sigma)
によって、Deep Neural Network部分の$\sigma$依存性をコントロールしている。
この時、Denoiser学習の損失関数は
\begin{eqnarray}
\mathcal{L} (\theta) &=& \mathbb{E}\left[ \lambda(\sigma) \left\| \boldsymbol{D}_{\theta} (\boldsymbol{y} + \boldsymbol{n}, \sigma) - \boldsymbol{y} \right\|^2 \right] \\
&=& \mathbb{E}\left[ \lambda(\sigma) \left\| c_{\mathrm{skip}}(\sigma) (\boldsymbol{y} + \boldsymbol{n}) + c_{\mathrm{out}} (\sigma) \boldsymbol{F}_{\theta} (c_{\mathrm{in}} (\sigma) (\boldsymbol{y} + \boldsymbol{n}) ; c_{\mathrm{noise}} (\sigma)) - \boldsymbol{y} \right\|^2 \right] \\
&=& \mathbb{E}\left[ \lambda(\sigma) {c_{\mathrm{out}} (\sigma)}^2 \left\| \boldsymbol{F}_{\theta} (c_{\mathrm{in}} (\sigma) (\boldsymbol{y} + \boldsymbol{n}) ; c_{\mathrm{noise}} (\sigma)) - \frac{\boldsymbol{y} - c_{\mathrm{skip}}(\sigma) (\boldsymbol{y} + \boldsymbol{n})}{c_{\mathrm{out}} (\sigma)} \right\|^2 \right] \\
&=& \mathbb{E}\left[ w(\sigma) \left\| \boldsymbol{F}_{\theta} (c_{\mathrm{in}} (\sigma) (\boldsymbol{y} + \boldsymbol{n}) ; c_{\mathrm{noise}} (\sigma)) - \boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma) \right\|^2 \right]
\end{eqnarray}
となる。ただし、
\begin{eqnarray}
w(\sigma) &=& \lambda(\sigma) {c_{\mathrm{out}} (\sigma)}^2 \\
\boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma) &=& \frac{1}{c_{\mathrm{out}} (\sigma)} (\boldsymbol{y} - c_{\mathrm{skip}}(\sigma) (\boldsymbol{y} + \boldsymbol{n}))
\end{eqnarray}
としている。また、損失関数の期待値$\mathbb{E}$は
\boldsymbol{y} \sim p_{\mathrm{data}},\boldsymbol{n} \sim N(\boldsymbol{0}, \sigma^2 \mathbf{I}_d), \sigma \sim p_{\mathrm{train}}
に対して計算される。
まず、Deep Neural Network$\boldsymbol{F}_{\theta}$の入力の分散は一定であることが望ましいので、
\mathrm{Var}[c_{\mathrm{in}} (\sigma) (\boldsymbol{y} + \boldsymbol{n})] = \mathbf{I}_d
を要請する。
\mathrm{Var}[\boldsymbol{y}] = {{\sigma}_{\mathrm{data}}}^2 \mathbf{I}_d,\, \mathrm{Var}[\boldsymbol{n}] = {\sigma}^2 \mathbf{I}_d
を想定しているので、この要請により、
c_{\mathrm{in}} (\sigma) = \frac{1}{\sqrt{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2}}
が得られる。
次に、Deep Neural Network$\boldsymbol{F}_{\theta}$の教師ラベルの分散も一定であることが望ましいので、
\mathrm{Var}[\boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma)] = \mathbf{I}_d
を要請する。
\begin{eqnarray}
\boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma) &=& \frac{1}{c_{\mathrm{out}} (\sigma)} (\boldsymbol{y} - c_{\mathrm{skip}}(\sigma) (\boldsymbol{y} + \boldsymbol{n})) \\
&=& \frac{1}{c_{\mathrm{out}} (\sigma)} ((1 - c_{\mathrm{skip}}(\sigma)) \boldsymbol{y} - c_{\mathrm{skip}}(\sigma)\boldsymbol{n})
\end{eqnarray}
なので、
\mathrm{Var}[\boldsymbol{y}] = {{\sigma}_{\mathrm{data}}}^2 \mathbf{I}_d,\, \mathrm{Var}[\boldsymbol{n}] = {\sigma}^2 \mathbf{I}_d
の想定と$\boldsymbol{n}$が$\boldsymbol{y}$とは独立な正規分布のノイズであることから、この要請により、
c_{\mathrm{out}} (\sigma) = \sqrt{(1 - c_{\mathrm{skip}}(\sigma))^2 {{\sigma}_{\mathrm{data}}}^2 + {c_{\mathrm{skip}}(\sigma)}^2 {\sigma}^2}
を得る。
この関係式を基に、各$\sigma$に対して、$c_{\mathrm{out}}$が最小になるように$c_{\mathrm{skip}}$を選ぶ。
\begin{eqnarray}
{c_{\mathrm{out}}}^2 &=& (1 - c_{\mathrm{skip}})^2 {{\sigma}_{\mathrm{data}}}^2 + {c_{\mathrm{skip}}}^2 {\sigma}^2 \\
&=& ({\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2) {c_{\mathrm{skip}}}^2 - 2 {{\sigma}_{\mathrm{data}}}^2 c_{\mathrm{skip}} + {{\sigma}_{\mathrm{data}}}^2 \\
&=& ({\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2) \left( c_{\mathrm{skip}} - \frac{{{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \right)^2 + \frac{{\sigma}^2 {{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2}
\end{eqnarray}
なので、この選び方で、
\begin{eqnarray}
c_{\mathrm{skip}} (\sigma) &=& \frac{{{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \\
c_{\mathrm{out}} (\sigma) &=& \frac{\sigma {\sigma}_{\mathrm{data}}}{\sqrt{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2}}
\end{eqnarray}
を得る。
$c_{\mathrm{out}}$を最小化する理由は、Deep Neural Network$\boldsymbol{F}_{\theta}$の学習に使う教師ラベル
\boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma) = \frac{1}{c_{\mathrm{out}} (\sigma)} ((1 - c_{\mathrm{skip}}(\sigma)) \boldsymbol{y} - c_{\mathrm{skip}}(\sigma)\boldsymbol{n})
に含まれるデータ$\boldsymbol{y}, \boldsymbol{n}$由来の信号を可能な限り減衰しないためである。$c_{\mathrm{out}}$が大きいほど、$\boldsymbol{y}, \boldsymbol{n}$のスケール、つまりDenoiserのスケールで考えた際にDeep Neural Network$\boldsymbol{F}_{\theta}$の誤差は増幅される。
最後に、$w(\sigma)=1$となるように$\lambda(\sigma)$を設定する。
\lambda(\sigma) = \frac{w(\sigma)}{{c_{\mathrm{out}} (\sigma)}^2} = \frac{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2 {{\sigma}_{\mathrm{data}}}^2}
とすればよい。
$w(\sigma)=1$を要請する根拠は、各$\sigma$における損失関数のスケールを少なくとも学習初期において$\sigma$によらず一定にすることである。$\sigma$ごとに損失関数のスケールが大きく異なる場合、様々な$\sigma$で平均を計算した場合に、平均損失において損失のスケールの大きい$\sigma$での損失が支配的になり、損失のスケールが小さい$\sigma$の学習が不十分になる可能性がある。
$w(\sigma)=1$の下では、$\boldsymbol{F}_{\theta} = \boldsymbol{0}$の初期状態を考えると
\begin{eqnarray}
&& \mathbb{E}\left[ w(\sigma) \left\| \boldsymbol{F}_{\theta} (c_{\mathrm{in}} (\sigma) (\boldsymbol{y} + \boldsymbol{n}) ; c_{\mathrm{noise}} (\sigma)) - \boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma) \right\|^2 \right] \\
&=& \mathbb{E}\left[ \left\| \boldsymbol{F}_{\mathrm{target}} (\boldsymbol{y}, \boldsymbol{n}; \sigma) \right\|^2 \right] \\
&=& \mathbb{E}\left[ \left\| \frac{\sqrt{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2}}{\sigma {\sigma}_{\mathrm{data}}} \left( \frac{{\sigma}^2}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \boldsymbol{y} + \frac{{{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \boldsymbol{n} \right) \right\|^2 \right] \\
&=& \frac{1}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \mathbb{E}\left[ \left\| \frac{\sigma}{{\sigma}_{\mathrm{data}}} \boldsymbol{y} + \frac{{\sigma}_{\mathrm{data}}}{\sigma} \boldsymbol{n} \right\|^2 \right] \\
&=& \frac{1}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \mathbb{E}\left[ \frac{{\sigma}^2}{{{\sigma}_{\mathrm{data}}}^2} \|\boldsymbol{y}\|^2 + \frac{{{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2} \|\boldsymbol{n}\|^2 + 2\langle \boldsymbol{y}, \boldsymbol{n}\rangle \right] \\
&=& \frac{1}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \mathbb{E}\left[ \frac{{\sigma}^2}{{{\sigma}_{\mathrm{data}}}^2} \|\boldsymbol{y}\|^2 + \frac{{{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2} \|\boldsymbol{n}\|^2 + 2\langle \boldsymbol{y}, \boldsymbol{n}\rangle \right] \\
&=& \frac{1}{{\sigma}^2 + {{\sigma}_{\mathrm{data}}}^2} \left( \frac{{\sigma}^2}{{{\sigma}_{\mathrm{data}}}^2} \mathrm{Var}[\boldsymbol{y}] + \frac{{{\sigma}_{\mathrm{data}}}^2}{{\sigma}^2} \mathrm{Var}[\boldsymbol{n}] + 2 \mathrm{Cov}( \boldsymbol{y}, \boldsymbol{n}) \right) \\
&=& 1
\end{eqnarray}
となる。これは各$\sigma$における損失関数のスケールを$\sigma$によらず一定にする望ましい性質となる。ただし、$\boldsymbol{y}$の分布は規格化されて平均が$\boldsymbol{0}$であると想定している。$\mathrm{Cov}( \boldsymbol{y}, \boldsymbol{n})$は$\boldsymbol{y}, \boldsymbol{n}$の共分散であり、$\boldsymbol{n}$が$\boldsymbol{y}$と独立なノイズなので常に$0$である。
注:EDMの論文のAppendixで上記の説明がされているが、実際はデータ分布では$d$次元の各要素は独立でないし、$\boldsymbol{n}$と違って$d$次元正規分布でもないので、
\mathbb{E}\left[\|\boldsymbol{y}\|^2\right] = \mathrm{Var}[\boldsymbol{y}] = {{\sigma}_{\mathrm{data}}}^2
とするのは正確でない。これは$\boldsymbol{y}$も$\boldsymbol{n}$のような$d$次元の正規分布のような分布と一旦仮定した大雑把な説明である。$w(\sigma)=1$というハイパーパラメータを設定する根拠としては十分と考えられる。
Reference
【EDM】Tero Karras, Miika Aittala, Timo Aila, Samuli Laine, "Elucidating the Design Space of Diffusion-Based Generative Models", 11 Oct 2022, https://arxiv.org/pdf/2206.00364
【Probabilistic Flow ODEを使ったスコア関数ベースのDiffusion model】 Yang Song, et al., "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS", ICLR 2021, https://arxiv.org/pdf/2011.13456
【AlphaFold3】 Josh Abramson, et al., "Accurate structure prediction of biomolecular interactions with AlphaFold 3", Nature 630, 493–500 (2024), https://www.nature.com/articles/s41586-024-07487-w