初めに
最近話題のdiffusion modelの理解のため、下記記事を参考に実装をしつつ理解を進めました。
下記は、どうやら必要最低限の実装でDDPMを実装しているようですが、しっかりと学習と生成ができました。
VAEやアテンション層は入っていないみたいなので、今後アテンションを理解したら追加実装して確認してみたいです。
https://github.com/cloneofsimo/minDiffusion
因みに、自分はこちらの記事からこの存在を知りました。
https://note.com/gcem156/n/nd8d00f0a3159
また、今回はほぼこの「minDiffusion」の紹介になります。
コード内にコメントを追記しているので、写経時の参考になればと思います。
※間違えていたら、ご指摘いただけると幸いです。
コード
class DDPM(nn.Module):
def __init__(
self,
eps_model: nn.Module,
betas: Tuple[float, float],
n_T: int,
criterion: nn.Module = nn.MSELoss(),
) -> None:
super(DDPM, self).__init__()
self.eps_model = eps_model
# register_buffer allows us to freely access these tensors by name. It helps device placement.
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.criterion = criterion
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
This implements Algorithm 1 in the paper.
"""
# ノイズを付与する総ステップの中からランダムに選択
# t ~ Uniform(0, n_T)
_ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(x.device)
# ノイズを標準正規分布から生成
# eps ~ N(0, 1)
eps = torch.randn_like(x)
# 時刻t時点でのノイズを付与した画像x_tを生成
x_t = (
self.sqrtab[_ts, None, None, None] * x
+ self.sqrtmab[_ts, None, None, None] * eps
) # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
# We should predict the "error term" from this x_t. Loss is what we return.
# 損失関数に、「加えたノイズepsと、Unetの出力」を与える
return self.criterion(eps, self.eps_model(x_t, _ts / self.n_T))
# ランダムにノイズを生成し、学習済みモデルを用いて画像を生成する
def sample(self, n_sample: int, size, device) -> torch.Tensor:
# 正規分布からランダムにノイズを生成
x_i = torch.randn(n_sample, *size).to(device) # x_T ~ N(0, 1)
# This samples accordingly to Algorithm 2. It is exactly the same logic.
# step数の逆から0までfor文で推論
for i in range(self.n_T, 0, -1):
# 正規分布から乱数Zを出力
z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
# 数式ではeps(xt, t)となっているが、t=(i / self.n_T)となっているが、
# 学習時のtも(i / self.n_T)で与えていれば問題ない?
# 損失関数の与え方も_ts / self.n_Tとなっているため、対応が取れており問題無さそう
eps = self.eps_model(
x_i, torch.tensor(i / self.n_T).to(device).repeat(n_sample, 1)
)
x_i = (
self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
+ self.sqrt_beta_t[i] * z
)
return x_i
unetも見ていきます。
"""
Simple Unet Structure.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# Unetで用いられる畳み込み層の基本構造
class Conv3(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
# 最初の畳み込み構造
self.main = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.GroupNorm(8, out_channels),
nn.ReLU(),
)
# ダウンサンプリング部の畳み込み構造
self.conv = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.GroupNorm(8, out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.GroupNorm(8, out_channels),
nn.ReLU(),
)
self.is_res = is_res
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.main(x)
# 残差ブロックが存在する場合は、畳み込む前の値を加える
if self.is_res:
x = x + self.conv(x)
return x / 1.414
else:
return self.conv(x)
# ダウンサンプリング層の定義
class UnetDown(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(UnetDown, self).__init__()
layers = [Conv3(in_channels, out_channels), nn.MaxPool2d(2)]
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
# アップサンプリング層の定義
class UnetUp(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(UnetUp, self).__init__()
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
Conv3(out_channels, out_channels),
Conv3(out_channels, out_channels),
]
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
# 時刻情報を埋め込むためのsin関数
# 周期を変数とすれば、時刻毎に異なる数値をとることができる
class TimeSiren(nn.Module):
def __init__(self, emb_dim: int) -> None:
super(TimeSiren, self).__init__()
self.lin1 = nn.Linear(1, emb_dim, bias=False)
self.lin2 = nn.Linear(emb_dim, emb_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(-1, 1)
x = torch.sin(self.lin1(x))
x = self.lin2(x)
return x
# 上記のパーツを用いてUnetを定義
class NaiveUnet(nn.Module):
def __init__(self, in_channels: int, out_channels: int, n_feat: int = 256) -> None:
super(NaiveUnet, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.n_feat = n_feat
self.init_conv = Conv3(in_channels, n_feat, is_res=True)
self.down1 = UnetDown(n_feat, n_feat)
self.down2 = UnetDown(n_feat, 2 * n_feat)
self.down3 = UnetDown(2 * n_feat, 2 * n_feat)
self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.ReLU())
self.timeembed = TimeSiren(2 * n_feat)
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4),
nn.GroupNorm(8, 2 * n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, 2 * n_feat)
self.up2 = UnetUp(4 * n_feat, n_feat)
self.up3 = UnetUp(2 * n_feat, n_feat)
self.out = nn.Conv2d(2 * n_feat, self.out_channels, 3, 1, 1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
x = self.init_conv(x)
down1 = self.down1(x)
down2 = self.down2(down1)
down3 = self.down3(down2)
thro = self.to_vec(down3)
# 時刻情報の埋め込みベクトル算出
temb = self.timeembed(t).view(-1, self.n_feat * 2, 1, 1)
# 時刻情報の埋め込み
thro = self.up0(thro + temb)
up1 = self.up1(thro, down3) + temb
up2 = self.up2(up1, down2)
up3 = self.up3(up2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
学習と生成時のコード
import torch
print(torch.cuda.is_available())
print(torch.__version__ )
torch.cuda.current_device()
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class CatDataset(Dataset):
def __init__(self, path):
files = os.listdir(path)
self.file_list = [os.path.join(path,file) for file in files]
self.transform = transforms.Compose(
[
transforms.Resize(64),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
def __len__(self):
return len(self.file_list)
def __getitem__(self, i):
img = Image.open(self.file_list[i])
return self.transform(img)
from typing import Dict, Optional, Tuple
from sympy import Ci
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from mindiffusion.unet import NaiveUnet
from mindiffusion.ddpm import DDPM
from pathlib import Path
def train_mindiff(
n_epoch: int = 100, device: str = "cuda:0", load_pth: Optional[str] = None
) -> None:
###設定############
n_feat = 256
batch_size = 32
lr = 5e-5
train_path = Path("./train")
dataset_dir = train_path
###################
ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=n_feat), betas=(1e-4, 0.02), n_T=1000)
if load_pth is not None:
ddpm.load_state_dict(torch.load(load_pth))
ddpm.to(device)
dataset = CatDataset(dataset_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
optim = torch.optim.Adam(ddpm.parameters(), lr=lr)
for i in range(n_epoch):
print(f"Epoch {i} : ")
ddpm.train()
pbar = tqdm(dataloader)
loss_ema = None
for x in pbar:
print(x.shape[0])
optim.zero_grad()
x = x.to(device)
loss = ddpm(x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
pbar.set_description(f"loss: {loss_ema:.4f}")
optim.step()
ddpm.eval()
with torch.no_grad():
xh = ddpm.sample(8, (3, 64, 64), device)
xset = torch.cat([xh, x[:8]], dim=0)
grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
save_image(grid, f"./contents2/ddpm_sample_cat{str(i).zfill(3)}.png")
# save model
torch.save(ddpm.state_dict(), f"./ddpm_cat{i%3}_3.pth")
if __name__ == '__main__':
train_mindiff(300)
今回の目的は、上手く生成するというよりはスクラッチ(写経)実装しつつ、モデルの動作原理を理解するところにあるため、学習時のコードは先ほど紹介したhttps://note.com/gcem156/n/nd8d00f0a3159 記載のものをほぼそのまま使用しております。
また、学習データは下記を使用しました。
https://paperswithcode.com/dataset/afhq
この内、猫の画像のみを用いています。
生成例
バッチサイズ128,画像サイズ32×32で学習した時の結果です。他のパラメータはほぼ上記の学習例のものになっています。
また、載せている画像は、上段8枚が生成画像で下段8枚が対応するノイズを付与する前の画像です。
epoch:51
微妙に輪郭やパーツが見えてきました!
しかし、まだまだ不安定な感じですね。
epoch:98
だいぶマシになりましたが、まだまだではありますね。
では、次に画像サイズを倍(64×64)に、バッチサイズを32にして学習してみました。(本記事投稿時、絶賛学習中です。。。)
epoch:26
26epoch目で既に、ある程度の輪郭を生成できている画像も存在しています。
epoch:35
この時点で最新のものですが、やはりまだまだ学習できていないですね。
3060とかのGPUを積んだPCが欲しいですね。。。
最後に
如何でしたでしょうか?
「minDiffusion」は最低限の実装のため、生成モデルの構造や動作を理解しやすいと思います。
自分は、ここにVAEやATTENTION_blockを追加してみようかなと考えております。
追加前後の学習速度や生成結果の比較もできますしね!
因みに、動作環境の制約でどうしても画像サイズや学習時間に制限が生じてしまう場合は、
VAEを追加すると大分変わるのかなと思います。
論文だと確か「512×512」を入力しますが、VAEで「64×64」にサイズダウン(重要な特徴のみに絞る)し、ノイズ付与やUnetを追加した最後に「512×512」にデコードします。
そのため、重たい学習(Unet)は圧縮した画像で行えるので、多少は変わるのかなと思います。
VAEは次に確認しようと思っているので、追加実装して上手く動作したら再度記事化しようと思います。
以上です。
本記事が何かしらの参考になれば幸いです。