2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Vision Transformerに入門する (1/N)

Last updated at Posted at 2024-01-21

Vision Transformerについて解説する「Vision Transformer入門」という本があったのでそれを読んでコードを触りながら, 自分の思考を整理したいと思う.
今回はViTのInput Layerまで実装してみようと思う.

NNに慣れる

NNをtorch.nnで構築していく. 図にすると, 以下の図である.

NN.png

import torch
import torch.nn as nn


#  シンプルなMLP層を作る.
class SimpleMlp(nn.Module):
    def __init__(
        self,
        vec_length: int = 16,
        hidden_unit_1: int = 8,
        hidden_unit_2: int = 2,
    ):
        """
        引数:
            vec_length:入力ベクトルの長さ
            hidden_unit_1:一つ目の線形層のニューロン数
            hidden_unit_2:二つ目の線形層のニューロン数
        """
        super(SimpleMlp, self).__init__()
        self.layer1 = nn.Linear(vec_length, hidden_unit_1)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_unit_1, hidden_unit_2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.layer1(x)
        out = self.relu(out)
        out = self.layer2(x)
        return out


vec_length = 16
hidden_unit_1 = 8
hidden_unit_2 = 2
batch_size = 4

x = torch.randn(batch_size, vec_length)

net = SimpleMlp(
    vec_length=vec_length,
    hidden_unit_1=hidden_unit_1,
    hidden_unit_2=hidden_unit_2,
)
out = net(x)

#出力はtorch.Size([4, 2])になるはず
print(out.shape)

ViTの全体像

ViTは大きく三つの部分で構成されている.

  • Input Layer (パッチに分割+画像埋め込み)
  • Encoder (SelfAttentionがある肝となるアーキテクチャ)
  • MLP Head (Downstreamタスク用のヘッダ)

ViT_全体図.png

順に見ていきましょう.

Input Layer

Input Layerでは次の四つの操作が行われる.
① 画像をパッチに分割
② 各パッチを埋め込む
③ クラストークンのの追加
④ 位置埋め込み

まず全体のコードを貼る.

import torch
import torch.nn as nn


class VitInputLayer(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        emb_dim: int = 384,  # 埋め込み後のベクトル長さ
        num_patch_row: int = 2,  # 高さ方向のパッチの数
        image_size: int = 32,
    ):
        super(VitInputLayer, self).__init__()
        self.in_channels = in_channels
        self.emb_dim = emb_dim
        self.num_patch_row = num_patch_row
        self.image_size = image_size

        self.num_patch = self.num_patch_row**2

        # パッチの大きさ. patchの一辺は16
        self.patch_size = int(self.image_size // self.num_patch_row)

        # 入力画像のパッチの分割 & パッチの埋め込みを一気に行う層
        self.patch_emb_layer = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.emb_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))

        self.pos_emb = nn.Parameter(torch.randn(1, self.num_patch + 1, emb_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        引数 x: 入力画像. 形状は(B, C, H, W). B = バッチ数, C = チャネル数, H = 画像高さ, W = 画像幅.
        返り値: Encoderへの入力. (B, N, D). N = トークン数, D = ベクトルの長さ
        """
        out = self.patch_emb_layer(x)
        out = out.flatten(2)
        out = out.transpose(1, 2)
        out = torch.cat([self.cls_token.repeat(repeats=(x.size(0), 1, 1)), out], dim=1)
        out = out + self.pos_emb
        return out

まず32x32の画像を想定する. なので, image_sizeは32である. 各パッチを384次元のベクトルに変換する. 

ポイントとなる箇所を解説する.

一つ目のポイントとして, nn.Conv2dモジュールがある. これはパッチの分割と埋めこみのどちらもやってくれる.(Convolution層についてわかりやすい記事がありました. )

32x32の画像上をカーネルサイズ16のフィルタをストライド16で走らせるので, 4分割さる.
図解すると以下の通りである. つまり, self.patch_emb_layer(x).shape = torch.Size([2, 384, 2, 2])である.
nnConv2d.png

二つ目のポイントはクラストークン(cls_token)である. クラストークンはそれを使って分類なりのdownstream taskをすることから, 画像全体の情報を圧縮したベクトルと言える. また, nn.Parameterはnn.Parameterは例えばnn.Linear()などのパラメータの型である. (この記事が参考になる)

三つ目のポイントはflatten()の処理である. 引数の2はstart_dimのことで, flattenを開始する次元である. つまり, self.patch_emb_layer(x).shape = torch.Size([2, 384, 4])である. transposeで次元1と次元2を転置する.

四つ目のポイントはrepeat()である. バッチに含まれている画像の数だけクラストークンを用意しなければいけないので複製する必要がある. 似たような処理にexpandがあるが, メモリを共有してはだめなので, repeat()を使う. (参照記事:repeatとexpandの違いについて)

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])

また, torch.catの引数のdimは結合する次元である.

個人的に特に勉強になったポイント

画像の埋め込みはnn.Conv2dで行っているということ.

時間があるときにチェックすること

CNNの逆伝播についての詳細.(数式など)

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?