この記事を読む前に
本記事は前記事の続きとなっています.
前記事の内容を理解していることを前提として本記事は執筆されています.理論を理解したい!という方はぜひ前記事も読んでいただけると嬉しいです.
なお,本記事は実装編1となっています.実装編は3までを予定しています.1,2,3で扱う内容は以下のとおりです.
- 実装編1(今ココ)
- パラメータを推定するモデルであるUViTの解説・実装.
- 実装編2
- UViTを用いた拡散モデルDDPMの設計.
- 実装編3
- CIFAR10データセットを用いたDDPMの訓練.
UViTとは?
理論編の記事で説明したとおり,拡散過程における時刻$t$の$x_t$から$x_{t-1}$を得るためには,$p(x_{t-1}|x_t)$を推定するモデルを利用しますが,実装上では,ニューラルネットワークは時刻$t$および時刻$t$におけるノイズが加わった画像$x_t$を受け取り,時刻$t$に乗せられたノイズ$\epsilon$を出力します.
ここで,今回使用するモデル内部には非決定的な部分はありません.したがって,$x_t$と$t$が与えられているとき,$\epsilon$は決定的に定まることに注意してください.
一般に,拡散モデルではUNetが使われます.しかし,著者はViTを理解している一方UNetに詳しくないので実装がViTとほぼ同じであり,Vision Transformerアーキテクチャの勉強にもなるので今回はUViT[1]というモデルを扱います.
アーキテクチャ
以下に,[1]より引用したUViTのアーキテクチャ図を示します.
ほとんどVision Transformerと同じですが,以下の違いがあります.
- 入力に画像パッチだけでなく,時刻とクラスもトークンとして与えている
- 長距離スキップコネクションがある
- 最終層に畳み込みがある
ViTのアーキテクチャについてはOmiitaさんの記事がよくまとまっていますので,そちらをご参照ください.
それではUViTを実装していきましょう.MHSA, FFN, TransformerEncoderの部分は折りたたみます.
MHSA, FFN, TransformerEncoderの実装
class MHSA(nn.Module):
def __init__(self, dim, n_heads) -> None:
super().__init__()
assert dim % n_heads == 0, 'dimはn_headsの倍数である必要があります'
self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
self.to_v = nn.Linear(dim, dim)
self.MHSA = nn.MultiheadAttention(dim, n_heads, batch_first=True)
def forward(self, x):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
return self.MHSA(q, k, v)[0]
class FFN(nn.Module):
def __init__(self, dim, inner_dim) -> None:
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Linear(inner_dim, dim))
def forward(self, x):
return self.net(x)
class TransformerEncoder(nn.Module):
def __init__(self, dim, n_heads, inner_dim) -> None:
super().__init__()
self.MHSA = MHSA(dim, n_heads)
self.FFN = FFN(dim, inner_dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
x = self.MHSA(self.norm1(x)) + x
x = self.FFN(self.norm2(x)) + x
return x
上で折りたたんだMHSA, FFN, TransformerEncoderを利用して,以下のようにUViTを定義します.
import torch
import torch.nn as nn
import torchvision
from einops.layers.torch import Rearrange
from einops import repeat
class UViT(nn.Module):
def __init__(self, height=32, width=32, channels = 3, num_classes = 10, T=1000, patch_size = 4, dim=256, depth=9, n_heads=8, inner_dim=1024):
#パッチ化関係
super().__init__()
self.patching = Rearrange("B C (h ph) (w pw) -> B (h w) (C ph pw)", pw = patch_size, ph = patch_size)
self.patch_to_token = nn.Linear(patch_size*patch_size*channels , dim)
self.num_patches = (height // patch_size) * (width // patch_size)
# トークン関係
self.cls_token = nn.Parameter(torch.randn(num_classes, dim))
self.time_token = nn.Parameter(torch.randn(T, dim))
self.positional_token = nn.Parameter(torch.randn(2+self.num_patches, dim))
# エンコーダ
self.encoders = nn.ModuleList([TransformerEncoder(dim, n_heads, inner_dim) for _ in range(depth)])
self.depth = (depth - 1) // 2
self.reverse_linear = nn.Linear(dim, patch_size*patch_size*channels)
self.unpatching = Rearrange("B (h w) (C ph pw) -> B C (h ph) (w pw) ", pw = patch_size, ph = patch_size, h = height // patch_size)
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x, t, c):
x = self.patching(x)
x = self.patch_to_token(x)
if c is not None :
x = torch.concat((self.time_token[t].unsqueeze(1), self.cls_token[c].unsqueeze(1), x), dim=1)
x = x + repeat(self.positional_token.unsqueeze(0), '1 n d -> b n d', b=x.shape[0] )
depth = 0
longskip = []
for encoder in self.encoders:
if depth < self.depth:
x = encoder(x)
longskip.append(x)
elif depth > self.depth :
x = encoder(x) + longskip[2*self.depth-depth]
else:
x = encoder(x)
depth += 1
x = x[:, 2:self.num_patches+2]
x = self.reverse_linear(x)
x = self.unpatching(x)
return self.conv(x)
パッチ化・エンコーダの部分は同じですが,クラス用のトークンself.cls_token
と時刻用のトークンself.time_token
があることおよび,long skip connection,最後の畳込み層があることがViTとの大きな相違点です.
if __name__ == '__main__':
# デバッグ用
device = 'cuda:0'
model = UViT().to(device)
input=torch.randn(4, 3, 32, 32).to(device)
cls_token=torch.Tensor([1,2,3,4]).long().to(device)
time_token=torch.Tensor([1,2,3,4]).long().to(device)
output = model(input, time_token, cls_token)
print(output.shape)
これを実行して,[4,3,32,32]
が出てきたら無事に実装できています.
次回予告
次回は,DDPMクラスを実装します.
参考文献
[1] All are Worth Words: A ViT Backbone for Diffusion Models, Bao et al., 2023.