はじめに
今回はVision Transformerという画像認識のモデルをゼロから実装してみました。
概要について理解している方は、コード部分からお読みください。
Vision Transformerとは
Vision Transformer(以下ViT)は、画像認識タスクのための深層学習モデルアーキテクチャであり、従来の畳み込みニューラルネットワーク(CNN)に代わる方法として提案されました。ViTは自然言語処理タスクで成功を収めたTransformerモデルの考え方を、画像データに適用するものです。
具体的には以下のような手順で画像の多クラス分類を行っています。
- 画像をN×Nのパッチに分割する。(原論文では16×16のパッチとしている)
- 各パッチを全結合層に入力しパッチ単位のベクトルを作成する。これを仮想的なトークンと見なす。
- トークン化したベクトルに、Bertと同じようにclsトークンの埋め込み行列を結合し、Position Embeddingを行う。
- この過程でバッチサイズ×トークン数 + 1 ×(チャンネル数×パッチサイズ×パッチサイズ)に変形することができる。このベクトルをエンコーダに入力する。
- エンコーダの出力からclsトークンのベクトルのみを抽出し、再度全結合層に入力、分類ヘッドから確率分布を計算する。
ViTは大規模なデータセットで事前学習され、特定のタスクに合わせてファインチューニングされます。そのため比較的少ないラベル付きデータで高性能な画像認識モデルを構築できるという利点があります。
加えて、画像認識タスクで従来のCNNベースのモデルに対抗する性能を示し、特に大規模なデータセットや高解像度の画像認識においてSoTAを達成し優れた成果を示しました。
コード
以下が実装したコードになります。
警告
以下は個人で開発したコードです。誤った部分や、冗長な部分がある可能性があります。
データセット
今回は、ViTが動作するかどうかの実験的な実装なので、データセットには10クラスと比較的少ないCIFAR10を使用しました。
import torchvision
image_size = 256
batch_size = 32
root_dir = "./data"
transform =torchvision.transforms.Compose([
torchvision.transforms.Resize(image_size), torchvision.transforms.ToTensor()
])
train_datasets = torchvision.datasets.CIFAR10(
root=root_dir, train=True, transform=transform, download=True
)
val_datasets = torchvision.datasets.CIFAR10(
root=root_dir, train=False, transform=transform, download=True
)
train_dataloader = DataLoader(
train_datasets, batch_size=batch_size, shuffle=True, drop_last=True
)
val_dataloader = DataLoader(
val_datasets, batch_size=batch_size, shuffle=False, drop_last=True
)
エンコーダ
Encoder
クラスはTransformer Encoder、MultiHeadAttention
クラスはエンコーダの内部にあるMulti-HeadAttentionに対応します。
import torch
import torch.nn as nn
class VisonTransformer(nn.Module):
def __init__(self, num_classes, batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention, Encoder):
super().__init__()
self.device = device
self.batch_size = batch_size
self.patch_size = patch_size
self.image_size = image_size
self.num_channel = num_channel
self.num_patch = int((image_size / patch_size) * (image_size / patch_size))
self.num_token = self.num_patch + 1
self.num_layer = num_layer
self.num_head = num_head
self.cls_id = torch.tensor(0, dtype=torch.long).to(device)
self.cls_embedding = nn.Embedding(1, embed_hidden_size)
self.embed_hidden_size = embed_hidden_size
self.positional_embedding = nn.Embedding(self.num_token, embed_hidden_size)
self.image_embedding = nn.Linear(patch_size * patch_size * self.num_patch, embed_hidden_size)
self.layer_norm = nn.LayerNorm((batch_size, embed_hidden_size))
self.dropout = nn.Dropout(p=0.9)
self.fc = nn.Linear(embed_hidden_size, num_classes)
args = (batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention)
self.setup_layer(num_layer, Encoder, args)
def image_to_token(self, images, patch_size):
batch_size, num_channel, width, height = self.batch_size, self.num_channel, self.image_size, self.image_size
patch_window = torch.ones((patch_size, patch_size), dtype=torch.long)
patch_window = patch_window.unsqueeze(0).expand(num_channel, patch_size, patch_size)\
.unsqueeze(0).expand(batch_size, num_channel, patch_size, patch_size)
token_list = []
for row_idx in range(0, width, patch_size):
for col_idx in range(0, height, patch_size):
patch = images[:, :, row_idx: row_idx + patch_size, col_idx:col_idx + patch_size]
token_list.append(patch)
token_list = torch.stack(token_list, dim=0).transpose(0, 1).view(batch_size, 256, -1)
return token_list
def positional_encoding(self, num_token):
position_ids = torch.tensor(list(range(num_token)), dtype=torch.long).expand(self.batch_size, -1).to(self.device)
positional_embeds = self.positional_embedding(position_ids)
return positional_embeds
def setup_layer(self, num_layer, encoder, args):
layer_list = []
for _ in range(num_layer):
layer_list.append(encoder(*args).to(self.device))
module_list = nn.ModuleList(layer_list)
self.layer_list = nn.Sequential(*module_list)
def forward(self, images):
cls_tokens = self.cls_embedding(self.cls_id).unsqueeze(0).expand(self.batch_size, 1, -1)
image_tokens = self.image_to_token(images, self.patch_size)
tokens = torch.concat([cls_tokens, image_tokens], dim=1)
positional_embeds = self.positional_encoding(self.num_token)
embed_tokens = (tokens + positional_embeds)
encoded_outputs = self.layer_list(embed_tokens)
cls = encoded_outputs[:, 0, :]
layer_norm = self.layer_norm(cls)
dropout = self.dropout(layer_norm)
outputs = self.fc(dropout)
return outputs
class Encoder(nn.Module):
def __init__(self, batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention):
super().__init__()
num_patch = int((image_size / patch_size) * (image_size / patch_size))
num_token = num_patch + 1
self.layer_norm1 = nn.LayerNorm((batch_size, num_token, embed_hidden_size))
self.layer_norm2 = nn.LayerNorm((batch_size, num_token, embed_hidden_size))
self.dropout = nn.Dropout(p=0.9)
self.mlp = nn.Linear(embed_hidden_size, embed_hidden_size)
self.attention_layer = MultiHeadAttention(embed_hidden_size, num_head, device)
self.gelu = nn.GELU().to(device)
def forward(self, tokens):
layer_norm1 = self.layer_norm1(tokens)
dropout1 = self.dropout(layer_norm1)
skip1 = tokens
concat_attention = self.attention_layer(dropout1)
outputs_tmp1 = concat_attention + skip1
skip2 = outputs_tmp1
layer_norm2 = self.layer_norm2(outputs_tmp1)
dropout2 = self.dropout(layer_norm2)
mlp = self.gelu(self.mlp(dropout2))
outputs = mlp + skip2
return outputs
class MultiHeadAttention(nn.Module):
def __init__(self, embed_hidden_size, num_head, device):
super().__init__()
self.attention_layers = []
self.query_layers = []
self.key_layers = []
self.value_layers = []
self.num_head = num_head
self.embed_hidden_size = embed_hidden_size
self.multi_embed_hidden_size = int(embed_hidden_size / num_head)
self.device = device
self.setup_attention()
def setup_attention(self):
for number in range(self.num_head):
self.query_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
self.key_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
self.value_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
self.query_layers = nn.ModuleList(self.query_layers)
self.key_layers = nn.ModuleList(self.key_layers)
self.value_layers = nn.ModuleList(self.value_layers)
def output_attention(self, tokens):
for number in range(self.num_head):
query = self.query_layers[number](tokens)
key = self.key_layers[number](tokens)
value = self.value_layers[number](tokens)
attention = nn.Softmax(dim=-1)((query@torch.transpose(key, 1, 2)) / torch.sqrt(torch.tensor(self.multi_embed_hidden_size)))@value
if number > 0: concat_attention = torch.concat([concat_attention, attention], dim=-1)
else:concat_attention = attention
return concat_attention
def forward(self, tokens):
concat_attention = self.output_attention(tokens)
return concat_attention
学習・評価
学習に時間がかかるため、ここでエポック数を1としています。またローカル環境のGPUのスペック上バッチサイズは32が限界でした。
import torch
from torcheval.metrics.functional import (multiclass_accuracy,
multiclass_precision,
multiclass_recall,
multiclass_f1_score
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {
"num_classes":10,
"batch_size":32,
"image_size":256,
"num_channel":3,
"patch_size":16,
"embed_hidden_size":768,
"num_layer":12,
"num_head":8,
"device":device,
"MultiHeadAttention":MultiHeadAttention,
"Encoder":Encoder
}
model = VisonTransformer(**kwargs).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
output_dir = "./output"
train_loss_list = []
val_loss_list = []
train_correct_list =[]
val_correct_list = []
presision_list = []
recall_list = []
f1_list = []
epochs = 1
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
min_loss = np.inf
for epoch in range(epochs):
for step in ["train", "val"]:
if step == "train":
model.train()
dataloader = train_dataloader
else:
model.eval()
dataloader = val_dataloader
running_loss = 0.0
running_correct = 0.0
for batch, (images, labels) in enumerate(dataloader):
images = images.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(model.training):
outputs = model(images)
pred = torch.argmax(outputs, dim=-1)
loss = criterion(outputs, labels).sum()
correct = multiclass_accuracy(
input=outputs,
target=labels,
num_classes=kwargs["num_classes"],
average="micro"
)
if model.training:
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() / batch_size
running_correct += correct.item() / batch_size
if model.training:
train_loss_list.append(running_loss)
train_correct_list.append(running_correct)
print(f"Step: Train Epoch: {epoch + 1}/{epochs} Iterate: {batch + 1}/{len(dataloader)} train_loss: {running_loss / (batch + 1)}")
else:
print(f"Step: Train Epoch: {epoch + 1}/{epochs} val_loss: {running_loss / (batch + 1)}\
val_correct: {running_correct / (batch + 1)}")
val_loss_list.append(running_loss)
val_correct_list.append(running_correct)
precision = multiclass_precision(
input=outputs,
target=labels,
num_classes=kwargs["num_classes"],
average="micro"
).item()
recall = multiclass_recall(
input=outputs,
target=labels,
num_classes=kwargs["num_classes"],
average="micro"
).item()
f1_score = multiclass_f1_score(
input=outputs,
target=labels,
num_classes=kwargs["num_classes"],
average="micro"
).item()
recall_list.append(recall)
presision_list.append(precision)
f1_list.append(f1_score)
if running_loss < min_loss:
print("Model Save!")
min_loss = running_loss
if not os.path.exists(output_dir):
os.mkdir()
torch.save(model.state_dict(), os.path.join(output_dir, f"{epoch + 1}.pth"))
まとめ
ViTの実装を通して、一番ためになったのはMulti-HeadAttentionを含めたTransformerエンコーダをゼロから作ることで、大規模言語モデルの理解がさらに深まりました。次回からはT5などのエンコーダ・デコーダモデルの実装も視野に入れたいと思います。
参考文献
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale