論文とコード
- Zhuofan Xia, Xuran Pan, Shiji Song, Li Erran Li, Gao Huang, Vision Transformer with Deformable Attention, arXiv:2201.00520, 2022
- コード(https://github.com/LeapLabTHU/DAT)
掲載している画像は、本論文から引用となります。
概要
トランスフォーマーは画像認識タスクにおいて優れた性能を発揮しますが、グローバルアテンションは計算コストが高く、不要な領域に過度な注目を払うという問題があります。一方、Swin Transformerで採用されているローカルアテンションは、アテンション領域がデータに依存せず固定されているため、長距離の関係をモデル化する能力が不足しています。これらの問題に対処するために、論文ではデータに依存してキーとバリューの位置を調整する「変形可能アテンション」モジュールを提案しています。このモジュールは、下図(d)に示すように、クエリ(星印)に対して、キーとバリューが重要な情報を含む領域(犬全体を指す紫色の丸)から生成されます。このアプローチにより、より有益な特徴を捉えることが可能となり、このモジュールを搭載したDeformable Attention Transformer(DAT)は、画像分類と高密度予測タスクの両方で、競合他モデルを上回る性能を示しています。なお、DATには拡張版であるDAT++も存在しますが、本記事ではDATについて解説します。
モデル | 過去の記事 |
---|---|
Swin Transformer | https://qiita.com/kinkalow/items/cb1024c2c9856ee1afca |
Deformable Attention Transformer
DATは、データに基づいてサンプリング位置を柔軟に変化できる特徴があります。同様の特性を持つモデルとして、DCN(上図(c))やDeformable DETRなどが知られています。しかし、DCNは畳み込みに基づくため、これを単純にアテンションモジュールに拡張すると、特徴マップのサイズ$HW$に対して2乗の計算複雑度となり問題が発生します。Deformable DETRではサンプリング点を減らすことで計算コストを削減していますが、少数のサンプリングにより表現力が制限されます。DATは、これらの課題を改善し、さらにGCNetなどの研究において、異なるクエリに対してほぼ同じアテンションマップが得られるという観察から、各クエリに対して共有されたキーとバリューを使用するアイディアを採用しています。これは、上図(d)のように、赤と青の異なるクエリに対して、同じサンプリング点(紫色の丸)を使用することを意味します。これらのサンプリング位置は、クエリを入力とするオフセットネットワークによって学習されます。学習されたサンプリング位置に基づいてキーとバリューが生成され、そしてクエリを用いて、マルチヘッドセルフアテンションが行われます。以下に詳細を記載します。
Deformable attention module
以下の図は、データに応じてサンプリング位置を動的に変更し、その結果に基いてアテンションが行われるメカニズムが示されています。これを変形可能アテンション(Deformable attention)と呼びます。
入力特徴マップ$x$が与えられると、格子状に配置された参照点$p$が生成されます。この参照点は、$[-1,1]\times [-1,1]$の範囲をとり、入力特徴マップ上の各位置を参照するために利用されます。次に、参照点$p$にオフセット$\Delta p$を加えて変形点$p+\Delta p$を得ます。このオフセットは、特徴マップをクエリ$q=xW_q$に変換し、そのクエリをオフセットネットワーク$\theta_{\text{offset}}$に入力することで得られます($\Delta p=\theta_{\text{offset}}(q)$)。ただし、オフセットが大きくなりすぎないように、$\Delta p \leftarrow s \tanh(\Delta p)$の変換を行います。ここで、$s$はスケール調整です。その後、変形点($p+\Delta p$)でバイリニア補間を用いて特徴マップをサンプリングします。得られた特徴マップ$\tilde{x}$は線形変換に適用され、キーとバリューを生成します($\tilde{k}=\tilde{x}W_k$、$\tilde{v}=\tilde{x}W_k$)。最後に、$q$、$\tilde{k}$、$\tilde{v}$を用いて、相対位置エンコーディングを含むマルチヘッドセルフアテンションを適用します。相対位置エンコーディングを計算する際には、まず、$q$の位置と変形点($p+\Delta p$)を同じ座標系に合わせるために、$q$の位置を$[-1,1]\times [-1,1]$の範囲に変換します。次に、変換した位置と変形点の相対位置を計算し、最後にその相対位置に対するバイアスを求めます。このバイアスは、有限個の相対位置インデックスを持つテーブルに格納されているため、連続的な相対位置のずれに対応するには、テーブルにない中間的な値を補完する必要があります。この補完には、バイリニア補間が使われています。
参照点の数は、サンプリングする個数、つまりキーとバリューの空間サイズに等しいため、計算コストに大きく関与します。具体的には、参照点の数は、入力特徴マップの縦$H$と横$W$を$r$倍縮小した$H_GW_G$個となります。ここで、$H_G=\frac{H}{r}$、$W_G=\frac{W}{r}$、$r\geq1$です。$r$を大きくすることで、参照点の数が減り、アテンションの計算コストを軽減することができます。
オフセットネットワーク$\theta_{\text{offset}}$を使用してオフセットを取得する際も、解像度を$r$倍縮小する必要があります。これには、ストライド$r$の畳み込みを使用します。具体的なネットワーク構成は上図(b)に示されています。最初の5x5depthwise convolution(ストライド=r)で解像度を縮小し、局所特徴を抽出します。その後、GELU非線形関数によってネットワークがより複雑なパターンを学習できるようにし、最後にConv1x1で次元削減($C\rightarrow 2$)を行います。なお、Conv1x1ではバイアス項を削減することで、ネットワークが位置に対して過度なシフトを学習しないようにしています。
オフセットネットワークでは、マルチヘッドセルフアテンションの発想に基づき、入力$q$を複数のグループに分割しています。これにより、各グループごとに異なる変形点が生成され、サンプリングもグループごとに異なり、その結果として多様な生成結果を得られるようになります。
計算複雑度
変形可能アテンション(DA)の計算複雑度は、以下の式で表されます。
\displaylines{
\begin{align}
&\Omega(\text{attention}) = 2HWN_sC + 2HWC^2 + 2N_sC^2 \\
&\Omega(\text{offset network}) = (k^2+2)N_sC \\
&\Omega(\text{DA}) = \Omega(\text{attention}) + \Omega(\text{offset network})
\end{align}
}
ここで、$N_s=H_GW_G=\frac{HW}{r^2}$はサンプリング点、$k$はdepthwise convolutionのカーネルです。オフセットネットワークの計算コストは、アテンション計算のコストと比較すると非常に小さいです。
全体的なアーキテクチャ
DATの全体的なアーキテクチャを以下に示します。
入力画像はまず4x4の畳み込み層(stride=4)とレイヤーノルムによって処理され、4つのステージから構成されるバックボーンに入力されます。2つの連続するステージの間には2x2の畳み込み層(stride=2)とレイヤーノルムが挿入され、空間サイズは半分、チャンネル数は2倍になります。各ステージには、通常のTransformerと同様に、アテンションとMLPが配置され、その前後にはレイヤーノルムと残差結合が適用されます。ただし、各ステージでは2つのTransformer Blockが複数回実行され、それぞれ異なるアテンションが用いられます。最初の2つのステージでは、Swin Transformerと同様なLocal AttentionとShift-Window Attentionが配置され、主に局所的な特徴が学習されます。3番目と4番目のステージでは、Local AttentionとDeformable Attentionが配置され、局所情報と大域情報を交互に集めます。
実験
DATの有効性を検証するために、ImageNet-1Kによる画像分類、COCOによる物体検出とインスタンスセグメンテーション、ADE20Kによるセマンティックセグメンテーションの実験を行っています。さらに、アブレーションスタディや可視化の実験を通して、有効性を詳細に検証しています。
画像認識タスク
DATは、Swin Transformerのステージ3と4を主に改良したモデルであるため、両者を比較することは重要です。その結果、DATは同程度のパラメータ数とFLOPsにもかかわらず、Swin Transformerよりも高い精度を達成しています。特に、画像分類よりも高密度予測タスクにおいて精度の差が大きく、例えば、画像分類では+0.7向上しているのに対し、インスタンスセグメンテーションでは+2.1というより顕著な性能向上が確認されています。これらの結果から、DATは複雑なタスクになるほど、より優れた性能を発揮する可能性を示唆しています。以下の表は画像分類の結果です。
アブレーションスタディ
以下の表は、アテンションモジュール、オフセット、位置エンコーディングの変更が性能に与える影響を比較した結果です。表の最下行がDATの性能を表し、これらの要素をDATから変更すると、性能が低下します。Attn列のPは最初の2つのステージでPyramid Vision TransformerのSRAを使用し、SはSwin Transformerのシフトウィンドウアテンションを使用することを意味します。Offsets列の$\checkmark$はオフセットを採用することを表します。Pos. Embed列のFixedとDWConvは、それぞれ固定学習可能な位置バイアスと、CSWin Transformerで使用されたLePE(depthwise convolutionに基づくモジュール)を表します。
以下の表は、各ステージにおけるシフトウィンドウアテンションを変形可能アテンションに置き換えた場合の性能比較を示します。Stage3とStage4で変形可能アテンションを使用した場合(DAT)が、最も高い精度を実現します。
以下の図は、最大オフセット($s \tanh(\Delta p)$の$s$のこと)を0から16まで変更させたときの精度を示しています。$s$を1から12に変化させても精度に大きな変化が見られないことから、DATがこのハイパーパラメータに対して頑健であることが示されています。論文では$s=2$を選択しています。
可視化
以下の図は、COCOのデータセットにおいて、複数のヘッドによる高いアテンションスコアを持つキー点を円で示しています。円の半径が大きいほど、スコアが高くなります。右下の画像は、テニスボールを打つためにラケットを振る人物を示しており、人物、ラケット、ボール周辺に多くのキー点が配置されています。これは、DATが重要なオブジェクトに焦点を当てていることを示しています。
結論
変形可能アテンションは、従来の固定位置でのアテンション機構とは異なり、データに基づいて可変的な位置にアテンションを向けることができる、スパースなアテンション機構です。可視化実験によって、変形可能アテンションは、重要な部分に多くのアテンションを集中していることが示されています。この柔軟性と効果的なアテンション能力により、変形可能アテンションを搭載したDATは、従来のモデルと比較して、優れた性能を発揮しています。