2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DDPM(Denoising Diffusion Probabilistic Models)の大まかな仕組みと実装

Last updated at Posted at 2024-11-27

当記事では近年画像生成に採用されることの多いDDPM(Denoising Diffusion Probabilistic Models)の大まかな仕組みと実装について取りまとめました。作成にあたっては「ゼロから作るDeepLearning⑤」を参考にしました。

大まかな仕組み

概要

DDPM(Denoising Diffusion Probabilistic Models)の学習では「ランダムに選んだ画像にノイズを加えて得た画像から加えたノイズを予測するタスク」に基づいて学習を行います。このような学習を行うことで、「画像と同様のサイズでサンプリングされた潜在変数」についてノイズを予測することで画像の生成を行うことが可能です。

DDPMの学習

\begin{align}
\mathbf{x}_{t} &= \sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_{t}} \varepsilon \quad (1) \\
L(\mathbf{x}_{0}; \theta) &= || \varepsilon_{\theta}(\mathbf{x}_{t},t) - \varepsilon ||^{2} \quad (2) \\
\frac{\partial}{\partial \theta}L(\mathbf{x}_{0}; \theta) &= (\varepsilon_{\theta}(\mathbf{x}_{t},t) - \varepsilon) \frac{\partial}{\partial \theta}\varepsilon_{\theta}(\mathbf{x}_{t},t) \quad (3)
\end{align}

上記の式などを用いることで下記の手順でDDPMの学習を行うことができます。

以下の演算を繰り返す:
1) $\mathbf{x}_{0}$を学習データからランダムに取得する
2) $t \sim U\{1,T\}$に基づいて$t$をサンプリングする(1~Tの間の整数を当確率で選択する)
3) $\varepsilon \sim \mathcal{N}(\mathbf{0},\mathbf{I})$に基づいて加えるノイズをサンプリングする(無相関の正規分布に基づいて生成されるので、各変数毎に標準正規分布に基づいてサンプリングを行う)
4) $(1)$式を計算する
5) 損失関数($(2)$式)を計算する
6) $(3)$式を計算し、勾配法に基づいてパラメータのUpdateを行う

学習結果に基づく生成

\begin{align}
\sigma_{q}(t) &= \sqrt{\frac{(1-\alpha_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}} \quad (4) \\
\mathbf{x}_{t-1} &= \frac{1}{\alpha_{t}} \left( \mathbf{x}_{t} - \sqrt{\frac{1-\alpha_{t}}{1-\bar{\alpha}_{t}}} \varepsilon_{\theta}(\mathbf{x}_{t},t) \right) + \sigma_{q}(t) \varepsilon \quad (5)
\end{align}

DDPMにおける画像の生成は上記の式を用いて下記の手順を実行することで行うことができます。

以下の演算によって画像の生成を行う:
1) $\mathcal{N}(\mathbf{0},\mathbf{I})$から$\mathbf{x}_{T}$のサンプリングを行う
2) for t in [T, ..., 1]
2-1) $t>1$のとき$\mathcal{N}(\mathbf{0},\mathbf{I})$から$\varepsilon$のサンプリングを行う、$t=0$のときは$\varepsilon$はゼロ行列
2-2) $(4)$式に基づいて$\sigma_{q}(t)$を計算する
2-3) $(5)$式に基づいて$\mathbf{x}_{t-1}$を計算する
3) $\mathbf{x}_{0}$を出力する

DDPMの実装

UNet

DDPMでは加えたノイズの$\varepsilon \sim \mathcal{N}(\mathbf{0},\mathbf{I})$の予測にあたって、UNetを用います。UNetの入力が「ノイズを加えた結果の画像」、出力が「画像に加えたノイズの予測」にそれぞれ対応します。ゼロから作るDeepLearning⑤の9章では下記のようにUNetが構築されます。

step09/diffusion_model.py
class UNet(nn.Module):
    def __init__(self, in_ch=1, time_embed_dim=100):
        super().__init__()
        self.time_embed_dim = time_embed_dim

        self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
        self.down2 = ConvBlock(64, 128, time_embed_dim)
        self.bot1 = ConvBlock(128, 256, time_embed_dim)
        self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
        self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
        self.out = nn.Conv2d(64, in_ch, 1)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, x, timesteps):
        v = pos_encoding(timesteps, self.time_embed_dim, x.device)

        x1 = self.down1(x, v)
        x = self.maxpool(x1)
        x2 = self.down2(x, v)
        x = self.maxpool(x2)

        x = self.bot1(x, v)

        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x, v)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x, v)
        x = self.out(x)
        return x

上記ではself.maxpoolself.upsampleの演算が2回ずつ実行されており、$1/2$のダウンサンプリングを2度行った後に$2$倍のアップサンプリングを2度行っていることが確認できます。上記で用いられるConvBlockは下記のように実装されています。timestepsに基づいた位置エンコーディングのvの作成にあたって用いられるpos_encodingは次項で詳しく確認します。

step09/diffusion_model.py
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_embed_dim):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(time_embed_dim, in_ch),
            nn.ReLU(),
            nn.Linear(in_ch, in_ch)
        )

    def forward(self, x, v):
        N, C, _, _ = x.shape
        v = self.mlp(v)
        v = v.view(N, C, 1, 1)
        y = self.convs(x + v)
        return y

上記では画像のxに位置エンコーディングのvを加え、nn.Conv2dに基づく畳み込み処理が行われていることが確認できます。vは入力の画像のチャネル毎に値を持っており、空間方向には同じ値を加える点に注意しておくとよいです(v.view(N, C, 1, 1)より画像の縦横方向の値は1つしか持たないことが確認できる)。

正弦波位置エンコーディング

\mathbf{v}_{i} = \left\{
\begin{array}{ll}
\sin{\left( \frac{t}{10000^{\frac{i}{D}}} \right)} \qquad iが奇数のとき \\
\sin{\left( \frac{t}{10000^{\frac{i}{D}}} \right)} \qquad iが偶数のとき
\end{array}
\right.

ゼロから作るDeepLearning⑤の実装では位置エンコーディングに上記のような式で表される正弦波位置エンコーディングが用いられます。

step09/diffusion_model.py
def _pos_encoding(time_idx, output_dim, device='cpu'):
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)

    i = torch.arange(0, D, device=device)
    div_term = torch.exp(i / D * math.log(10000))

    v[0::2] = torch.sin(t / div_term[0::2])
    v[1::2] = torch.cos(t / div_term[1::2])
    return v

def pos_encoding(timesteps, output_dim, device='cpu'):
    batch_size = len(timesteps)
    device = timesteps.device
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(timesteps[i], output_dim, device)
    return v

上記の実装に基づいて正弦波位置エンコーディングを行うことができます。入力された時刻の配列のtimesteps毎にoutput_dim次元の位置エンコーディングが出力されます。

step09/diffusion_model.py
v = pos_encoding(torch.tensor([1, 2, 3]), 16)
print(v.shape)  # (3,16)が出力される

たとえば上記のようにpos_encodeingを実行することで正弦波位置エンコーディングを得ることができます。

拡散過程

ゼロから作るDeepLearning⑤では拡散過程を取り扱うにあたって下記のように実装されるDiffuserクラスが用いられます。

step09/diffusion_model.py
class Diffuser:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
        self.num_timesteps = num_timesteps
        self.device = device
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def add_noise(self, x_0, t):
        T = self.num_timesteps
        assert (t >= 1).all() and (t <= T).all()

        t_idx = t - 1  # alpha_bars[0] is for t=1
        alpha_bar = self.alpha_bars[t_idx]  # (N,)
        N = alpha_bar.size(0)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)  # (N, 1, 1, 1)

        noise = torch.randn_like(x_0, device=self.device)
        x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
        return x_t, noise

    def denoise(self, model, x, t):
        T = self.num_timesteps
        assert (t >= 1).all() and (t <= T).all()

        t_idx = t - 1  # alphas[0] is for t=1
        alpha = self.alphas[t_idx]
        alpha_bar = self.alpha_bars[t_idx]
        alpha_bar_prev = self.alpha_bars[t_idx-1]

        N = alpha.size(0)
        alpha = alpha.view(N, 1, 1, 1)
        alpha_bar = alpha_bar.view(N, 1, 1, 1)
        alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)

        model.eval()
        with torch.no_grad():
            eps = model(x, t)
        model.train()

        noise = torch.randn_like(x, device=self.device)
        noise[t == 1] = 0  # no noise at t=1

        mu = (x - ((1-alpha) / torch.sqrt(1-alpha_bar)) * eps) / torch.sqrt(alpha)
        std = torch.sqrt((1-alpha) * (1-alpha_bar_prev) / (1-alpha_bar))
        return mu + noise * std

    def reverse_to_img(self, x):
        x = x * 255
        x = x.clamp(0, 255)
        x = x.to(torch.uint8)
        x = x.cpu()
        to_pil = transforms.ToPILImage()
        return to_pil(x)

    def sample(self, model, x_shape=(20, 1, 28, 28)):
        batch_size = x_shape[0]
        x = torch.randn(x_shape, device=self.device)

        for i in tqdm(range(self.num_timesteps, 0, -1)):
            t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
            x = self.denoise(model, x, t)

        images = [self.reverse_to_img(x[i]) for i in range(batch_size)]
        return images

add_noiseは拡散過程における処理、denoiseは逆拡散過程における処理、sampleメソッドは逆拡散過程に基づく画像の生成にそれぞれ対応します。add_noiseを前節の「DDPMの学習の1)~4)」、denoiseを前節の「学習結果に基づく生成1)~3)」に対応させて抑えておくと良いです。

step09/diffusion_model.py
preprocess = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='./data', download=True, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

diffuser = Diffuser(num_timesteps, device=device)
model = UNet()
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)

losses = []
for epoch in range(epochs):
    loss_sum = 0.0
    cnt = 0

    # generate samples every epoch ===================
    # images = diffuser.sample(model)
    # show_images(images)
    # ================================================

    for images, labels in tqdm(dataloader):
        optimizer.zero_grad()
        x = images.to(device)
        t = torch.randint(1, num_timesteps+1, (len(x),), device=device)

        x_noisy, noise = diffuser.add_noise(x, t)
        noise_pred = model(x_noisy, t)
        loss = F.mse_loss(noise, noise_pred)

        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        cnt += 1

    loss_avg = loss_sum / cnt
    losses.append(loss_avg)
    print(f'Epoch {epoch} | Loss: {loss_avg}')

また、学習の実行にあたっては上記のようなコードを実行すれば良いです。torch.nn.functional.mse_lossで「実際に加えたノイズ(画像と同じサイズ)」と「予測したノイズ(画像と同じサイズ)」の平均二乗和誤差(前節の$(2)$式に対応)を計算し、optimizer = Adam(model.parameters(), lr=lr)のように生成されるoptimizerオブジェクトに基づいて最適化が行われます。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?