Vision Transformer から Attentionを取り除くとどうなるのか検証した論文"Do You Even Need Attention?"解説
画像:"Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet" Luke Melas-Kyriazi (2021)
論文データ
arxiv : "Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet" Luke Melas-Kyriazi (2021)
github : https://github.com/lukemelas/do-you-even-need-attention
概要
画像分類などの視覚タスクにおいてVision Transformerが優れた性能を発揮するのは、Multi-Head Attentionレイヤーの設計によるものが多い。しかしAttentionがどの程度パフォーマンスに寄与しているのかは明らかではない。この論文ではAttentionをパッチ方向の次元に適応される全結合層に置き換えて性能を確かめた。
ViTとの違い
画像 : "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",(2021)
Vitとの違いはMulti-Head Attentionを単純な全結合層に変更しただけである。
コード
論文にPytorchのコードが載っているので見た方がイメージしやすいと思います。
from torch import nn
class LinearBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act=nn.GELU,
norm=nn.LayerNorm, n_tokens=197): # 197 = 16**2 + 1
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# FF over features
self.mlp1 = Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act=act, drop=drop)
self.norm1 = norm(dim)
# FF over patches
self.mlp2 = Mlp(in_features=n_tokens, hidden_features=int(n_tokens*mlp_ratio), act=act, drop=drop)
self.norm2 = norm(n_tokens)
def forward(self, x):
x = x + self.drop_path(self.mlp1(self.norm1(x)))
x = x.transpose(-2, -1) #特徴量次元とパッチ次元を入れ替えている
x = x + self.drop_path(self.mlp2(self.norm2(x)))
x = x.transpose(-2, -1) #入れ替えたのをもとに戻す
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features, act_layer=nn.GELU, drop=0.):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, in_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
非常に単純ですね... (実際はPositional Encodingなどが必要なのでもう少し複雑ですが)
結果
TinyではDeiTほどの性能は出ていないもののBaseにおいてはViTに近い性能が出ています。Largeの場合にはViT,FF Onlyともに性能は下がっているもののViTを超える性能を出しています。
また、Attentionのみの場合、Tinyモデルで28.2%の性能しか出なかった。
まとめ
- ViTはAttentionなしでも驚くべき画像分類能力を持っている
- ViTiのパフォーマンスの凄さはAttentionよりもpatch embeddingや学習法によるものが多いのかもしれない