当記事では近年画像生成に採用されることの多いDDPM(Denoising Diffusion Probabilistic Models)の大まかな仕組みと実装について取りまとめました。作成にあたっては「ゼロから作るDeepLearning⑤」を参考にしました。
大まかな仕組み
概要
DDPM(Denoising Diffusion Probabilistic Models)の学習では「ランダムに選んだ画像にノイズを加えて得た画像から加えたノイズを予測するタスク」に基づいて学習を行います。このような学習を行うことで、「画像と同様のサイズでサンプリングされた潜在変数」についてノイズを予測することで画像の生成を行うことが可能です。
DDPMの学習
\begin{align}
\mathbf{x}_{t} &= \sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0} + \sqrt{1-\bar{\alpha}_{t}} \varepsilon \quad (1) \\
L(\mathbf{x}_{0}; \theta) &= || \varepsilon_{\theta}(\mathbf{x}_{t},t) - \varepsilon ||^{2} \quad (2) \\
\frac{\partial}{\partial \theta}L(\mathbf{x}_{0}; \theta) &= (\varepsilon_{\theta}(\mathbf{x}_{t},t) - \varepsilon) \frac{\partial}{\partial \theta}\varepsilon_{\theta}(\mathbf{x}_{t},t) \quad (3)
\end{align}
上記の式などを用いることで下記の手順でDDPMの学習を行うことができます。
以下の演算を繰り返す: |
---|
1) $\mathbf{x}_{0}$を学習データからランダムに取得する |
2) $t \sim U\{1,T\}$に基づいて$t$をサンプリングする(1~Tの間の整数を当確率で選択する) |
3) $\varepsilon \sim \mathcal{N}(\mathbf{0},\mathbf{I})$に基づいて加えるノイズをサンプリングする(無相関の正規分布に基づいて生成されるので、各変数毎に標準正規分布に基づいてサンプリングを行う) |
4) $(1)$式を計算する |
5) 損失関数($(2)$式)を計算する |
6) $(3)$式を計算し、勾配法に基づいてパラメータのUpdateを行う |
学習結果に基づく生成
\begin{align}
\sigma_{q}(t) &= \sqrt{\frac{(1-\alpha_{t})(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}} \quad (4) \\
\mathbf{x}_{t-1} &= \frac{1}{\alpha_{t}} \left( \mathbf{x}_{t} - \sqrt{\frac{1-\alpha_{t}}{1-\bar{\alpha}_{t}}} \varepsilon_{\theta}(\mathbf{x}_{t},t) \right) + \sigma_{q}(t) \varepsilon \quad (5)
\end{align}
DDPMにおける画像の生成は上記の式を用いて下記の手順を実行することで行うことができます。
以下の演算によって画像の生成を行う: |
---|
1) $\mathcal{N}(\mathbf{0},\mathbf{I})$から$\mathbf{x}_{T}$のサンプリングを行う |
2) for t in [T, ..., 1]
|
2-1) $t>1$のとき$\mathcal{N}(\mathbf{0},\mathbf{I})$から$\varepsilon$のサンプリングを行う、$t=0$のときは$\varepsilon$はゼロ行列 |
2-2) $(4)$式に基づいて$\sigma_{q}(t)$を計算する |
2-3) $(5)$式に基づいて$\mathbf{x}_{t-1}$を計算する |
3) $\mathbf{x}_{0}$を出力する |
DDPMの実装
UNet
DDPMでは加えたノイズの$\varepsilon \sim \mathcal{N}(\mathbf{0},\mathbf{I})$の予測にあたって、UNetを用います。UNetの入力が「ノイズを加えた結果の画像」、出力が「画像に加えたノイズの予測」にそれぞれ対応します。ゼロから作るDeepLearning⑤の9章では下記のようにUNetが構築されます。
class UNet(nn.Module):
def __init__(self, in_ch=1, time_embed_dim=100):
super().__init__()
self.time_embed_dim = time_embed_dim
self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
self.down2 = ConvBlock(64, 128, time_embed_dim)
self.bot1 = ConvBlock(128, 256, time_embed_dim)
self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
self.out = nn.Conv2d(64, in_ch, 1)
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, x, timesteps):
v = pos_encoding(timesteps, self.time_embed_dim, x.device)
x1 = self.down1(x, v)
x = self.maxpool(x1)
x2 = self.down2(x, v)
x = self.maxpool(x2)
x = self.bot1(x, v)
x = self.upsample(x)
x = torch.cat([x, x2], dim=1)
x = self.up2(x, v)
x = self.upsample(x)
x = torch.cat([x, x1], dim=1)
x = self.up1(x, v)
x = self.out(x)
return x
上記ではself.maxpool
とself.upsample
の演算が2回ずつ実行されており、$1/2$のダウンサンプリングを2度行った後に$2$倍のアップサンプリングを2度行っていることが確認できます。上記で用いられるConvBlock
は下記のように実装されています。timesteps
に基づいた位置エンコーディングのv
の作成にあたって用いられるpos_encoding
は次項で詳しく確認します。
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_embed_dim):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
self.mlp = nn.Sequential(
nn.Linear(time_embed_dim, in_ch),
nn.ReLU(),
nn.Linear(in_ch, in_ch)
)
def forward(self, x, v):
N, C, _, _ = x.shape
v = self.mlp(v)
v = v.view(N, C, 1, 1)
y = self.convs(x + v)
return y
上記では画像のx
に位置エンコーディングのv
を加え、nn.Conv2d
に基づく畳み込み処理が行われていることが確認できます。v
は入力の画像のチャネル毎に値を持っており、空間方向には同じ値を加える点に注意しておくとよいです(v.view(N, C, 1, 1)
より画像の縦横方向の値は1つしか持たないことが確認できる)。
正弦波位置エンコーディング
\mathbf{v}_{i} = \left\{
\begin{array}{ll}
\sin{\left( \frac{t}{10000^{\frac{i}{D}}} \right)} \qquad iが奇数のとき \\
\sin{\left( \frac{t}{10000^{\frac{i}{D}}} \right)} \qquad iが偶数のとき
\end{array}
\right.
ゼロから作るDeepLearning⑤の実装では位置エンコーディングに上記のような式で表される正弦波位置エンコーディングが用いられます。
def _pos_encoding(time_idx, output_dim, device='cpu'):
t, D = time_idx, output_dim
v = torch.zeros(D, device=device)
i = torch.arange(0, D, device=device)
div_term = torch.exp(i / D * math.log(10000))
v[0::2] = torch.sin(t / div_term[0::2])
v[1::2] = torch.cos(t / div_term[1::2])
return v
def pos_encoding(timesteps, output_dim, device='cpu'):
batch_size = len(timesteps)
device = timesteps.device
v = torch.zeros(batch_size, output_dim, device=device)
for i in range(batch_size):
v[i] = _pos_encoding(timesteps[i], output_dim, device)
return v
上記の実装に基づいて正弦波位置エンコーディングを行うことができます。入力された時刻の配列のtimesteps
毎にoutput_dim
次元の位置エンコーディングが出力されます。
v = pos_encoding(torch.tensor([1, 2, 3]), 16)
print(v.shape) # (3,16)が出力される
たとえば上記のようにpos_encodeing
を実行することで正弦波位置エンコーディングを得ることができます。
拡散過程
ゼロから作るDeepLearning⑤では拡散過程を取り扱うにあたって下記のように実装されるDiffuser
クラスが用いられます。
class Diffuser:
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
self.num_timesteps = num_timesteps
self.device = device
self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def add_noise(self, x_0, t):
T = self.num_timesteps
assert (t >= 1).all() and (t <= T).all()
t_idx = t - 1 # alpha_bars[0] is for t=1
alpha_bar = self.alpha_bars[t_idx] # (N,)
N = alpha_bar.size(0)
alpha_bar = alpha_bar.view(N, 1, 1, 1) # (N, 1, 1, 1)
noise = torch.randn_like(x_0, device=self.device)
x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * noise
return x_t, noise
def denoise(self, model, x, t):
T = self.num_timesteps
assert (t >= 1).all() and (t <= T).all()
t_idx = t - 1 # alphas[0] is for t=1
alpha = self.alphas[t_idx]
alpha_bar = self.alpha_bars[t_idx]
alpha_bar_prev = self.alpha_bars[t_idx-1]
N = alpha.size(0)
alpha = alpha.view(N, 1, 1, 1)
alpha_bar = alpha_bar.view(N, 1, 1, 1)
alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
model.eval()
with torch.no_grad():
eps = model(x, t)
model.train()
noise = torch.randn_like(x, device=self.device)
noise[t == 1] = 0 # no noise at t=1
mu = (x - ((1-alpha) / torch.sqrt(1-alpha_bar)) * eps) / torch.sqrt(alpha)
std = torch.sqrt((1-alpha) * (1-alpha_bar_prev) / (1-alpha_bar))
return mu + noise * std
def reverse_to_img(self, x):
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
x = x.cpu()
to_pil = transforms.ToPILImage()
return to_pil(x)
def sample(self, model, x_shape=(20, 1, 28, 28)):
batch_size = x_shape[0]
x = torch.randn(x_shape, device=self.device)
for i in tqdm(range(self.num_timesteps, 0, -1)):
t = torch.tensor([i] * batch_size, device=self.device, dtype=torch.long)
x = self.denoise(model, x, t)
images = [self.reverse_to_img(x[i]) for i in range(batch_size)]
return images
add_noise
は拡散過程における処理、denoise
は逆拡散過程における処理、sample
メソッドは逆拡散過程に基づく画像の生成にそれぞれ対応します。add_noise
を前節の「DDPMの学習の1)~4)」、denoise
を前節の「学習結果に基づく生成1)~3)」に対応させて抑えておくと良いです。
preprocess = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='./data', download=True, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
diffuser = Diffuser(num_timesteps, device=device)
model = UNet()
model.to(device)
optimizer = Adam(model.parameters(), lr=lr)
losses = []
for epoch in range(epochs):
loss_sum = 0.0
cnt = 0
# generate samples every epoch ===================
# images = diffuser.sample(model)
# show_images(images)
# ================================================
for images, labels in tqdm(dataloader):
optimizer.zero_grad()
x = images.to(device)
t = torch.randint(1, num_timesteps+1, (len(x),), device=device)
x_noisy, noise = diffuser.add_noise(x, t)
noise_pred = model(x_noisy, t)
loss = F.mse_loss(noise, noise_pred)
loss.backward()
optimizer.step()
loss_sum += loss.item()
cnt += 1
loss_avg = loss_sum / cnt
losses.append(loss_avg)
print(f'Epoch {epoch} | Loss: {loss_avg}')
また、学習の実行にあたっては上記のようなコードを実行すれば良いです。torch.nn.functional.mse_loss
で「実際に加えたノイズ(画像と同じサイズ)」と「予測したノイズ(画像と同じサイズ)」の平均二乗和誤差(前節の$(2)$式に対応)を計算し、optimizer = Adam(model.parameters(), lr=lr)
のように生成されるoptimizer
オブジェクトに基づいて最適化が行われます。