1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【実装】Vision Transformerをスクラッチ開発してみた

Last updated at Posted at 2023-10-31

はじめに

今回はVision Transformerという画像認識のモデルをゼロから実装してみました。

概要について理解している方は、コード部分からお読みください。

Vision Transformerとは

Vision Transformer(以下ViT)は、画像認識タスクのための深層学習モデルアーキテクチャであり、従来の畳み込みニューラルネットワーク(CNN)に代わる方法として提案されました。ViTは自然言語処理タスクで成功を収めたTransformerモデルの考え方を、画像データに適用するものです。

image.png

具体的には以下のような手順で画像の多クラス分類を行っています。

  1. 画像をN×Nのパッチに分割する。(原論文では16×16のパッチとしている)
  2. 各パッチを全結合層に入力しパッチ単位のベクトルを作成する。これを仮想的なトークンと見なす。
  3. トークン化したベクトルに、Bertと同じようにclsトークンの埋め込み行列を結合し、Position Embeddingを行う。
  4. この過程でバッチサイズ×トークン数 + 1 ×(チャンネル数×パッチサイズ×パッチサイズ)に変形することができる。このベクトルをエンコーダに入力する。
  5. エンコーダの出力からclsトークンのベクトルのみを抽出し、再度全結合層に入力、分類ヘッドから確率分布を計算する。

ViTは大規模なデータセットで事前学習され、特定のタスクに合わせてファインチューニングされます。そのため比較的少ないラベル付きデータで高性能な画像認識モデルを構築できるという利点があります。

加えて、画像認識タスクで従来のCNNベースのモデルに対抗する性能を示し、特に大規模なデータセットや高解像度の画像認識においてSoTAを達成し優れた成果を示しました。

image.png

コード

以下が実装したコードになります。

警告
以下は個人で開発したコードです。誤った部分や、冗長な部分がある可能性があります。

データセット

今回は、ViTが動作するかどうかの実験的な実装なので、データセットには10クラスと比較的少ないCIFAR10を使用しました。

python
import torchvision

image_size = 256
batch_size = 32
root_dir = "./data"

transform =torchvision.transforms.Compose([
    torchvision.transforms.Resize(image_size), torchvision.transforms.ToTensor()
])

train_datasets = torchvision.datasets.CIFAR10(
    root=root_dir, train=True, transform=transform, download=True
)

val_datasets = torchvision.datasets.CIFAR10(
    root=root_dir, train=False, transform=transform, download=True
)

train_dataloader = DataLoader(
    train_datasets, batch_size=batch_size, shuffle=True, drop_last=True
)

val_dataloader = DataLoader(
    val_datasets, batch_size=batch_size, shuffle=False, drop_last=True
)



エンコーダ

EncoderクラスはTransformer Encoder、MultiHeadAttentionクラスはエンコーダの内部にあるMulti-HeadAttentionに対応します。

python
import torch
import torch.nn as nn


class VisonTransformer(nn.Module):
    def __init__(self, num_classes, batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention, Encoder):
        super().__init__()
        self.device = device
        self.batch_size = batch_size
        self.patch_size = patch_size
        self.image_size = image_size
        self.num_channel = num_channel
        self.num_patch = int((image_size / patch_size) * (image_size / patch_size))
        self.num_token = self.num_patch + 1
        self.num_layer = num_layer
        self.num_head = num_head        
        self.cls_id = torch.tensor(0, dtype=torch.long).to(device)
        self.cls_embedding = nn.Embedding(1, embed_hidden_size)
        self.embed_hidden_size = embed_hidden_size
        self.positional_embedding = nn.Embedding(self.num_token, embed_hidden_size)
        self.image_embedding = nn.Linear(patch_size * patch_size * self.num_patch, embed_hidden_size)
        self.layer_norm = nn.LayerNorm((batch_size, embed_hidden_size))
        self.dropout = nn.Dropout(p=0.9)
        self.fc = nn.Linear(embed_hidden_size, num_classes)
        args = (batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention)
        self.setup_layer(num_layer, Encoder, args)

    def image_to_token(self, images, patch_size):
        batch_size, num_channel, width, height = self.batch_size, self.num_channel, self.image_size, self.image_size

        patch_window = torch.ones((patch_size, patch_size), dtype=torch.long)
        patch_window = patch_window.unsqueeze(0).expand(num_channel, patch_size, patch_size)\
            .unsqueeze(0).expand(batch_size, num_channel, patch_size, patch_size)
        
        token_list = []
        for row_idx in range(0, width, patch_size):
            for col_idx in range(0, height, patch_size):
                patch = images[:, :, row_idx: row_idx + patch_size, col_idx:col_idx + patch_size]
                token_list.append(patch)

        token_list = torch.stack(token_list, dim=0).transpose(0, 1).view(batch_size, 256, -1)

        return token_list
    
    def positional_encoding(self, num_token):
        position_ids = torch.tensor(list(range(num_token)), dtype=torch.long).expand(self.batch_size, -1).to(self.device)
        positional_embeds = self.positional_embedding(position_ids)

        return positional_embeds
    
    def setup_layer(self, num_layer, encoder, args):
        layer_list = []
        for _ in range(num_layer):
            layer_list.append(encoder(*args).to(self.device))

        module_list = nn.ModuleList(layer_list)
        self.layer_list = nn.Sequential(*module_list)

    def forward(self, images):
        
        cls_tokens = self.cls_embedding(self.cls_id).unsqueeze(0).expand(self.batch_size, 1, -1)
        image_tokens = self.image_to_token(images, self.patch_size)
        tokens = torch.concat([cls_tokens, image_tokens], dim=1)
        positional_embeds = self.positional_encoding(self.num_token)
        embed_tokens = (tokens + positional_embeds)
        encoded_outputs = self.layer_list(embed_tokens)
        cls = encoded_outputs[:, 0, :]
        layer_norm = self.layer_norm(cls)
        dropout = self.dropout(layer_norm)
        outputs = self.fc(dropout)

        return outputs

        
class Encoder(nn.Module):
    def __init__(self, batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention):
        super().__init__()

        num_patch = int((image_size / patch_size) * (image_size / patch_size))
        num_token = num_patch + 1
        self.layer_norm1 = nn.LayerNorm((batch_size, num_token, embed_hidden_size))
        self.layer_norm2 = nn.LayerNorm((batch_size, num_token, embed_hidden_size))
        self.dropout = nn.Dropout(p=0.9)
        self.mlp = nn.Linear(embed_hidden_size, embed_hidden_size)
        self.attention_layer = MultiHeadAttention(embed_hidden_size, num_head, device)
        self.gelu = nn.GELU().to(device)

    def forward(self, tokens):
        layer_norm1 = self.layer_norm1(tokens)
        dropout1 = self.dropout(layer_norm1)
        skip1 = tokens
        concat_attention = self.attention_layer(dropout1)


        outputs_tmp1 = concat_attention + skip1
        skip2 = outputs_tmp1
        layer_norm2 = self.layer_norm2(outputs_tmp1)
        dropout2 = self.dropout(layer_norm2)
        mlp = self.gelu(self.mlp(dropout2))
        outputs = mlp + skip2

        return outputs

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_hidden_size, num_head, device):

        super().__init__()
        self.attention_layers = []
        self.query_layers = []
        self.key_layers = []
        self.value_layers = []  
        self.num_head = num_head
        self.embed_hidden_size = embed_hidden_size
        self.multi_embed_hidden_size = int(embed_hidden_size / num_head)

        self.device = device
        self.setup_attention()

    def setup_attention(self):

        for number in range(self.num_head):
            self.query_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
            self.key_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
            self.value_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))

        self.query_layers = nn.ModuleList(self.query_layers)
        self.key_layers = nn.ModuleList(self.key_layers)
        self.value_layers = nn.ModuleList(self.value_layers)


    def output_attention(self, tokens):

        for number in range(self.num_head):
            query = self.query_layers[number](tokens)
            key = self.key_layers[number](tokens)
            value = self.value_layers[number](tokens)
            attention = nn.Softmax(dim=-1)((query@torch.transpose(key, 1, 2)) / torch.sqrt(torch.tensor(self.multi_embed_hidden_size)))@value

            if number > 0: concat_attention = torch.concat([concat_attention, attention], dim=-1)
            else:concat_attention = attention

        return concat_attention
    
    def forward(self, tokens):

        concat_attention = self.output_attention(tokens)

        return concat_attention



学習・評価

学習に時間がかかるため、ここでエポック数を1としています。またローカル環境のGPUのスペック上バッチサイズは32が限界でした。

python
import torch
from torcheval.metrics.functional import (multiclass_accuracy, 
                                          multiclass_precision,
                                          multiclass_recall,
                                          multiclass_f1_score
                                          )


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {
    "num_classes":10, 
    "batch_size":32, 
    "image_size":256, 
    "num_channel":3, 
    "patch_size":16, 
    "embed_hidden_size":768,
    "num_layer":12, 
    "num_head":8,
    "device":device,
    "MultiHeadAttention":MultiHeadAttention, 
    "Encoder":Encoder
}

model = VisonTransformer(**kwargs).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
output_dir = "./output"

train_loss_list = []
val_loss_list = []
train_correct_list =[]
val_correct_list = []
presision_list = []
recall_list = []
f1_list = []

epochs = 1
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

min_loss = np.inf
for epoch in range(epochs):
    for step in ["train", "val"]:
        if step == "train":
            model.train()
            dataloader = train_dataloader
        else:
            model.eval()
            dataloader = val_dataloader

        running_loss = 0.0
        running_correct = 0.0

        for batch, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(model.training):
                outputs = model(images)
                pred = torch.argmax(outputs, dim=-1)
                loss = criterion(outputs, labels).sum()
                correct = multiclass_accuracy(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                )

                if model.training:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() / batch_size
            running_correct += correct.item() / batch_size

            if model.training:
                train_loss_list.append(running_loss)
                train_correct_list.append(running_correct)
                print(f"Step: Train Epoch: {epoch + 1}/{epochs} Iterate: {batch + 1}/{len(dataloader)} train_loss: {running_loss / (batch + 1)}")

            else:
                print(f"Step: Train Epoch: {epoch + 1}/{epochs} val_loss: {running_loss / (batch + 1)}\
                      val_correct: {running_correct / (batch + 1)}")
                val_loss_list.append(running_loss)
                val_correct_list.append(running_correct)
                
                precision = multiclass_precision(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                ).item()

                recall = multiclass_recall(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                ).item()

                f1_score = multiclass_f1_score(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                ).item()

                recall_list.append(recall)
                presision_list.append(precision)
                f1_list.append(f1_score)


    if running_loss < min_loss:
        print("Model Save!")
        min_loss = running_loss
        if not os.path.exists(output_dir):
            os.mkdir()
        torch.save(model.state_dict(), os.path.join(output_dir, f"{epoch + 1}.pth"))

まとめ

ViTの実装を通して、一番ためになったのはMulti-HeadAttentionを含めたTransformerエンコーダをゼロから作ることで、大規模言語モデルの理解がさらに深まりました。次回からはT5などのエンコーダ・デコーダモデルの実装も視野に入れたいと思います。

参考文献

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

画像認識の大革命。AI界で話題爆発中の「Vision Transformer」を解説!

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?