0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Vision Transformerに入門する (3/N)

Posted at

前回の続きを実装していく.

前回はEncoderのMulti-Head Self-Attention(MHSA)についれ学んだ. 今回は, Encoderの残りの部分について触れる.

EncoderBlock.png

Layer Normalization

正規化と言えばBatch Normalizationがあるが一バッチに含まれているトークン数が違う自然言語処理タスクではそれはできない. 代わりに, 一文中に含まれる単語ベクトルを正規化するLayer Normalizationがよく使われる. 今回もせっかく画像の中の各バッチをうまくベクトル化できたのでそれらの正規化をする方が自然である.

MLPブロック

MLPブロックに関しては,
Linear->GELU->Dropout->Linear->Dropout
という構成らしい. ちなみにGELUはTransformer系のモデルでよく使われる活性化関数とのこと. なぜReLUではなくこちらのstackoverflowの記事によると, ReLUは多くの活性化関数の出力値が0になって何もしなくなってしまうことがあるがGELUは定義域全体(実数全体)で滑らかな関数なのでこうした問題を緩和できるとのこと.

コードは以下の通り.

import torch
import torch.nn as nn
from input_layer import VitInputLayer
from ViT import MultiHeadSelfAttention


class VitEncoderBlock(nn.Module):
    def __init__(
        self,
        emb_dim: int = 384,
        head: int = 8,
        hidden_dim: int = 384 * 4,
        dropout: float = 0,
    ) -> None:
        """
        emb_dim: 埋め込み後のベクトルの長さ
        head: headの数
        hidden_dim: int: 384*4
        dropout: float=0
        """
        super(VitEncoderBlock, self).__init__()
        # 一つ目のLayerNormalization
        self.ln1 = nn.LayerNorm(emb_dim)
        # MHSA(前回実装した)
        self.mhsa = MultiHeadSelfAttention(emb_dim=emb_dim, head=head, dropout=dropout)
        # 二つ目のLayer Normalization
        self.ln2 = nn.LayerNorm(emb_dim)
        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, emb_dim),
            nn.Dropout(dropout),
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        引数
        z:Encoder Block への入力. 形状は(B, N, D)
        B: バッチサイズ, N: パッチ数, D:ベクトルの長さ
        返り値
        out: Encoder Blockの出力. 形状は(B, N, D)
        B: バッチサイズ, N: パッチ数, D: ベクトルの長さ
        """
        out = self.mhsa(self.ln1(z)) + z
        out = self.mlp(self.ln2(out)) + out
        return out


0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?