1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Vision Transformerに入門する (2/N)

Posted at

前回の続きを実装する. 前回では,input layerが画像を4つのパッチに分け, それぞれを384次元のベクトル空間に埋め込むことを見た.

InputLayer.png

Encoderはそのベクトルとclass tokenを受け取り, downstream taskのためにclass tokenを吐き出す.

ViT_全体図.png

今回はencoderの中の一機構であるMulti-Head Self-Attentionを見ていく.
Multi-HeadとSelf-Attentionの二つのパートに分かれる.

Self-Attentionについて

Self-Attentionにはquery, key, valueの三つがある.
動画検索に似ている. query = 検索キーワード, key = 動画のタイトルや説明, value = 実際の動画というイメージ. queryとkeyの類似性を測ってそれをもとに動画を「混ぜて」、それを提示するイメージ.

query_key.png
上記の計算からトークン1とトークン4が似ていると判断されていることが分かる. すると直観的には, attentionと内積を取ったトークン1のvalueはトークン4と似ているという情報を保持することになる. これは画像の異なる部分の関係性を捉えていてることになる.

values_compute.png

Multi-Head Self-Attentionについて

Self-Attentionの考えが肝であり, Multi-Headは本質的には何も新しいことをやっていない. Self-Attentionでトークン間の関係を得られたので、違う関係性についても学習しようというのがMulti-Headの考えである.

【実装】

viewreshapeはほとんど一緒だが, viewはshapeを変更する際, Tensorの各要素がメモリ上で要素順に並んでいなければいけないのに対して, reshapeはメモリ上で要素順に並んでいない場合は物理的なコピーを作るらしい.

import torch
import torch.nn as nn
import torch.nn.functional as F
from input_layer import VitInputLayer


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, emb_dim: int = 384, head: int = 3, dropout: float = 0):
        """
        dropout: ドロップアウト率.
        """
        super(MultiHeadSelfAttention, self).__init__()
        self.head = head
        self.emb_dim = emb_dim
        self.head_dim = emb_dim // head
        self.sqrt_dh = self.head**0.5

        # 入力をq, k, vに埋め込むための線形層
        self.w_q = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_v = nn.Linear(emb_dim, emb_dim, bias=False)

        self.attn_drop = nn.Dropout(dropout)
        self.w_o = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.Dropout(dropout),
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        引数: 形状は(B, N, D)
        Bはバッチ数, Nはトークン数, Dは次元.
        """
        batch_size, num_patch, _ = z.size()

        # 埋め込み
        q = self.w_q(z)
        k = self.w_k(z)
        v = self.w_v(z)

        # q, k, vをヘッドに分ける
        # まずベクトルをヘッドの個数に分ける
        # (B, N, D) -> (B, N, h, D//h)
        q = q.view(batch_size, num_patch, self.head, self.head_dim)
        k = k.view(batch_size, num_patch, self.head, self.head_dim)
        v = v.view(batch_size, num_patch, self.head, self.head_dim)

        # Self-Attentionができるようにする
        # (B, N, h, D//h) -> (B, h, N, D//h)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # 内積を取るためにkを転置する
        # (B, h, N, D//h) -> (B, h, D//h, N)
        k_T = k.transpose(2, 3)

        # 内積を取る
        # (B, h, N, D//h) x (B, h, D//h, N) -> (B, h, N, N)
        dots = (q @ k_T) / self.sqrt_dh

        # 列方向にsoftmaxを取る
        attn = F.softmax(dots, dim=-1)

        # 加重和
        # (B, h, N, N) x (B, h, N, D//h) -> (B, h, N, D//h)
        out = attn @ v
        # (B, h, N, D//h) -> (B, N, h, D//h)
        out = out.transpose(1, 2)
        # (B, N, h, D//h) -> (B, N, D)
        out = out.reshape(batch_size, num_patch, self.emb_dim)

        # 出力層
        out = self.w_o(out)
        return out


batch_size, channel, height, width = 2, 3, 32, 32
x = torch.randn(batch_size, channel, height, width)
# VitInputLayerについては
input_layer = VitInputLayer(num_patch_row=2)
z_0 = input_layer(x)

mhsa = MultiHeadSelfAttention()
out = mhsa(z_0)

# (2, 5, 384)( = (B, N, D))となるはず.
print(out.shape)
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?