前回の続きを実装する. 前回では,input layerが画像を4つのパッチに分け, それぞれを384次元のベクトル空間に埋め込むことを見た.
Encoderはそのベクトルとclass tokenを受け取り, downstream taskのためにclass tokenを吐き出す.
今回はencoderの中の一機構であるMulti-Head Self-Attentionを見ていく.
Multi-HeadとSelf-Attentionの二つのパートに分かれる.
Self-Attentionについて
Self-Attentionにはquery, key, valueの三つがある.
動画検索に似ている. query = 検索キーワード, key = 動画のタイトルや説明, value = 実際の動画というイメージ. queryとkeyの類似性を測ってそれをもとに動画を「混ぜて」、それを提示するイメージ.
上記の計算からトークン1とトークン4が似ていると判断されていることが分かる. すると直観的には, attentionと内積を取ったトークン1のvalueはトークン4と似ているという情報を保持することになる. これは画像の異なる部分の関係性を捉えていてることになる.
Multi-Head Self-Attentionについて
Self-Attentionの考えが肝であり, Multi-Headは本質的には何も新しいことをやっていない. Self-Attentionでトークン間の関係を得られたので、違う関係性についても学習しようというのがMulti-Headの考えである.
【実装】
view
とreshape
はほとんど一緒だが, viewはshapeを変更する際, Tensorの各要素がメモリ上で要素順に並んでいなければいけないのに対して, reshapeはメモリ上で要素順に並んでいない場合は物理的なコピーを作るらしい.
import torch
import torch.nn as nn
import torch.nn.functional as F
from input_layer import VitInputLayer
class MultiHeadSelfAttention(nn.Module):
def __init__(self, emb_dim: int = 384, head: int = 3, dropout: float = 0):
"""
dropout: ドロップアウト率.
"""
super(MultiHeadSelfAttention, self).__init__()
self.head = head
self.emb_dim = emb_dim
self.head_dim = emb_dim // head
self.sqrt_dh = self.head**0.5
# 入力をq, k, vに埋め込むための線形層
self.w_q = nn.Linear(emb_dim, emb_dim, bias=False)
self.w_k = nn.Linear(emb_dim, emb_dim, bias=False)
self.w_v = nn.Linear(emb_dim, emb_dim, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.w_o = nn.Sequential(
nn.Linear(emb_dim, emb_dim),
nn.Dropout(dropout),
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
引数: 形状は(B, N, D)
Bはバッチ数, Nはトークン数, Dは次元.
"""
batch_size, num_patch, _ = z.size()
# 埋め込み
q = self.w_q(z)
k = self.w_k(z)
v = self.w_v(z)
# q, k, vをヘッドに分ける
# まずベクトルをヘッドの個数に分ける
# (B, N, D) -> (B, N, h, D//h)
q = q.view(batch_size, num_patch, self.head, self.head_dim)
k = k.view(batch_size, num_patch, self.head, self.head_dim)
v = v.view(batch_size, num_patch, self.head, self.head_dim)
# Self-Attentionができるようにする
# (B, N, h, D//h) -> (B, h, N, D//h)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 内積を取るためにkを転置する
# (B, h, N, D//h) -> (B, h, D//h, N)
k_T = k.transpose(2, 3)
# 内積を取る
# (B, h, N, D//h) x (B, h, D//h, N) -> (B, h, N, N)
dots = (q @ k_T) / self.sqrt_dh
# 列方向にsoftmaxを取る
attn = F.softmax(dots, dim=-1)
# 加重和
# (B, h, N, N) x (B, h, N, D//h) -> (B, h, N, D//h)
out = attn @ v
# (B, h, N, D//h) -> (B, N, h, D//h)
out = out.transpose(1, 2)
# (B, N, h, D//h) -> (B, N, D)
out = out.reshape(batch_size, num_patch, self.emb_dim)
# 出力層
out = self.w_o(out)
return out
batch_size, channel, height, width = 2, 3, 32, 32
x = torch.randn(batch_size, channel, height, width)
# VitInputLayerについては
input_layer = VitInputLayer(num_patch_row=2)
z_0 = input_layer(x)
mhsa = MultiHeadSelfAttention()
out = mhsa(z_0)
# (2, 5, 384)( = (B, N, D))となるはず.
print(out.shape)