拡散モデルの最新の考え方についてまとめた記事になります。
入門①~④を書いた後にこの記事は書いており、入門①~③をまとめたような記事になっています。
この記事を読めば①~③は読む必要がなかったり…
・拡散モデル関係の記事
入門①:DDPMの理論とMNISTの実装
入門②:SDE/ODEの基礎理論(Tensorflow実装付き)
入門③:EDMの解説とMNISTの実装
入門④:条件付きU-Net(MNIST実装付き)
応用編:拡散モデルにGRPOを使ってファインチューニングしてみた
応用を作っていたら拡散モデルにもっと踏み込まないと理解できない部分があったので書きました。
この記事は拡散モデルを理解するところに重点を置いています。
なので曖昧な表現がある点はご了承ください。
TL;DR
この記事のまとめです。
・拡散モデルはウィーナー過程であり、SDE(PF-ODE)で表現できた
・これによりODEの強力な理論体系を使うことができ、いろいろ改善点が見えてきた
・改善点を元にいろいろ改善した
この記事を読むとこれが分かるようになります。
コード
今回はGoogleColabにも実装しました。
(学習を簡単にするために0のみ学習しています)
事前知識
拡散モデルを理解するための事前知識です。
これがなかったので私は苦労しました…
知っている人は読み飛ばしてください。
ODEソルバー(オイラー法)
数値計算とは、数学的な問題をコンピュータを用いて近似的に解く計算手法のことです。
その中でも、常微分方程式(ODE)を数値的に解く手法を ODEソルバー と呼びます。
本記事では、ODEソルバーの中で最も基本的な手法である オイラー法 について見ていきます。
数値計算やオイラー法についてはヨビノリたくみさんの以下の動画がとても参考になったので是非見てみてください。
(Youtube)数値計算の基本(微分方程式の扱い) (予備校のノリで学ぶ「大学の数学・物理」)
EDMではODEソルバーとしてルンゲ・クッタ法が使われていますが、これはオイラー法を改良した手法です。
本記事では理解を優先するためにオイラー法のみに注目します。
初期値問題と常微分方程式
ある関数$y$を$x$で微分した以下の式が与えられたとします。
$$
\frac{dy}{dx} = 2y
$$
この時に $y(1)=1$ の時、元の関数 $y(x)$ を求めよ、というような問題が初期値問題です。
この例では厳密解が計算できるので計算してみます。(出来ない場合もあり、拡散モデルでは出来ない場合のほうがメインになります)
計算過程
変形分離法より、$p(x)=2$、$q(y)=y$ とすると以下に変形できます。
\begin{align}
\int \frac{1}{q(y)}dy=&\int p(x)dx (変形分離法)\\
\int \frac{1}{y}dy=&\int 2dx \\
\ln |y| =& 2x + C \\
\end{align}
$$
y = e^{2x+C}
$$
積分によって積分定数$C$が現れるため、その値を求めます。
条件 $y(1)=1$ より、$C=-2$ になり以下が厳密解となります。
$$
y = e^{2x-2}
$$
ここで初期値問題として最初に登場した方程式は、常微分方程式(ODE; Ordinary Differential Equation)といい、一般的には以下の形式で表されます。
$$
\frac{dy}{dx} = f(y, x)
$$
・参考
https://ja.wikipedia.org/wiki/%E5%A4%89%E6%95%B0%E5%88%86%E9%9B%A2
https://ja.wikipedia.org/wiki/%E5%88%9D%E6%9C%9F%E5%80%A4%E5%95%8F%E9%A1%8C
変数分離形の微分方程式の解法と例題(高校数学の美しい物語)
オイラー法
次に厳密解ではなく、コンピューターで近似値を出す事を考えます。
まず初期値問題として、ある未知の関数$x$を$t$で微分した関数 $x'(t)$ が与えられているとし、初期値を $x(t_0) = x_0$ とします。
この時の $x(t_1) = x_1$ を予測してみます。
微分関数が与えられているので$t_0$時点での傾きを得ることができます。
その傾きを元に間隔$h$進んだ$t_1$での$x_1$を予測できます。
\begin{align}
t_1 =& t_0 + h \\
x_1 =& x_0 + h x'(t_0) \\
\end{align}
$t_0$,$x_0$を元に$x_1$の予測ができました。
次は$t_1$,$x_1$を元に$x_2$を予測・・・、と予測を繰り返していく手法がオイラー法です。
図は間隔が広いので誤差が大きいですが、間隔を狭くすれば未知の関数を再現できそうなことが分かるかと思います。
最後に数式を整理します。
$x'(t)$ を変数$x$と$t$からなる関数であることを明示的にするために、$x'(t)=f(x,t)$ と書きます。
($f$は微分された傾きを表す関数である点に注意)
時刻$t$のインデックスを$i$とした場合以下となります。
\begin{align}
t_{i+1} =& t_i + h \\
x_{i+1} =& x_i + h f(x_i,t_i) \\
\end{align}
・参考
オイラー法をわかりやすく解説(高校数学の美しい物語)
後退オイラー法
上記オイラー法は前進オイラー法ともいい、現在の値を使って次の値を直接計算する方法です。
一方、後退オイラー法は、次の値を含む方程式を解いて求める方法となります。
\begin{align}
t_{i+1} =& t_i + h \\
x_{i+1} =& x_i + h f(x_{\color{red}{i+1}}, t_{\color{red}{i+1}}) \\
\end{align}
もし逆方向の傾きが分かるなら以下のように簡単に逆算できます。
\begin{align}
t_i =& t_{i+1} - h \\
x_i =& x_{i+1} - h f(x_{i+1}, t_{i+1}) \\
\end{align}
確率微分方程式(SDE; Stochastic Differential Equations)
ここではSDEがODEで表せることが分かる事を目標に解説します。
確率過程を含んだ微分方程式を確率微分方程式といい、一般的には以下の形で表されます。
$$
dX_t = f(X_t, t) dt + g(X_t, t) dW_t
$$
この式を理解していきます。
Xt
まず$X$ですがこれは確率変数を指します。
確率変数は例えばサイコロを考えた場合、各出目のことです。
サイコロの確率を $P(X)$ と置いた場合、$P(2)=\frac{1}{6}$ (2が出る場合の確率)のように確率が決まる値 $X$ を確率変数と言います。
また$X_t$は確率過程を指し、確率変数が時間とともに変化する事を指します。
サイコロの例だと時間と共にサイコロが6面ダイス→8面ダイス→16面ダイスみたいに変化して出せる目が変わっていく感じですかね。
もし確率過程 $X_t$ が確定的なら、それは常微分方程式と同じ $X_t = x$ と見ることができます。
この記事では $x$ を確率過程と意識して扱わなければならない場面はないため、単に $x$ と見ても問題ないかと思います。
ランダムウォーク
ウィーナー過程の前にランダムウォークについて話します。
アイシアさんの以下の動画がとても分かりやすく参考にしています。
(Youtube)【ブラックショールズ方程式への道①】ランダムウォークとブラウン運動【確率微分方程式の基礎】 #VRアカデミア #039
ランダムウォークの最も単純な例として、ある点が時間とともに確率50%で上か下に動く場合を考えます。
実際に20step動かしてみた図が以下です。
この動きには以下の性質があります。(各ステップで移動する量を $\sigma$ と置く)
傾き(微分):$\pm \sigma$
平均:0(期待値が0なので)
分散:$\sigma^2 + \sigma^2 + ... \sigma^2 = \sigma^2t$(時間経過と共に分散が増えていきます)
これを踏まえてウィーナー過程を見てみます。
ウィーナー過程(ブラウン運動)
元々はブラウン運動と呼ばれる物理現象があり、それを数学的にモデル化したものがウィーナー過程となります。
ランダムウォークでは時間の間隔が離散でした。
ウィーナー過程は上記ランダムウォークの時間間隔を極限まで短くして連続的にしたものになります。
これを計算すると1step毎にガウス分布の乱数で動くグラフになるようです。
※ウィーナー過程は連続値ですが、図は任意の間隔で切り取って離散化しています。
ウィーナー過程の性質は以下です。
傾き(微分):なし(時間を短くするほど変化が激しくなる性質があり、瞬間的な傾きを求めると値が発散する)
平均:0
分散:$\sqrt{t-s}$($t-s$は切り取った間隔)
ここではガウス分布の乱数で動くという事が重要となります。
ちなみにステップ数を増やすと以下みたいな感じになります。(後で似た図が出てきたり)
常微分方程式におけるdtの変形
SDEではウィーナー過程があるので微分 $\frac{dx}{dt}$ が扱えません。
なので常微分方程式を変形しておきます。
\begin{align}
\frac{dx}{dt} &= f(x, t) \\
dx &= f(x, t)dt \\
\end{align}
問題はこの意味ですが、まず微分 $\frac{dx}{dt}$ について $dx$ はxの変化量、$dt$ は$t$の変化量と解釈できます。
日本語で表すと以下です。
\begin{align}
\frac{xの変化量}{tの変化量} &= 傾き \\
xの変化量 &= 傾き×tの変化量 \\
\end{align}
図だとこんなイメージ。
$f(x, t)dt$ は$x$の変化量となります。
dWtのd
$W_t$はウィーナー過程を指します。
$d$は微分で見られる変化量を表す意味の$d$となり、$dW_t$ はウィーナー過程の変化量を表します。
拡散係数 g(Xt,t)
拡散係数はウィーナー過程のランダム要素を制御する項となります。
最も単純な拡散係数は $g(X_t, t)=\sigma$ でこの場合は標準ブラウン運動と呼ぶようです。
SDEまとめ
\begin{align}
dX_t &= f(X_t, t) dt + g(X_t, t) d W_t \\
xの変化量 &= (傾き×tの変化量)+ 拡散係数×ウィーナー過程の変化量 \\
\end{align}
$f(X_t, t) dt$ はドリフト項、$g(X_t, t) d W_t$ は拡散項とも言います。
オイラー法の図で見ると以下のイメージです。
ランダム性があるODE(PF-ODE; Probability Flow ODE)と見ることができますね。
・参考
初期値問題(Wikipedia)
確率微分方程式(Wikipedia)
ウィーナー過程(Wikipedia)
導関数 dy/dx を分数扱いする(Qiita)
【ブラウン運動の数理】ランダムウォークからブラウン運動へ(ケィオスの時系列解析メモランダム)
拡散モデルの確率微分方程式(henatips)
拡散モデルの謎の解明
NVIDIAが公開している上記ブログをメインに拡散モデルについて見ていきます。
(翻訳ではなく解釈した上で独自に書いているので注意)
一応上記ブログと関係のあるEDMの論文も張っておきます。
https://arxiv.org/abs/2206.00364
拡散モデルとウィーナー過程
今までの説明でSDEが時間方向にガウスノイズが追加されていく様子(ブラウン運動)を表している事が分かったかと思います。
これを拡散モデルで表すと以下です。
図はブログより、徐々にガウスノイズを追加していく様子(ランダムウォーク)
図でオレンジの線は画像のある1ドットが実際に取った値です。
この軌跡を複数の画像で見てみると以下のようになります。
図はブログより、画像データのランダムな軌道が分布の密度を形成する様子
ガウス分布っぽい色の濃淡が出てきました。
図の左端の特徴的なデータ(例えば猫とか犬など)は徐々に混ざり合っていき、右端で特徴のないただのノイズ(ガウス分布)に収束しています。
今後の説明のためにオイラー法による対応も書いておきます。(あまり自信がないので私の認識ということで…)
未知の関数ですが、拡散モデルにおけるSDEでは決定的な部分がガウスノイズの平均になるので、ガウスノイズの平均を表す関数になります。
次に画像生成ですが、画像生成では右端から左端の点にちゃんと移動できれば画像が生成できることになります。
これは可能でしょうか?
ブログではSDEを使うと、時間の進行方向を反転させることができ、その結果としてデータを引き寄せるための追加の項(力)が自然に現れるということでした。
話を簡単にするために逆方向は決定的として考えます。(拡散項を除いてドリフト項のみを仮定する)
その場合は乱数がないのでなめらかな曲線で元の画像を生成できます。
さて、ここで曲線の曲がりが大きいと精度が落ちることが予想されます。
図はブログより、直線が曲線の近似値としては不十分になる可能性を表した図
オイラー法では傾きで次の地点を予測するので曲がっていると誤差がひどくなります。
これを改善する手法をブログでは3つ挙げています。
- 曲線を直線化する
- 離散化する間隔を検討する
- 高次元のソルバー(オイラー法ではなくルンゲ=クッタ法など)を検討する
高次元ソルバーについては本記事では省略し、上2つを見ていきます。
曲線を直線化する
オイラー法で分かる通り、求めたい関数が直線なら話は簡単です。
これはノイズのスケジュールで調整できます。
EDMでは直線に調整することで生成ステップの回数を大幅に減らす事ができます。
離散化する間隔の検討
傾きをなるべく直線にしましたが左側に近づくにつれて曲線がどうしても現れます。
図はブログより、右から左にかけて時間ステップを短くした方がいい様子
画像の例のように右側から左側にかけて山が1つから2つに変化する場合、その山に沿って移動する必要があるのでどうしても曲がりが発生してしまいます。
逆に言えば右側ではほぼ直線なので速度を上げて移動できます。
左側に近づくほどゆっくり進む事で移動速度を改善します。
EDMでは以下のようにステップ間隔を調整しています。
$$
t_i = \left( \sigma_{\text{max}}^{\frac{1}{\rho}} + \frac{i}{N-1} \left( \sigma_{\text{min}}^{\frac{1}{\rho}} - \sigma_{\text{max}}^{\frac{1}{\rho}} \right) \right)^{\rho}
$$
def create_timesptes(N: int, sigma_min=0.002, sigma_max=80, rho=7):
i = np.arange(N)
t = (sigma_max ** (1 / rho) + i / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t = np.append(t, 0) # 最後に0を追加
return t
学習
最もシンプルに学習コードを書くと以下です。
for clean_image in training_data:
# ノイズレベルをランダムに選択
sigma = np.random.uniform(0, 80)
# ノイズ入り画像を作成
noisy_image = clean_image + sigma * torch.randn_like(clean_image)
# ノイズ入り画像+ノイズレベルより、ノイズ除去画像を予測
denoised_image = denoise(noisy_image, sigma)
# ノイズ除去画像が元の画像に近づくように学習
loss = (denoised_image - clean_image).square().sum()
オイラー法の図では以下です。
モデルは最終的にガウスノイズの平均を学習します。
実際に学習済みモデルからdenoiseされた画像を見てみると、ノイズレベルが大きいほどぼやけた(平均的な)画像が生成されるのを確認できます。
以降はこの学習方法の改善内容です。
ネットワークの学習と改善
一般的な拡散モデルは以下のように画像を予測します。
- ネットワークがノイズ入り画像からノイズを予測する
- 予測したノイズを現ステップのノイズレベルにスケールする
- ノイズ入り画像から現ステップのノイズを引くことでノイズを除去する
式だと以下のイメージです。
denoised_image = noisy_image - sigma * network_output
これはノイズレベルが低い場合、ネットワークは少しの変化だけを学習すればよく、またネットワークの予測が間違っていてもノイズレベルのスケールによりその影響が小さくなるので良いアプローチです。
しかしノイズレベルが高い場合、入力画像から本来の画像情報がほぼ失われ、ネットワークの予測誤差が大きくなります。
さらにノイズレベルのスケールで誤差がより大きくなるので悪いアプローチです。
この問題を解決するために入力画像とネットワークの出力を混ぜて学習します。
- ネットワークに「ノイズと綺麗な画像の混合状態」を学習させる
- 低ノイズのときは「入力画像をより多く活用」し、高ノイズのときは「ネットワークの出力に頼る」ようにする
denoised_image = c_skip * noisy_image + c_out * network_output
$c_{skip}$ は入力画像の再利用量を調整する係数で、ノイズが少ないときには大きく、ノイズが多い時には小さくします。
$c_{out}$ はネットワークの出力を調整する係数で、ノイズが少ないときには小さく、ノイズが多い時には大きくします。
具体的には以下です。
$$
c_{skip} = \frac{ \sigma_{\text{data}}^{2} }{\sigma^2 + \sigma_{\text{data}}^{2}}
$$
$$
c_{out} = \sigma \frac{\sigma_{\text{data}}}{ \sqrt{\sigma^2 + \sigma_{\text{data}}^{2}}}
$$
sigma_data = 0.5 # 学習データの標準偏差(ImageNetのおおよその標準偏差が0.5)
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
c_out = sigma * sigma_data / tf.sqrt(sigma**2 + sigma_data**2)
sigma_data
はImageNetのおおよその標準偏差らしいです。
MNISTだともうちょっと高く0.6ぐらいでした。
from tensorflow import keras
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
x_train = (x_train / 255.0) * 2 - 1
print(np.std(x_train)) # 0.6162156077129245
入力の正規化
入力ですが、そのままだと分散80^2というとても大きな値のデータが来て学習が不安定になるのでスケールさせます。
また同様にノイズも-1から1の範囲で均等にばらけるように対数でスケールさせます。
sigma_data = 0.5 # 学習データの標準偏差
def denoise(noisy_image, sigma):
# ノイズ後の画像の分散 = ノイズの分散 + 元データの分散
noisy_image_variance = sigma**2 + sigma_data**2
# 標準偏差1に正規化
scaled_noisy_image = (noisy_image / noisy_image_variance) ** 0.5
# ノイズのスケール: ln(σ)/4
c_noise = np.log(sigma) / 4
return net(scaled_noisy_image, c_noise)
学習の分散と損失の正規化
学習ですが、重要な点を重点的に学習をすることは学習コストを下げることに繋がるので重要です。
拡散モデルでは損失のスケール(図の真ん中)と、サンプリングの場所(図の右側)が重要になります。
まず損失のスケールですが、損失の大きさ(矢印の長さ)と頻度(矢印の数)はノイズレベルに依存します。
EDMでは以下の式で標準化して揃えています。
$$
\lambda(\sigma) = \frac{\sigma^2 + \sigma_{\text{data}}^2}{\sigma \cdot \sigma_{\text{data}}^{2}}
$$
次にサンプリングの場所(ノイズレベル)ですが、ノイズレベルが低い場所はノイズの予測ができないのであまり有意義ではありません。
逆にノイズレベルが高すぎると復元できる情報がほとんどないのでこれも有意義ではありません。
中間あたりのノイズレベルが学習に最適となります。
具体的な計算方法ですが、経験則によるところが大きいようです。
EDMではガウス分布の乱数を対数にして使っています。
P_mean = -1.2
P_std = 1.2
sigma = torch.exp(P_mean + P_std * torch.randn([]))
学習まとめ
sigma_data = 0.5 # 学習データの標準偏差
def denoise(noisy_image, sigma):
# ノイズ画像を正規化
c_in = 1 / torch.sqrt(sigma_data**2 + sigma**2)
scaled_noisy_image = c_in * noisy_image
# sigmaを正規化
c_noise = torch.log(sigma) / 4
# 入力画像とネットワーク出力を元にノイズ除去画像を生成
c_out = sigma * sigma_data / torch.sqrt(sigma**2 + sigma_data**2)
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
network_output = net(scaled_noisy_image, c_noise)
return c_skip * noisy_image + c_out * network_output
for clean_image in training_data:
# ノイズレベルをサンプリング
P_mean = -1.2
P_std = 1.2
sigma = torch.exp(P_mean + P_std * torch.randn([]))
# ノイズ入り画像を生成
noisy_image = clean_image + sigma * torch.randn_like(clean_image)
# ノイズ除去画像を作成
denoised_image = denoise(noisy_image, sigma)
# 損失を計算(ノイズ除去画像が元の画像になるように学習)
weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data)**2
loss = weight * (denoised_image - clean_image).square().sum()
# ... lossを元に誤差逆伝播でモデルをトレーニング
生成
最後に生成です。
(後退)オイラー法をそのまま使う方法が一番シンプルです。
num_steps = 50
timesteps = create_timesptes(num_steps) # 80~0のノイズレベルを生成
x = torch.randn(img_shape) * timesteps[0] # 最初のノイズレベルで画像を生成
for t_curr in timesteps[:-1]:
x = denoise(x, t_curr)
図示すると以下です。
・実際の生成過程
このオイラー法は決定的なので次に乱数を入れた生成方法を紹介します。
拡散項を追加するだけですね。
num_steps = 50
timesteps = create_timesptes(num_steps) # 80~0のノイズレベルを生成
x = torch.randn(img_shape) * timesteps[0] # 最初のノイズレベルで画像を生成
for t_curr, t_next in zip(timesteps[:-1], timesteps[1:]):
e = torch.randn(img_shape) * t_curr # 拡散項
x = denoise(x, t_curr) + e
・実際の生成過程(上がdenoise後の画像、下がノイズ追加後の画像)
最後にプログで書かれているアルゴリズムです。
(ブログではこれは非常に単純なアルゴリズムなので拡散モデルが出てきた当初(2015年)になんで出てこなかったのかが不思議と書かれていました)
num_steps = 50
timesteps = create_timesptes(num_steps) # 80~0のノイズレベルを生成
x = torch.randn(img_shape) * timesteps[0] # 最初のノイズレベルで画像を生成
for t_curr, t_next in zip(timesteps[:-1], timesteps[1:]):
blend = t_next / t_curr # ブレンドする割合を計算
# ブレンド
x = blend * x + (1-blend) * denoise(x, t_curr)
図示すると以下です。
・実際の生成過程(上がdenoise後の画像、下がブレンド画像)
同じモデルでも生成方法で結構変わりますね。
GoogleColabではEDMに書いてある2次ルンゲ・クッタ法も実装しています。
おわりに
拡散モデルの理論体系をまとめたと言ってるだけあってやっとEDMの凄さが分かりました。
ただ難しい…
素人が書いているので間違っていたらすいません。
細かい間違いはスルーしていただければと思いますが、致命的な間違いがあれば教えてもらえると助かります。