3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

拡散モデルを解説するッ!!!①

Posted at

はじめに

拡散モデルについて勉強したので,それについてのまとめです!(記事内で間違いあったら教えてもらえると助かります)
今回の記事はPFNの岡野原 大輔氏が書かれた拡散モデル データ生成技術の数理を参考にしています.
一応サポートページのリンクも貼っておきます(サポートページ

イントロ(ざっくり理解したい人はここ読めば十分です)

 あるデータ$\mathbf{x}$は真の確率分布

    p_{data}(\mathbf{x})

からサンプリングされているとします.もしこの分布の形がわかれば,その分布に基づいてサンプリングすることで,$\mathbf{x}'$が得られます.すなわちデータを生成できるということです.例えば,人間の顔画像の確率分布がわかっていればその分布からサンプリングするとランダムな顔の画像が生成されるという具合ですね.最近ではこのような画像生成を実際のアプリケーション上で使ったことがある方も多いかと思いますが,あれは生成したい画像をプロンプトで指示できるように条件付きの確率分布で生成モデルを定義しています.

 さてこれから拡散モデルの解説をしていこうと思っているのですが,まず全体の話の流れを簡単に説明しておきます.生成モデル$p_{data}(\mathbf{x})$をニューラルネットワークなどで近似したいんですが,直接近似するのは実は大変(素直な実装ではできるんですが,性能が良くない、、、).そこで,代わりにスコア(対数尤度の入力による微分)と呼ばれる関数を学習し,そのスコアを使うことで求めたい確率分布からのサンプルを得ます(具体的な話は後述します,とにかくスコアを使ってもほしい結果は得られるということが大事).しかし,実際はスコアを使ったサンプリングではあまり効率的にほしいサンプルが得られないことも知られています(なんやねん).なので,データ分布を攪乱した(ノイズを加えた)攪乱後分布のスコアを学習し,そのスコアを使ってサンプリングを行います.これでなぜ上手くいくのか,というのは後述するとして.さて,この攪乱する時は正規分布からのノイズをデータに加えていくのですが,何回もデータにそれを加えていくにつれて,その正規分布ノイズが支配的になるため拡散の最終ステップでは元がどんなデータであっても正規分布に落ち着きます.で,その逆拡散過程を学習する,すなわち正規分布からデータ分布に戻してあげる(デノイジングする)過程を学習すると,その結果として得られたモデルは正規分布から何かデータを生成するモデルだということですね.このようなモデルがスコアベースモデルとか拡散モデルとか言われるやつですね.
 で,両者は別々に生み出されたんですがJonathan Ho et al.,2020の論文で実は両者が同じものであることが示され,さらにシグナルノイズ比という考えを導入してDiederik Kingma et al.,2021で完全に統一化されました.そんな全体的な流れを踏まえつつ順に解説および実装して行こうかと思います.
 というわけで,第一回は背景の説明とSBMの説明をしてきます〜.

スコアを用いたサンプリングについて

 まず,マルコフ連鎖モンテカルロ(Markov Chain Monte Carlo,MCMC)を用いたサンプリングについて簡単に説明します.そも,マルコフ連鎖っていうのはある遷移確率に従って状態遷移させる確率過程のことで,現在の状態は一つ前の状態にのみ依存するという性質を持ちます.そして,ある定常分布$P$を持つマルコフ連鎖を上手く設計し,そのマルコフ連鎖にしたがって適当な初期値を状態遷移させます.すると適当な回数状態遷移させた後の値は,$P$からのサンプルと見做せるという性質があります.この手法を用いると,分配関数(確率分布を規格化するときの値で,積分する必要があります)の計算が困難なときでも求めたい確率分布からサンプリングできます(例えばMCMCの一つであるメトロポリス法を見てみると遷移確率が尤度比になっているので,分配関数を計算する必要がなくなっていることがわかります).しかし,局所解(確率の大きいところ)にトラップされたり,受容・拒否のステップの存在により更新に時間がかかることが知られています.そこで勾配情報(スコア)を用いたサンプリング手法であるランジュバン・モンテカルロ法

\mathbf{x}_t = \mathbf{x}_{t-1}+\epsilon\nabla_x\log p(\mathbf{x_{t-1}})+\sqrt{2\epsilon} \mathbf{u}_t

を用いてサンプリングを行います.サンプリングしたい分布の勾配などがわかっている必要がありますが,メトロポリス法よりも効率的にサンプリングできます.そこで,この勾配をニューラルネットワークで近似します.それをおこなっているのが,スコアマッチングです.

スコアマッチングについて

 この節では明示的スコアマッチング(ESM) と暗黙的スコアマッチング(ISM)とデノイジングスコアマッチング(DSM)について説明します.コスト関数を$J$,ニューラルネットワークのパラメータを$\mathbf{\theta}$とします.この時,それぞれのコスト関数は

\begin{align}
J_{\rm{ESM}}(\mathbf{\theta}) &= \frac{1}{2}\mathbb{E}_{p(\mathbf{x})}\left[ \left\| \nabla_\mathbf{x} \log{p(\mathbf{x})} - \mathbf{s}_{\theta}(\mathbf{x}) \right\|^2 \right] \\
J_{\rm{ISM}}(\mathbf{\theta}) &= \mathbb{E}_{p(\mathbf{x})}\left[ \frac{1}{2}\left\| \mathbf{s}_{\theta}(\mathbf{x}) \right\|^2 + tr\left( \nabla_x \mathbf{s}_{\theta}(\mathbf{x})\right)\right] \\
J_{\rm{DSM}}(\mathbf{\theta})&=\frac{1}{2}\mathbb{E}_{p_{\sigma}(\tilde{\mathbf{x}}|\mathbf{x})p(\mathbf{x})} \left[ \left\| \nabla_\mathbf{x} \log{p(\tilde{\mathbf{x}}|\mathbf{x})} - \mathbf{s}_{\theta}(\tilde{\mathbf{x}}, \sigma) \right\|^2 \right]\\
&= \frac{1}{2} \mathbb{E}_{\epsilon \sim \mathcal{N}(\mathbf{0}, \sigma^2 I),\mathbf{x} \sim p_{data}(\mathbf{x})}\left[ \left\| -\frac{1}{\sigma^2}\epsilon - \mathbf{s}_{\theta}(\tilde{\mathbf{x}}, \sigma) \right\|^2 \right]
\end{align}

です.ここで$\tilde{\mathbf{x}}$は$\mathbf{x}$にノイズを加えた攪乱後データです.証明は割愛しますが,これらの最適化の結果得られる関数は同じです(不思議だと思う方は証明を追ってみるとわかると思います,双対問題的な話ですね).なのでどのコスト関数を使ってもいいんですが。。。実際はそれじゃダメで.まず,$J_{\rm{ESM}}(\mathbf{\theta})$については$\nabla_{\mathbf{x}}\log{p(\mathbf{x})}$がわかっている時しか使えません.そして,$J_{\rm{ISM}}(\mathbf{\theta})$は$tr\left(\nabla_x\mathbf{s}_{\theta}(\mathbf{x})\right)$の計算コストが高く,非効率であることに加え過学習を起こしやすいです.一方で,$J_{\rm{DSM}}(\mathbf{\theta})$は$\frac{1}{\sigma^2}\epsilon$は計算できる量です.なので,$J_{\rm{DSM}}(\mathbf{\theta})$を使ってスコアを近似していきます.
 さて,突然攪乱後データを考えるってなんのこっちゃと思ったかもしれないですね.なのでこれの背景を簡単に説明すると,データ分布は多様体仮説に基づいて高次元では実際のデータが存在する場所(確率が大きい場所)はめっちゃ少ないという話がありまして(一応補足:厳密性は全くないのできちんと調べた方が良いです.この記事では多様体仮説の細かい部分まで踏み込みません,悪しからず).それが下の図を見てくれたらいいんですが,download.png
こんな感じになっています(という仮説ですが).で,ランジュバン・モンテカルロを用いてサンプリングするときには,異なるモード(図で見えるところの各確率の大きい山に見える部分)を遷移する必要があります(一つの山しか探索できていないと,例えば男性と女性のデータを学習させたにもかかわらずサンプルされる=生成されるデータがどちらか片方しかないというようなケースが起こり得ます).ですが,確率が高いところから低いところを通っていくというのが実は大変.なぜなら,アルゴリズム自体が確率が高い場所を効率的に探索していくアルゴリズムですからね.確率の低い部分を移動したければランジュバン・モンテカルロ法第3項にあるノイズの力のみをつかって確率の高い部分を脱出してあげる必要があります.これには無限回のサンプリングを行えば理論的に大丈夫,すなわちステップ数を莫大にすれば良いのですが,効率が悪いですよね.そこで次の図のようにfig1.1.png
確率の高い部分を少し慣らして,別のモードに行きやすいようにします.その元でスコアを学習すれば効率的にいけます.
この慣らしが,攪乱です.データに少しの量を加えてあげることで確率分布を広げてあげることができました.
この攪乱後分布内でランジュバン・モンテカルロ法に従って一定ステップ状態遷移→最後の状態を初期値としてノイズの小さい分布内で再び状態遷移→…という手順でサンプリングを行います.よって,攪乱後データを使ってスコアを近似してあげることで広い範囲を探索し,多様性のあるサンプリングを行うことができるようになるわけです.これと似たようなことをやっているのがSBMとかDDPMです.

SBMについて

先ほどの攪乱する,すなわちノイズを加えることを複数回行う(データをノイズによって拡散する)過程を考えます.

SBMの拡散過程

SBMの拡散過程は,

\sigma_{\rm{min}}=\sigma_1<\cdots<\sigma_T=\sigma_{\rm{max}}

を用いて,

\mathbf{x}_t=\mathbf{x}_{t-1}+\sigma_t \epsilon, \,\,\epsilon \sim\mathcal{N}(\mathbf{0}, I) 

によって定義される過程です(ただし$\mathbf{x}_0$は元のデータです).
このように段階的な攪乱によって,完全なノイズ($\mathbf{x}_T \sim \mathcal{N}(\mathbf{0},\sigma_T^2I)$)となります.
次のコードと結果をご覧くださいませませ.

以下が元データです.こいつにノイズを加えていきます.
download.png

# 各ステップで加えるノイズ
variances = np.array([0.01, 0.1, 0.2, 0.3, 1, 5, 25, 125, 500, 700, 1000, 1e5, 1e10])
x_t = train_dataset[0][0].squeeze().numpy().flatten()
t=5
for i in range(t):
    x_t = x_t + variances[i] * rng.standard_normal((28*28))
plt.imshow(x_t.reshape(28,28))
plt.title("t={}".format(t))
plt.show()

いかが結果(ノイズを加えた画像)です(途中まで)
).
download-1.png
download-2.png
download-3.png
download-4.png
download-5.png

こんな感じでノイズを加える過程をSBMによる拡散過程とします.

攪乱後分布上でのスコアを考える

SBMはスコア(対数尤度の入力による微分)を推定するモデルで,

\mathbf{s}_{\theta}(\tilde{\mathbf{x}}, \sigma_t)

のようにモデル化されます.
このモデルは,ある時点での位置座標と分散を引数にとってその位置での対数尤度の微分を返す関数です.
また,コスト関数は

\sum_{t=1}^T w_t \mathbb{E}_{x \sim p_{data}(\mathbf{x}),\tilde{\mathbf{x}}\sim\mathcal{N}(\mathbf{x},\sigma_t^2I)} \left[\left\| \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_t^2} - \mathbf{s}_{\theta}(\tilde{\mathbf{x}}, \sigma_t)
 \right\|^2 \right]

です.これを最適化して得られたスコアを用いてLangevin Monte Carloでサンプリングすれば近似的に$p_{data}(\mathbf{x})$からのサンプリングが得られます(すなわち生成).
なかなか香ばしい式だとは思いますが,実装上は$M$をデータ数として

\begin{align}
J(\theta) &= \sum_{t=1}^T w_t \mathbb{E}_{x \sim p_{data}(\mathbf{x}),\tilde{\mathbf{x}}\sim\mathcal{N}(\mathbf{x},\sigma_t^2I)} \left[\left\| \frac{\mathbf{x} - \tilde{\mathbf{x}}}{\sigma_t^2} - \mathbf{s}_{\theta}(\tilde{\mathbf{x}}, \sigma_t)
 \right\|^2 \right] \\
 &\simeq \sum_{t=1}^T w_t \left[\frac{1}{M} \sum_{m=1}^M \left\| \frac{\mathbf{x}_m - \tilde{\mathbf{x}}_m}{\sigma_t^2} - \mathbf{s}_{\theta}(\tilde{\mathbf{x}}_m, \sigma_t) \right\|^2 \right]
\end{align}

とモンテカルロ平均に近似し,この時$\tilde{\mathbf{x}}_m=\mathbf{x}_m+\sigma_t\epsilon$とサンプリングしてあげればいけると思います(ちゃんと調べるべきですね。笑)
ともあれ,これでスコアを近似できます.

SBMの実装

まずはニューラルネットワークの定義から行います.
今回は,攪乱後データ+その拡散時の分散なので,データの次元+1次元が入力となります.
さらに全結合型で実装していますが,扱ったデータが画像なのでCNNやResNetを用いた方が精度は上がるのかなと思っています.

class ScoreBaseNet(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        out_dim = in_dim - 1
        self.layer1 = nn.Linear(in_dim, 512)
        self.layer2 = nn.Linear(512, 256)
        self.layer3 = nn.Linear(256, 128)
        self.layer4 = nn.Linear(128, 64)
        self.output = nn.Linear(64, out_dim)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = F.relu(self.layer4(x))
        return self.output(x)

そして肝心の学習ですが,こんな感じ.今回はMNISTのデータセットを使っています.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
epochs = 100
in_dim = 784+1
T = 100
sigma_0 = torch.randn(1)
r = 1.3
T = 20
sigmas = [sigma_0 * (r ** t) for t in range(T)]
sigmas = torch.tensor(sigmas)

score = ScoreBaseNet(in_dim)
optimizer = torch.optim.Adam(score.parameters(), lr=1e-4)
for epoch in range(epochs):
    for x, _ in train_loader:
        x = x.view(x.size(0), -1).to(device)
        loss = 0
        for t in range(T):
            w_t = sigmas[t]**2
            for k in range(K):
                x_tilde = x + torch.randn_like(x) * sigmas[t]
                sigma_t = torch.ones(batch_size) * sigmas[t]
                target = - (x - x_tilde)/(sigmas[t]**2)
            
                input = torch.cat((x_tilde, sigma_t.view(-1, 1)),dim=1)
                s_theta = score(input)
            loss += w_t * dsm(s_theta, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"epoch={epoch+1}, loss={loss.item()}")
        
torch.save(score, 'scorebasemodel.pth')

現状あんまり精度が出てないのでもしコードに間違いとかあれば教えてほしいです.
また,質問などあればお答えするのでぜひ.
次は拡散確率モデルから雑音除去拡散確率モデルをやっていきます.

参考文献

[1]岡野原 大輔,"拡散モデル データ生成技術の数理", 岩波書店
[2] Yang Song and Stefano Ermon, "Generative Modeling by Estimating Gradients of the Data Distribution", NeurIPS 2019
[3] Jascha Sohl-Dickstein et al.,"Deep Unsupervised Learning using Nonequilibrium Thermodynamics", ICML 2015
[4]Jonathan Ho et al.,"Denoising Diffusion Probabilistic Models", NeurIPS 2020
[5]Diederik Kingma et al.,"Variational Diffusion Models", NeurIPS 2021

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?