LoginSignup
1
1

More than 3 years have passed since last update.

Vision Transformer - Attention = !? 全結合層だけで高精度を出した論文3分解説

Posted at

Vision Transformer から Attentionを取り除くとどうなるのか検証した論文"Do You Even Need Attention?"解説

Do-You-Even-Need-Attention.png
画像:"Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet" Luke Melas-Kyriazi (2021)

論文データ

arxiv : "Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet" Luke Melas-Kyriazi (2021)
github : https://github.com/lukemelas/do-you-even-need-attention

概要

画像分類などの視覚タスクにおいてVision Transformerが優れた性能を発揮するのは、Multi-Head Attentionレイヤーの設計によるものが多い。しかしAttentionがどの程度パフォーマンスに寄与しているのかは明らかではない。この論文ではAttentionをパッチ方向の次元に適応される全結合層に置き換えて性能を確かめた。

ViTとの違い

image.png
画像 : "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",(2021)
Vitとの違いはMulti-Head Attentionを単純な全結合層に変更しただけである。

コード

論文にPytorchのコードが載っているので見た方がイメージしやすいと思います。

from torch import nn

class LinearBlock(nn.Module):
    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act=nn.GELU,
        norm=nn.LayerNorm, n_tokens=197): # 197 = 16**2 + 1
        super().__init__()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # FF over features
        self.mlp1 = Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act=act, drop=drop)
        self.norm1 = norm(dim)
        # FF over patches
        self.mlp2 = Mlp(in_features=n_tokens, hidden_features=int(n_tokens*mlp_ratio), act=act, drop=drop)
        self.norm2 = norm(n_tokens)
    def forward(self, x):
        x = x + self.drop_path(self.mlp1(self.norm1(x)))
        x = x.transpose(-2, -1) #特徴量次元とパッチ次元を入れ替えている
        x = x + self.drop_path(self.mlp2(self.norm2(x)))
        x = x.transpose(-2, -1) #入れ替えたのをもとに戻す
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features, act_layer=nn.GELU, drop=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
return x

非常に単純ですね... (実際はPositional Encodingなどが必要なのでもう少し複雑ですが)

結果

image.png
TinyではDeiTほどの性能は出ていないもののBaseにおいてはViTに近い性能が出ています。Largeの場合にはViT,FF Onlyともに性能は下がっているもののViTを超える性能を出しています。
また、Attentionのみの場合、Tinyモデルで28.2%の性能しか出なかった。

まとめ

  • ViTはAttentionなしでも驚くべき画像分類能力を持っている
  • ViTiのパフォーマンスの凄さはAttentionよりもpatch embeddingや学習法によるものが多いのかもしれない
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1