2023年2月に発売されたPFN岡野原さんの著書「拡散モデルデータ生成技術の数理」を実装しながら解説します。
前回記事では、拡散モデル本の1.5.1で解説されているランジュバン・モンテカルロ法を実装し、単一の二次元正規分布と、複数の二次元正規分布の混合分布に対してサンプリングを行いました。
今回は、拡散モデル本の1.5.5で解説されているデノイジングスコアマッチングについてPytorchを使用して実装します。
今回実装したコードはこちらからも参照できます。
スコアベースモデルとは?
前回記事で解説したランジュバン・モンテカルロ法では、$p(\mathbf{x})$ の対数尤度の微分であるスコア $s(\mathbf{x}) (= \nabla_\mathbf{x} \log p(\mathbf{x}))$ を用いて、ランダムに初期化されたサンプリング結果 $\mathbf{x}_0$ を徐々に尤度の高いものに更新していき、最終的に $p(\mathbf{x})$ に近い確率分布からサンプリングを得ることができました。
このことから、確率分布 $p(\mathbf{x})$ が未知であってもそのスコアさえわかっていれば、$p(\mathbf{x})$ から効率的にサンプリングを行うことができます。このように、確率分布 $p(\mathbf{x})$ のスコアを学習して、それをサンプリングに応用するモデルをスコアベースモデルと呼びます。
あるパラメータ $\theta$ で構築されるモデルを使って、上記のスコアを学習する際、最小化する目的関数は単純に平均二乗誤差を用いると以下のようになり、これを明示的スコアマッチングと呼びます。
$$
J_{ESM_p}(\theta) = \frac{1}{2} E_{p(\mathbf{x})}[| \nabla_\mathbf{x} \log p(\mathbf{x}) - s_{\theta}(\mathbf{x}) |^2]
$$
しかしながら、一般的に未知の確率分布 $p(\mathbf{x})$ のスコア $\nabla_\mathbf{x} \log p(\mathbf{x})$ は未知であるため、上記の目的関数を用いてスコアを推定するモデルの学習は行うことができません。
そこで、以下の式で記述される暗黙的スコアマッチングと呼ばれる目的関数を用いてスコアを推定するモデルを学習します。
$$
J_{ISM_p}(\theta) = E_{p(\mathbf{x})}[ \frac{1}{2} | s_{\theta}(\mathbf{x}) |^2 + tr(\nabla_x s_{\theta}(\mathbf{x}))]
$$
詳細な証明は拡散モデル本の1.5.4に示されていますが、この二つの関数には、以下の式のような関係が成立し、暗黙的スコアマッチングを最小化するパラメータ $\theta$ は明示的スコアマッチングを最小化する $\theta$ と一致します。ただし、$C_1$ はパラメータ $\theta$ に依存しない定数を示します。
$$
J_{ESM_p}(\theta) = J_{ISM_p}(\theta) + C_1
$$
デノイジングスコアマッチングとは?
暗黙的スコアマッチングを用いれば、明示的スコアマッチングを最小化するスコアベースモデルを学習できることがわかりましたが、この学習方法にも以下の問題点が存在します。
- $E_{p(\mathbf{x})}[tr(\nabla_x s_{\theta}(\mathbf{x}))]$ の計算量が大きい
- データ $\mathbf{x}$ に対して過学習が起こりやすい
そこで、デノイジング・スコアマッチングを使って上記二点の問題を解決します。
デノイジング・スコアマッチングでは、データ $\mathbf{x}$ に対し正規分布からのノイズ $\epsilon \sim N(0,\sigma^2\mathbf{I})$ を加えた $\mathbf{\tilde{x}}$ を学習に用いることで、モデルがデータ $\mathbf{x}$ に過剰にフィッティングすることを抑制します。さらに、暗黙的スコアマッチングを以下のようにデノイジングスコアマッチング関数に修正することで、計算量の問題点も解決します。
$$
J_{DSM_{p_{\sigma}}}(\theta) = \frac{1}{2} E_{\epsilon \sim N(\mathbf{0},\sigma^2 \mathbf{I}), x \sim p(\mathbf{x})}[ | - \frac{1}{\sigma^2} \epsilon + s_{\theta}(\mathbf{x} + \epsilon, \sigma) |^2 ]
$$
つまり、上式を通してモデルはスコアを予測するために、データ $\mathbf{x}$ に加えられたノイズ $\epsilon$ を推定することになります。
実装: デノイジングスコアマッチングによるサンプリング
スコアベースモデルの学習
実際にPythonでデノイジングスコアマッチングを実装したコードがこちらです。
まずは、ライブラリのインポートです。今回は乱数を生成するので、プログラム実行時に同じ結果となるように乱数シードを固定しています。
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
# 乱数シードを固定
torch.manual_seed(1234)
次に、サンプリングを行いたい確率分布を定義します。
今回は、2種類の二次元正規分布を混合した比較的シンプルな確率分布を作成します。
# 平均ベクトル
means = torch.tensor([[2.0, 2.0], [-2.0, -2.0]])
# 共分散行列
covs = torch.Tensor([
[[ 1.0, 0.0],
[ 0.0, 1.0]],
[[ 1.0, 0.0],
[ 0.0, 1.0]],
])
# 混合係数
mixture_weights = torch.tensor([0.5, 0.5])
# 混合正規分布を作成
dist = MixtureSameFamily(
Categorical(mixture_weights),
MultivariateNormal(means, covs)
)
# 二次元の格子点の座標を作成
ls = np.linspace(-5, 5, 1000)
x, y = np.meshgrid(ls, ls)
point = torch.tensor(np.vstack([x.flatten(), y.flatten()]).T)
# 格子点の座標における尤度を算出
p = torch.exp(dist.log_prob(point))
# 混合分布を可視化
plt.title('2-dim mixture normal distribution')
plt.pcolormesh(x, y, p.reshape(x.shape), cmap='viridis')
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.colorbar()
plt.show()
次に、定義した確率分布のスコアを学習するモデルを定義します。
今回は、シンプルな4層の全結合層からなるニューラルネットワークを定義します。また、確率分布からサンプリングしたデータに基づいて、モデルを学習する関数も定義します。
# 学習するニューラルネットワークモデルを定義
# 今回はシンプルな4層の全結合層からなるモデルを定義
class ScoreBaseModel(nn.Module):
def __init__(self, input_dim, mid_dim=64):
super().__init__()
self.input_dim = input_dim
self.mid_dim = mid_dim
self.fc1 = nn.Linear(input_dim+1, mid_dim)
self.fc2 = nn.Linear(mid_dim, mid_dim//2)
self.fc3 = nn.Linear(mid_dim//2, mid_dim//4)
self.fc4 = nn.Linear(mid_dim//4, input_dim)
return
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
# 学習用の関数を定義
def train(batch_size, num_epoch, dist, model, optimizer, criterion):
# EpochごとのLossを保存するリスト
out_list = []
# モデルを学習モードに変更
model.train()
# 以下num_epoch回数分の学習を実行
for epoch in tqdm(range(num_epoch)):
# 確率分布からBatch Size分だけサンプリング
sample = dist.sample((batch_size,))
# サンプリング結果に付加するするノイズの標準偏差を一様分布から取得(0~10の範囲)
sigma = torch.rand(1) * 10
# サンプリング結果に付加するノイズを平均0、標準偏差Sigmaで生成
noise = torch.normal(mean=0, std=sigma.item(), size=sample.shape)
# サンプリング結果にノイズ付加
noise_sample = sample + noise
# モデルの予測対象を計算
true_y = - noise / sigma / sigma
# Sigmaをモデルに入力できるように(Batch Size, 1)の形状に変換
batch_sigma = torch.ones((batch_size, 1)) * sigma
# ノイズ付加後のサンプリング結果とSigmaをConcatで結合
x = torch.concat([noise_sample, batch_sigma], axis=1)
# それらをモデルに入力して予測を実行
pred_y = model(x)
# 勾配情報を初期化
optimizer.zero_grad()
# Lossの計算
# Lossの水準を一定にするため、Sigmaの二乗で乗算
loss = criterion(true_y, pred_y) * sigma * sigma
# 誤差逆伝播法で勾配計算
loss.backward()
# 計算された勾配に基づいてモデルパラメータの更新
optimizer.step()
# EpochごとのLossをリストに追加
out_list.append([epoch, loss.item()])
df_res = pd.DataFrame(out_list, columns=['epoch','loss'])
return df_res
それでは、以下のコードで実際にスコアベースモデルを学習していきます。
学習時のパラメータはコードに記載の通りで、最適化手法にAdam、損失関数に平均二乗誤差を使用しています。
学習終了後、損失の推移を可視化してみるとおよそ100エポックほどで損失は小さくなり、その後は0~1の範囲を反復して横ばいとなりました。
やはり推定対象そのものがノイズなので、完璧な予測はできず、一定のノイズ成分だけ損失値がブレ続けるのだと思われます。
# 学習時のパラメータを設定
input_dim = 2
lr = 0.001
batch_size = 2000
num_epoch = 10000
# モデル作成
model = ScoreBaseModel(input_dim)
# モデルパラメータの最適化手法を決定
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Loss関数を決定
criterion = torch.nn.MSELoss()
# 学習の実行
df_res = train(batch_size, num_epoch, dist, model, optimizer, criterion)
# Lossの推移を可視化
df_res['loss'].plot()
plt.ylabel('Loss')
plt.xlabel('Epoch')
学習されたスコアの可視化
以下のコードでは、先ほど学習したスコアベースモデルを用いてスコアを実際に予測し、推定されたスコアが真値と比較してどのくらい差があるのか確認します。
まずは、ある確率分布のスコアを計算する関数を定義します。
確率分布の微分にはPytorchのtorch.autograd.grad関数を使用しています。
# 確率分布のスコアを計算する関数を定義
def calc_score(dist, x):
# xを自動微分対象に変更
x.requires_grad_()
# 対数尤度を計算
log_p = dist.log_prob(x)
# 対数尤度のxによる微分を計算
score = torch.autograd.grad(log_p.sum(), x)[0]
return score
次に、上記で定義した関数を用いて、二次元の格子点座標における確率分布のスコアを計算します。
# 二次元の格子点の座標を作成
num_point = 20
ls = np.linspace(-5, 5, num_point)
x, y = np.meshgrid(ls, ls)
point = torch.tensor(np.vstack([x.flatten(), y.flatten()]).T).to(torch.float32)
# 格子点における混合分布のスコアを計算
score = calc_score(dist, point)
# 可視化のため、(x座標,y座標,スコアの値)の形状に変換
score = score.reshape((num_point,num_point,input_dim))
さらに、学習済みモデルを用いて同様の格子点座標におけるスコアを推定します。
このとき、データに付加するノイズの標準偏差は1としています。
# sigma=1として格子点におけるスコアを予測
sigma = 1
noise = torch.normal(mean=0, std=sigma, size=point.shape)
noise_sample = point + noise
batch_sigma = torch.ones((noise_sample.shape[0], 1)) * sigma
# モデルを推論モードに変更
model.eval()
# 格子点座標におけるスコアを推定
with torch.no_grad():
pred_y = model(torch.concat([noise_sample, batch_sigma], axis=1))
# 可視化のため、(x座標,y座標,スコアの値)の形状に変換
pred_vec = pred_y.reshape((num_point,num_point,input_dim))
最後に、真値と予測したスコアをベクトル場で可視化して比較します。
下図中の、青い矢印が真値のスコア、赤い矢印が予測したスコアとなり、等高線は初めに定義した確率分布のものとなります。
この可視化結果から、ほとんどの地点でおおよそスコアの向き、つまり、確率分布の対数尤度が大きくなる向きが一致していることがわかります。
# 混合分布の等高線図を可視化するため、格子点の数を多くして各点の尤度を計算
# 二次元の格子点の座標を作成
num_point = 100
ls = np.linspace(-5, 5, num_point)
_x, _y = np.meshgrid(ls, ls)
_point = torch.tensor(np.vstack([_x.flatten(), _y.flatten()]).T).to(torch.float32)
# 格子点の座標における尤度を算出
p = torch.exp(dist.log_prob(_point))
# 混合分布の等高線図を可視化
plt.title('true and predicted score')
plt.contour(_x, _y, p.reshape(_x.shape))
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.colorbar()
# 混合分布の実際のスコアを可視化
plt.quiver(x,y,score[:,:,0],score[:,:,1],color='blue',angles='xy',label='true')
# 学習済みモデルによって予測されたスコアを可視化
plt.quiver(x,y,pred_vec[:,:,0],pred_vec[:,:,1],color='red',angles='xy',label='pred')
plt.legend(loc='lower right')
plt.show()
学習済みスコアベースモデルを用いたサンプリング
最後に以下のコードで、前回記事で解説したランジュバン・モンテカルロ法を使って、学習済みモデルが予測したスコアから確率分布のサンプリングが可能かどうか検証します。
まず、こちらのコードでは、学習済みスコアベースモデルによってスコアを予測し、そのスコアを用いてランジュバン・モンテカルロ法を行う関数を実装しています。
# モデルベースのランジュバン・モンテカルロ法の実装
def model_based_langevin_monte_carlo(model, num_samples, num_steps, step_size, sigma):
# 初期サンプルを乱数から生成
x = torch.randn(num_samples, model.input_dim)
# モデルを推論モードに変更
model.eval()
# 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
for i in tqdm(range(num_steps)):
with torch.no_grad():
noise = torch.normal(mean=0, std=sigma, size=x.shape)
noise_x = x + noise
batch_sigma = torch.ones((noise_x.shape[0], 1)) * sigma
score = model(torch.concat([noise_x, batch_sigma], axis=1))
# 最終ステップのみノイズ無しでスコアの方向に更新
if i < num_steps - 1:
noise = torch.randn(num_samples, model.input_dim)
else:
noise = 0
x = x + step_size * score + np.sqrt(2 * step_size) * noise
return x
さらにこちらでは、上記で作成した関数を実行し、サンプリングを行っています。
可視化されたサンプリング結果をみると、分散が多少広がりぼやけた分布になってはいますが、大まかに最初に定義した確率分布と同様の特徴を持つ分布が得られていることが確認できました。
# ランジュバン・モンテカルロ法のパラメータ
num_samples = 100000
num_steps = 1000
step_size = 0.1
sigma = 1
# サンプリングの実行
samples = model_based_langevin_monte_carlo(model, num_samples, num_steps, step_size, sigma)
# サンプリング結果の可視化
plt.title('model-based langevin monte carlo sampling')
plt.hist2d(
samples[:,0],
samples[:,1],
range=((-5, 5), (-5, 5)),
cmap='viridis',
bins=50,
)
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.colorbar()
plt.show()
まとめ
本記事では、Pytorchを使用してデノイジング・スコアマッチングを実装する方法を紹介し、2種類の二次元正規分布の混合分布でのスコアを学習するニューラルネットワークモデルを構築しました。
さらに、学習したモデルによって予測されたスコアに基づき、ランジュバン・モンテカルロ法を使って、上記の確率分布からのサンプリングを模擬することができました。
次回は、拡散モデル本の2.3で解説されているデノイジング拡散確率モデルを実装する予定です。
以上、ご一読いただき、ありがとうございました。