2023年2月に発売されたPFN岡野原さんの著書「拡散モデルデータ生成技術の数理」を実装しながら解説します。
今回は、拡散モデル本の1.5.1で解説されているランジュバン・モンテカルロ法についてPytorchを使用して実装します。
今回実装したコードはこちらからも参照できます。
ランジュバン・モンテカルロ法とは?
ランジュバン・モンテカルロ法とは、MCMC法の一種であり、未知の確率分布 $p(\mathbf{x})$ の対数尤度の勾配であるスコアを用いて、$p(\mathbf{x})$ に近い確率分布から $\mathbf{x}$ をサンプリングすることができる手法です。
ここで、確率分布 $p(\mathbf{x})$ のスコア $s(\mathbf{x})$ とは以下の式で定義されます。
$$
s(\mathbf{x}) = \nabla_\mathbf{x} \log p(\mathbf{x}): R^d \to R^d
$$
$\nabla_\mathbf{x}$ は $\mathbf{x}$ の各成分($x_0, x_1, ... ,x_d$)による微分を示します。
上式のスコアは、微分の公式により、
$$
s(\mathbf{x}) = \nabla_\mathbf{x} \log p(\mathbf{x}) = \dfrac{\nabla_\mathbf{x} p(\mathbf{x})}{p(\mathbf{x})}
$$
となり、ある点 $\mathbf{x}$ の確率分布 $p(\mathbf{x})$ の勾配を確率分布 $p(\mathbf{x})$ そのもので割った値になります。
ランジュバン・モンテカルロ法では、このスコアを用いて $\mathbf{x}$ をサンプリングします。
以下、ランジュバン・モンテカルロ法のアルゴリズムを示します。
$$
\mathbf{x_{k+1}} = \mathbf{x_k} + \alpha \nabla_\mathbf{x} \log p(\mathbf{x_k}) + \sqrt{2\alpha} \mathbf{u_k} \
\mathbf{x_0} \sim N(\mathbf{x_0};\mathbf{0},\mathbf{I}) \
\mathbf{u_k} \sim N(\mathbf{u_k};\mathbf{0},\mathbf{I}) \
\alpha: stepsize
$$
上式のように、初め $\mathbf{x_0}$ は平均 $\mathbf{0}$ 、共分散行列 $\mathbf{I}$ の多次元正規分布から得られた乱数で初期化され、その後スコアの示す方向、つまり、対数尤度が急激に大きくなる方向に向かってノイズ $\mathbf{u_k}$ を含みながら更新されます。
この過程を繰り返し、$k \to \infty$ かつ $\alpha \to 0$ のとき、$\mathbf{x_k}$ は $p(\mathbf{x})$ からのサンプリングに収束します。
ランジュバン・モンテカルロ法は、
- 対数尤度の勾配であるスコア $s(\mathbf{x})$ を指針に更新されるため、尤度の高い領域を効率的に探索できる
- ノイズ $\mathbf{u_k}$ を含んで更新されるため、局所的に尤度の高い極大値から抜け出しやすい
という二つの大きな利点を持つサンプリング手法となっています。
実装: ランジュバン・モンテカルロ法でサンプリング
単一の二次元正規分布からサンプリング
実際にPythonでランジュバン・モンテカルロ法を実装したコードがこちらです。
まずは、ライブラリのインポートです。今回は乱数を生成するので、プログラム実行時に同じ結果となるように乱数シードを固定しています。
# ライブラリのインポート
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
# 乱数シードを固定
torch.manual_seed(1234)
次にサンプリングしたい確率分布を定義します。今回は、単純な平均 $\mathbf{0}$ 、共分散行列 $\mathbf{I}$ の二次元正規分布からランジュバン・モンテカルロ法を使ってサンプリングが可能かを確認します。
こちらのコードでは、定義した二次元正規分布を可視化しています。
# 次元を定義
dim = 2
# 平均0、共分散行列が二次元の単位行列の二次元正規分布を作成
dist = MultivariateNormal(torch.zeros((dim)), torch.eye(dim))
# 二次元の格子点の座標を作成
ls = np.linspace(-2, 2, 1000)
x, y = np.meshgrid(ls, ls)
point = torch.tensor(np.vstack([x.flatten(), y.flatten()]).T)
# 格子点の座標における尤度を算出
p = torch.exp(dist.log_prob(point))
# 二次元正規分布を可視化
plt.title('2-dim normal distribution')
plt.pcolormesh(x, y, p.reshape(x.shape), cmap='viridis')
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.colorbar()
plt.show()
次にランジュバン・モンテカルロ法を実装します。
今回は、Pytorchのtorch.autograd.grad関数を使って対数尤度の勾配を計算しています。
# ランジュバン・モンテカルロ法の実装
def langevin_monte_carlo(dist, num_samples, num_steps, step_size):
# 初期サンプルを乱数から生成
x = torch.randn(num_samples, dim)
for i in range(num_steps):
x.requires_grad_()
log_p = dist.log_prob(x)
score = torch.autograd.grad(log_p.sum(), x)[0]
with torch.no_grad():
noise = torch.randn(num_samples, dim)
x = x + step_size * score + np.sqrt(2 * step_size) * noise
return x
実装したランジュバン・モンテカルロ法を使ってサンプリングを行います。
サンプリング結果のヒストグラムを可視化すると、先ほど可視化した二次元正規分布に近い分布が可視化できており、意図したサンプリングができていることがわかります。
# ランジュバン・モンテカルロ法のパラメータ
num_samples = 100000
num_steps = 1000
step_size = 0.001
# サンプリングの実行
samples = langevin_monte_carlo(dist, num_samples, num_steps, step_size)
# サンプリング結果の可視化
plt.title('langevin monte carlo sampling')
plt.hist2d(
samples[:,0],
samples[:,1],
range=((-2, 2), (-2, 2)),
cmap='viridis',
bins=50,
)
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.colorbar()
plt.show()
複数の二次元正規分布からなる混合分布からサンプリング
続いて、もう少し複雑な分布に対してランジュバン・モンテカルロ法が機能するかについて見ていきます。
こちらのコードでは、複数の平均値と共分散行列を持つ二次元正規分布を定義し、それらの混合分布を作成しています。
# 平均ベクトル
means = torch.tensor([[0.0, 0.0], [2.0, 2.0], [-2.0, -2.0], [2.0, -2.0]])
# 共分散行列
covs = torch.Tensor([
[[ 1.0, 0.0],
[ 0.0, 1.0]],
[[ 0.6, 0.1],
[ 0.1, 0.9]],
[[ 0.8, -0.2],
[-0.2, 0.8]],
[[ 0.3, 0.2],
[0.2, 0.8]],
])
# 混合係数
mixture_weights = torch.tensor([0.2, 0.2, 0.4, 0.2])
# 混合正規分布を作成
mixture_dist = MixtureSameFamily(
Categorical(mixture_weights),
MultivariateNormal(means, covs)
)
混合分布の可視化結果はこちらです。先ほどの単一の二次元正規分布よりも複雑な分布になっていることがわかります。
# 二次元の格子点の座標を作成
ls = np.linspace(-5, 5, 1000)
x, y = np.meshgrid(ls, ls)
point = torch.tensor(np.vstack([x.flatten(), y.flatten()]).T)
# 格子点の座標における尤度を算出
p = torch.exp(mixture_dist.log_prob(point))
# 二次元正規分布を可視化
plt.title('2-dim mixture normal distribution')
plt.pcolormesh(x, y, p.reshape(x.shape), cmap='viridis')
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.colorbar()
plt.show()
それでは、実際にランジュバン・モンテカルロ法でサンプリングを行なってみます。
こちらのコードで同様にサンプリングを行い、サンプリング結果をヒストグラムにして可視化しています。
結果を見ると、先ほど可視化した混合分布に近い分布が得られていることがわかります。
しかしながら、今回のサンプリングではステップサイズを0.1と設定したのですが、これを0.01、0.001と変化させていくと、上手く所望の分布が得られませんでした。
# ランジュバン・モンテカルロ法のパラメータ
num_samples = 100000
num_steps = 1000
step_size = 0.1
# サンプリングの実行
samples = langevin_monte_carlo(mixture_dist, num_samples, num_steps, step_size)
# サンプリング結果の可視化
plt.title('langevin monte carlo sampling')
plt.hist2d(
samples[:,0],
samples[:,1],
range=((-5, 5), (-5, 5)),
cmap='viridis',
bins=50,
)
plt.gca().set_aspect('equal', adjustable='box')
plt.xlim([-5, 5])
plt.ylim([-5, 5])
plt.colorbar()
plt.show()
複数のステップサイズでサンプリング結果を比較
実際に、ステップサイズを0.1から徐々に小さくした時のサンプリングを行なってみます。
こちらでは、8パターンのステップサイズでランジュバン・モンテカルロ法を試しています。
サンプリング結果のヒストグラムを見ると、ステップサイズが0.1~0.005辺りまでは、意図したサンプリング結果になっているのですが、0.001~0.00005ではステップサイズが小さくなるにつれて中心部分の頻度が多くなり、正解の混合分布とは程遠い分布になっていることがわかります。
# ランジュバン・モンテカルロ法のパラメータ
num_samples = 100000
num_steps = 1000
step_size_list = [0.1, 0.05, 0.01, 0.005, 0.001, 0.0005, 0.0001, 0.00005]
# 結果の保存先
samples_list = []
# 複数のステップサイズでサンプリング
for step_size in step_size_list:
# サンプリングの実行
samples = langevin_monte_carlo(mixture_dist, num_samples, num_steps, step_size)
# サンプリング結果の追加
samples_list.append(samples)
# サンプリング結果の可視化
fig, axes = plt.subplots(2, len(step_size_list)//2, figsize=(8,4))
for i,step_size in enumerate(step_size_list):
axes[i//4, i-i//4*4].set_title(step_size)
im = axes[i//4, i-i//4*4].hist2d(
samples_list[i][:,0],
samples_list[i][:,1],
range=((-5, 5), (-5, 5)),
cmap='viridis',
bins=50,
)
axes[i//4, i-i//4*4].set_aspect('equal', adjustable='box')
axes[i//4, i-i//4*4].axis('off')
cax = fig.add_axes((0.92, 0.13, 0.02, 0.74))
cbar = plt.colorbar(im[3], cax=cax, orientation='vertical', ticks=mticker.NullLocator())
cbar.ax.set_yticklabels([])
plt.show()
拡散モデル本の第2章にも記述がありましたが、今回の混合分布のように多峰性をもつ確率分布は、ノイズを含みながら更新を行うランジュバン・モンテカルロ法であっても中途半端に尤度の高い極大値から抜け出しづらく、抜け出せたとしても膨大なステップ数が必要になります。
今回の混合分布は、構成要素として平均 $0$ の二次元正規分布をもっているため、ステップ数一定のまま、ステップサイズを徐々に小さくしていくと、$\mathbf{x}$ が極大値である $0$ 付近から抜け出せず、最終的に $0$ 付近に頻度が高い分布になったと考えられます。
まとめ
本記事では、Pytorchを使用してランジュバン・モンテカルロ法を実装する方法を紹介し、単純な二次元正規分布とそれらを複数組み合わせた混合分布でサンプリングを行いました。
次回は、拡散モデル本の1.5.5で解説されているデノイジングスコアマッチングを実装する予定です。
以上、ご一読いただき、ありがとうございました。