Vision Transformerについて解説する「Vision Transformer入門」という本があったのでそれを読んでコードを触りながら, 自分の思考を整理したいと思う.
今回はViTのInput Layerまで実装してみようと思う.
NNに慣れる
NNをtorch.nn
で構築していく. 図にすると, 以下の図である.
import torch
import torch.nn as nn
# シンプルなMLP層を作る.
class SimpleMlp(nn.Module):
def __init__(
self,
vec_length: int = 16,
hidden_unit_1: int = 8,
hidden_unit_2: int = 2,
):
"""
引数:
vec_length:入力ベクトルの長さ
hidden_unit_1:一つ目の線形層のニューロン数
hidden_unit_2:二つ目の線形層のニューロン数
"""
super(SimpleMlp, self).__init__()
self.layer1 = nn.Linear(vec_length, hidden_unit_1)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(hidden_unit_1, hidden_unit_2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.layer1(x)
out = self.relu(out)
out = self.layer2(x)
return out
vec_length = 16
hidden_unit_1 = 8
hidden_unit_2 = 2
batch_size = 4
x = torch.randn(batch_size, vec_length)
net = SimpleMlp(
vec_length=vec_length,
hidden_unit_1=hidden_unit_1,
hidden_unit_2=hidden_unit_2,
)
out = net(x)
#出力はtorch.Size([4, 2])になるはず
print(out.shape)
ViTの全体像
ViTは大きく三つの部分で構成されている.
- Input Layer (パッチに分割+画像埋め込み)
- Encoder (SelfAttentionがある肝となるアーキテクチャ)
- MLP Head (Downstreamタスク用のヘッダ)
順に見ていきましょう.
Input Layer
Input Layerでは次の四つの操作が行われる.
① 画像をパッチに分割
② 各パッチを埋め込む
③ クラストークンのの追加
④ 位置埋め込み
まず全体のコードを貼る.
import torch
import torch.nn as nn
class VitInputLayer(nn.Module):
def __init__(
self,
in_channels: int = 3,
emb_dim: int = 384, # 埋め込み後のベクトル長さ
num_patch_row: int = 2, # 高さ方向のパッチの数
image_size: int = 32,
):
super(VitInputLayer, self).__init__()
self.in_channels = in_channels
self.emb_dim = emb_dim
self.num_patch_row = num_patch_row
self.image_size = image_size
self.num_patch = self.num_patch_row**2
# パッチの大きさ. patchの一辺は16
self.patch_size = int(self.image_size // self.num_patch_row)
# 入力画像のパッチの分割 & パッチの埋め込みを一気に行う層
self.patch_emb_layer = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.emb_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
self.pos_emb = nn.Parameter(torch.randn(1, self.num_patch + 1, emb_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
引数 x: 入力画像. 形状は(B, C, H, W). B = バッチ数, C = チャネル数, H = 画像高さ, W = 画像幅.
返り値: Encoderへの入力. (B, N, D). N = トークン数, D = ベクトルの長さ
"""
out = self.patch_emb_layer(x)
out = out.flatten(2)
out = out.transpose(1, 2)
out = torch.cat([self.cls_token.repeat(repeats=(x.size(0), 1, 1)), out], dim=1)
out = out + self.pos_emb
return out
まず32x32の画像を想定する. なので, image_sizeは32である. 各パッチを384次元のベクトルに変換する.
ポイントとなる箇所を解説する.
一つ目のポイントとして, nn.Conv2d
モジュールがある. これはパッチの分割と埋めこみのどちらもやってくれる.(Convolution層についてわかりやすい記事がありました. )
32x32の画像上をカーネルサイズ16のフィルタをストライド16で走らせるので, 4分割さる.
図解すると以下の通りである. つまり, self.patch_emb_layer(x).shape = torch.Size([2, 384, 2, 2])
である.
二つ目のポイントはクラストークン(cls_token
)である. クラストークンはそれを使って分類なりのdownstream taskをすることから, 画像全体の情報を圧縮したベクトルと言える. また, nn.Parameterはnn.Parameterは例えばnn.Linear()などのパラメータの型である. (この記事が参考になる)
三つ目のポイントはflatten()
の処理である. 引数の2はstart_dim
のことで, flatten
を開始する次元である. つまり, self.patch_emb_layer(x).shape = torch.Size([2, 384, 4])
である. transpose
で次元1と次元2を転置する.
四つ目のポイントはrepeat()
である. バッチに含まれている画像の数だけクラストークンを用意しなければいけないので複製する必要がある. 似たような処理にexpand
があるが, メモリを共有してはだめなので, repeat()を使う. (参照記事:repeatとexpandの違いについて)
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
また, torch.cat
の引数のdimは結合する次元である.
個人的に特に勉強になったポイント
画像の埋め込みはnn.Conv2d
で行っているということ.
時間があるときにチェックすること
CNNの逆伝播についての詳細.(数式など)