原論文
- Going deeper with Image Transformers
https://arxiv.org/abs/2103.17239
関連研究
Vison Transformerの解説
https://qiita.com/wakanomi/items/55bba80338615c7cce73
結論
ViTの問題点としてself-attentionの類似度計算で特徴の全く異なるクラストークンも含めて処理を行うと,そのクラストークンが悪さをしてしまい精度低下につながる可能性がある.この問題をクラストークンを途中から追加する方法で解決する.
概要
Class-Attention in Image Transformers (CaiT) とはViTの派生モデルで,学習可能な対角行列パラメータを加える.これをLayerScaleと呼ぶ.また,Class Tokenを最初に追加するのではなく,途中から追加する.この途中からClass Tokenを追加した以降はSelf-Attentionを使用する代わりにClass-Attentionを使用する.Class-AttentionはClass TokenをQueryに使用し,他のTokenをKeyとValueに使用したアテンション機構である.また,Class Token適用後はclass tokenのみにMLPを行う.これにより,ImageNetの画像分類タスクにおいて精度を上回る精度を達成する.
モデル構造
CaiTのモデル構造としてViTと比較してクラストークンの追加を遅らせていることがわかる.また,クラストークンの追加後にSA(Self-Attention)ではなく,CA(Class Attention)を採用している.また,クラストークンの追加後はクラストークンのみをFFNを通して処理を行う.最後に,クラストークンの出力を使用してクラス識別する.
LayerScale
LayerScaleとはTransformer Encoderの処理内でTransformerの層が深くなるにつれ学習が不安定になる問題を改善すべく順接続の部分に学習可能な対角行列を導入したものである.
下図はTransformer Encoderの構造を示し,1週目にNormとSA,2週目にNormとFFNの処理を行う.スキップ接続の目的として,認識によって急激に変化する表現を認識を行う前の状態と足し合わせることで急激な変化を抑制し,過学習しない大きいモデルを実現している.CaiTではその抑制を強めるために学習可能な青色の行列の値で積をとる(Caitでは学習可能な行列を採用しているが,人手で決めたスカラー値(例:0.1)などでもよい).
通常のTransformer Encoder
x'_l = x_l + SA(\eta(x_l))\\
x_{l+1} = x'_l + FFN(\eta(x'_l))
LayerScaleを追加したTransformer Encoder
x'_l = x_l + diag(\lambda_{l,1},...,\lambda_{l,d}) \times SA(\eta(x_l))\\
x_{l+1} = x'_l + diag(\lambda'_{l,1},...,\lambda_{l,d}) \times FFN(\eta(x'_l))
LayerScaleで使用される対角行列の$\lambda$の重みは0に近い小さい値で初期値を設定する.0に近い値に設定する理由として,SAとFFNの処理の更新の割合が小さくなり,ほとんど割合で処理がスキップ接続で計算される.これにより,より深い層の学習が可能になる.また,スカラー値でなく対角行列を使用する理由として,チャンネルごとの重み付けが可能になり,学習可能なスカラー値で層全体を調整するよりも多様な最適化が可能になる.
Class-Attention
途中から追加したClass Tokenはself-attentionではなくClass-Attentionを採用する.Class-Attentionはクエリの入力をクラストークンのみとして類似度計算を行うこと,Q,K,Vに変換する際にそれに対応するバイアスを追加すること以外はself-attentionと同じである.self-attentionはクエリの入力サイズに依存するため,出力はクラストークンのサイズのみが出力される.
Class-Attentionの式
A = \mathrm{Softmax}(Q \cdot K^T/\sqrt{d/h})\\
\mathrm{out}_{CA} = W_oAV+b_o
Q,K,Vはそれぞれ以下の式で求める
Q = W_qx_{class}+b_q\\
K = W_kz+b_k\\
V = W_vz+b_v
ここで,ここで,$d$は次元数,$h$はヘッド数,$W$は入力を変換させるための行列,$b$はバイアスを示す.
実験
画像分類
CaiTはImageNetの他にImageNetより小さいデータセットにおいても従来手法を上回る精度を達成した.
Class Tokenの追加場所
Class Tokenの追加場所として,最後の2層をClass Tokenを追加して,Class-Attentionの処理を行うことが計算コストと精度の面から有効である.
下図の左は層数を示す.例に最下段は15(全部で15層):12(SAの層数)+3(CAの層数)で構成される.
考察
ViTの派生モデルには,そもそもクラストークンを使用せずに各ヘッドのGlobal Average Pooling(GAP)でクラス識別しても性能は大きく変化しないらしい.元々は自然言語モデルのBERTの名残りで残っているだけなのか.
まとめ
今回は,Class-Attention in Image Transformers (CaiT) について説明した.クラストークンを途中から追加し,LayerScaleで急激な表現変化を抑制することで,ViTの精度を上回る画像分類精度を達成した.