LoginSignup
2
3

minDiffusionを使ってみる

Posted at

初めに

最近話題のdiffusion modelの理解のため、下記記事を参考に実装をしつつ理解を進めました。
下記は、どうやら必要最低限の実装でDDPMを実装しているようですが、しっかりと学習と生成ができました。
VAEやアテンション層は入っていないみたいなので、今後アテンションを理解したら追加実装して確認してみたいです。
https://github.com/cloneofsimo/minDiffusion
因みに、自分はこちらの記事からこの存在を知りました。
https://note.com/gcem156/n/nd8d00f0a3159
また、今回はほぼこの「minDiffusion」の紹介になります。
コード内にコメントを追記しているので、写経時の参考になればと思います。
※間違えていたら、ご指摘いただけると幸いです。

コード

class DDPM(nn.Module):
    def __init__(
        self,
        eps_model: nn.Module,
        betas: Tuple[float, float],
        n_T: int,
        criterion: nn.Module = nn.MSELoss(),
    ) -> None:
        super(DDPM, self).__init__()
        self.eps_model = eps_model

        # register_buffer allows us to freely access these tensors by name. It helps device placement.
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.criterion = criterion

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
        This implements Algorithm 1 in the paper.
        """
        
        # ノイズを付与する総ステップの中からランダムに選択
        # t ~ Uniform(0, n_T)
        _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(x.device)
        
        # ノイズを標準正規分布から生成
        # eps ~ N(0, 1)
        eps = torch.randn_like(x)  

        # 時刻t時点でのノイズを付与した画像x_tを生成
        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * eps
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # 損失関数に、「加えたノイズepsと、Unetの出力」を与える
        return self.criterion(eps, self.eps_model(x_t, _ts / self.n_T))

    # ランダムにノイズを生成し、学習済みモデルを用いて画像を生成する
    def sample(self, n_sample: int, size, device) -> torch.Tensor:

        # 正規分布からランダムにノイズを生成
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)

        # This samples accordingly to Algorithm 2. It is exactly the same logic.

        # step数の逆から0までfor文で推論
        for i in range(self.n_T, 0, -1):
            # 正規分布から乱数Zを出力
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # 数式ではeps(xt, t)となっているが、t=(i / self.n_T)となっているが、
            # 学習時のtも(i / self.n_T)で与えていれば問題ない?
            # 損失関数の与え方も_ts / self.n_Tとなっているため、対応が取れており問題無さそう
            eps = self.eps_model(
                x_i, torch.tensor(i / self.n_T).to(device).repeat(n_sample, 1)
            )
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )

        return x_i

unetも見ていきます。

"""
Simple Unet Structure.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


# Unetで用いられる畳み込み層の基本構造
class Conv3(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        # 最初の畳み込み構造
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
        )
        # ダウンサンプリング部の畳み込み構造
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
        )

        self.is_res = is_res

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.main(x)
        # 残差ブロックが存在する場合は、畳み込む前の値を加える
        if self.is_res:
            x = x + self.conv(x)
            return x / 1.414
        else:
            return self.conv(x)


# ダウンサンプリング層の定義
class UnetDown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(UnetDown, self).__init__()
        layers = [Conv3(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.model(x)


# アップサンプリング層の定義
class UnetUp(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(UnetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            Conv3(out_channels, out_channels),
            Conv3(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = torch.cat((x, skip), 1)
        x = self.model(x)

        return x

# 時刻情報を埋め込むためのsin関数
# 周期を変数とすれば、時刻毎に異なる数値をとることができる
class TimeSiren(nn.Module):
    def __init__(self, emb_dim: int) -> None:
        super(TimeSiren, self).__init__()

        self.lin1 = nn.Linear(1, emb_dim, bias=False)
        self.lin2 = nn.Linear(emb_dim, emb_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 1)
        x = torch.sin(self.lin1(x))
        x = self.lin2(x)
        return x


# 上記のパーツを用いてUnetを定義
class NaiveUnet(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, n_feat: int = 256) -> None:
        super(NaiveUnet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.n_feat = n_feat

        self.init_conv = Conv3(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)
        self.down3 = UnetDown(2 * n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.ReLU())

        self.timeembed = TimeSiren(2 * n_feat)

        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4),
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, 2 * n_feat)
        self.up2 = UnetUp(4 * n_feat, n_feat)
        self.up3 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Conv2d(2 * n_feat, self.out_channels, 3, 1, 1)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:

        x = self.init_conv(x)

        down1 = self.down1(x)
        down2 = self.down2(down1)
        down3 = self.down3(down2)

        thro = self.to_vec(down3)
        # 時刻情報の埋め込みベクトル算出
        temb = self.timeembed(t).view(-1, self.n_feat * 2, 1, 1)
        # 時刻情報の埋め込み
        thro = self.up0(thro + temb)

        up1 = self.up1(thro, down3) + temb
        up2 = self.up2(up1, down2)
        up3 = self.up3(up2, down1)

        out = self.out(torch.cat((up3, x), 1))

        return out

学習と生成時のコード

import torch
print(torch.cuda.is_available())
print(torch.__version__ )

torch.cuda.current_device()

from PIL import Image
import os

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
class CatDataset(Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.file_list = [os.path.join(path,file) for file in files]
        self.transform = transforms.Compose(
        [
        transforms.Resize(64),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ]
    )
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, i):
        img = Image.open(self.file_list[i])
        return self.transform(img)

from typing import Dict, Optional, Tuple
from sympy import Ci
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.utils import save_image, make_grid

from mindiffusion.unet import NaiveUnet
from mindiffusion.ddpm import DDPM

from pathlib import Path

def train_mindiff(
    n_epoch: int = 100, device: str = "cuda:0", load_pth: Optional[str] = None
) -> None:

    ###設定############
    n_feat = 256
    batch_size = 32
    lr = 5e-5
    train_path = Path("./train")
    dataset_dir = train_path
    ###################

    ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=n_feat), betas=(1e-4, 0.02), n_T=1000)

    if load_pth is not None:
        ddpm.load_state_dict(torch.load(load_pth))

    ddpm.to(device)

    dataset = CatDataset(dataset_dir)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lr)

    for i in range(n_epoch):
        print(f"Epoch {i} : ")
        ddpm.train()

        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:
            print(x.shape[0])
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        ddpm.eval()
        with torch.no_grad():
            xh = ddpm.sample(8, (3, 64, 64), device)
            xset = torch.cat([xh, x[:8]], dim=0)
            grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
            save_image(grid, f"./contents2/ddpm_sample_cat{str(i).zfill(3)}.png")

            # save model
            torch.save(ddpm.state_dict(), f"./ddpm_cat{i%3}_3.pth")

if __name__ == '__main__':
    train_mindiff(300)

今回の目的は、上手く生成するというよりはスクラッチ(写経)実装しつつ、モデルの動作原理を理解するところにあるため、学習時のコードは先ほど紹介したhttps://note.com/gcem156/n/nd8d00f0a3159 記載のものをほぼそのまま使用しております。
また、学習データは下記を使用しました。
https://paperswithcode.com/dataset/afhq
この内、猫の画像のみを用いています。

生成例

バッチサイズ128,画像サイズ32×32で学習した時の結果です。他のパラメータはほぼ上記の学習例のものになっています。
また、載せている画像は、上段8枚が生成画像で下段8枚が対応するノイズを付与する前の画像です。

epoch:0
ddpm_sample_anime000.png
まだほぼノイズですね。

epoch:51
ddpm_sample_anime051.png
微妙に輪郭やパーツが見えてきました!
しかし、まだまだ不安定な感じですね。

epoch:98
ddpm_sample_anime098.png
だいぶマシになりましたが、まだまだではありますね。

では、次に画像サイズを倍(64×64)に、バッチサイズを32にして学習してみました。(本記事投稿時、絶賛学習中です。。。)

epoch:0
ddpm_sample_cat000.png
相変わらず、最初はノイズですね。

epoch:26
ddpm_sample_cat026.png
26epoch目で既に、ある程度の輪郭を生成できている画像も存在しています。

epoch:35
ddpm_sample_cat035.png
この時点で最新のものですが、やはりまだまだ学習できていないですね。
3060とかのGPUを積んだPCが欲しいですね。。。

最後に

如何でしたでしょうか?
「minDiffusion」は最低限の実装のため、生成モデルの構造や動作を理解しやすいと思います。
自分は、ここにVAEやATTENTION_blockを追加してみようかなと考えております。
追加前後の学習速度や生成結果の比較もできますしね!
因みに、動作環境の制約でどうしても画像サイズや学習時間に制限が生じてしまう場合は、
VAEを追加すると大分変わるのかなと思います。
論文だと確か「512×512」を入力しますが、VAEで「64×64」にサイズダウン(重要な特徴のみに絞る)し、ノイズ付与やUnetを追加した最後に「512×512」にデコードします。
そのため、重たい学習(Unet)は圧縮した画像で行えるので、多少は変わるのかなと思います。
VAEは次に確認しようと思っているので、追加実装して上手く動作したら再度記事化しようと思います。

以上です。
本記事が何かしらの参考になれば幸いです。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3