LoginSignup
0
1

拡散モデルのDDPMの理論と実装

Last updated at Posted at 2024-04-15

拡散モデルの、Denoising Diffusion Probabilistic Model(DDPM)がいろいろと凄かったので、なんとなくまとめておこうと思います。

DDPMとは?

元の画像を$x_0$として、徐々にガウシアンノイズを乗せる事を考えたモデルです。
$$x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}\epsilon_t\quad(t=1,2,...,T)$$

拡散過程

$\epsilon_t 〜 N(0, I)\quad(I: 単位行列)$ とし、
$x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\epsilon_tt\quad(\beta: ノイズの大きさのハイパーパラメータ)$
とするマルコフ過程を考えると、
$q(x_t | x_{t-1}) = N(x_t; \sqrt{1 - \beta_t}x_{t-1}, \beta_tI)$
$∴x_t = \sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon\quad(\alpha_t = 1 - \beta_t, \bar{\alpha} = \prod_{t=1}(\alpha_t))\quad...1$
また、$q(x_t | x_0) = N(x_t; \sqrt{\bar{\alpha}}x_0, \sqrt{1 - \bar{\alpha}}I)$

条件付き逆拡散過程

ベイズの定理とマルコフ性より、
$q(x_{t-1} | x_t, x_0) = \frac{q(x_t | x_{t-1})q(x_{t-1} | x_0)}{q(x_t | x_0)}$
それぞれの条件付き確率を拡散過程の正規分布より求めると、
$q(x_{t-1} | x_t, x_0) ∝ exp[-\frac{1}{2}\frac{(x_{t-1} - \bar{\mu}_t(x_t, x_0))^2}{\bar{\beta}_t}]$

$∴q(x_t | x_t, x_0) = N(x_{t-1}; \bar{\mu}_t(x_t, x_0), \bar{\beta}_tI)$

また、
$\quad\bar{\mu}_t = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\epsilon) …①$

$\quad\bar{\beta_t} = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t}\beta_t$

ここで、$\sigma_t^2 = \bar{\beta_t}$とし、
$x_{t-1} = \bar{\mu}_t + \bar{\beta}_tz_t\quad(z_t 〜 N(0, I))$
$\quad= \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}}\epsilon_t) + \sigma_t^2z_t$
$\quad= \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{1 - α_t}{\sqrt{1 - \bar{\alpha}_t}}\epsilon_t) + \bar{\beta}z_t\quad...2$

損失関数

ニューラルネットワークのパラメータをθとすると、
$x_{t-1} = \mu_\theta(x_t, t) + \sigma_t^2z_t\quad(z_t 〜 N(0, I))$
ここで、画像生成のため、平均の負の対数尤度を考えると、
マルコフ性と正規分布のカルバックライブラーダイバージェンスより、
$L_t = E_q[\frac{1}{2\sigma_t^2} || \bar{\mu}_t(x_t, x_0) - \mu_θ(x_t, t) ||^2]$

①より、$\mu_\theta = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha_t}}}\epsilon_\theta)$と定義し、係数を簡略化すると、
$L_t^{simple} = E_q[||\epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon_t, t) ||^2]$

参考

実装

拡散過程

# 線形にノイズの大きさを増やす方式: β_t, (t = num_times: 上記理論のT)
betas = torch.linspace(beta_1, beta_t, num_times).double()

# 上記理論の1のための準備(定義より)
alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, dim=0)
sqrt_alphas_bar = torch.sqrt(alphas_bar)
sqrt_one_minus_alphas_bar = torch.sqrt(1. - alphas_bar)

等と定義し、

t = torch.randint(num_times, size=(x_0.size(0),), device=x_0.device) # 上記理論の∀t∈T
noise = torch.randn_like(x_0)

# 上記理論の1より
x_t = (extract(sqrt_alphas_bar, t, x_0.shape) * x_0
       + extract(sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

# 上記理論の損失関数の定義より
loss = F.mse_loss(model(x_t, t), noise, reduction='none')

等と損失を得る。
ここで、

def extract(v, t, shape):
    out = torch.gather(v, index=t, dim=0).float().to(t.device)
    out = out.view([t.size(0)] + [1] * (len(shape) - 1)) # (Batch, 1, 1, 1, ...)
    return out

条件付き逆拡散過程

# 拡散過程同様のβ_t
betas = torch.linspace(beta_1, beta_t, num_times).double()

# 上記理論の2のための準備(定義より)
alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:num_times]
coeff1 = torch.sqrt(1. / alphas)
coeff2 = coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar)
var = (1. - alphas_bar_prev) / (1. - alphas_bar) * betas

等と定義し、

def p_mean_variance(x_t, t):
    # ニューラルネットワークによるガウシアンノイズの推定
    eps = model(x_t, t)
    # 上記理論の2より平均
    xt_prev_mean = (extract(coeff1, t, x_t.shape) * x_t
                    - extract(coeff2, t, x_t.shape) * eps)
    # 上記理論の2より分散
    _var = extract(var, t, x_t.shape)
    return xt_prev_mean, _var

# x: ガウシアンノイズ
def generate(x):
    xt_shape = x.shape
    # 条件付き逆拡散過程
    for times in reversed(range(num_times)):
        t = torch.ones([xt_shape[0]], dtype=torch.long, device=x.device) * times # 上記理論のt∈T
        x = x.detach()
        t = t.detach()
        mean, var= p_mean_variance(x, t) # 条件付き逆拡散過程の平均と分散から
        if times > 0:
            noise = torch.randn_like(x)
        else:
            noise = 0
        x = mean + torch.sqrt(var) * noise # x_{t-1}を計算する(上記理論の2)
        assert torch.isnan(x).int().sum() == 0, "NaN in Tensor."
    return torch.clip(x, 0, 1) # 画像生成のためクリップ

等と生成画像$x_0$を得る。

ニューラルネットワーク

位置情報埋め込みを用いたUNet等を用いる。
あとは、拡散過程の誤差逆伝播学習を行った後、条件付き逆拡散過程で画像生成を行う。

全体像サンプル

DDPM (GitHub)

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1