論文とコード
- Zilong Huang, Youcheng Ben, Guozhong Luo, Pei Cheng, Gang Yu, Bin Fu, Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer, arXiv:2106.03650, 2021
- コード(https://github.com/mulinmeng/Shuffle-Transformer)
挿入している画像は、本論文からの引用となります。
はじめに
Transformerの中核をなすのはアテンション機構であり、Shuffle Transformerはそのアテンションにおいて局所的な情報と大域的な情報を組み合わせる巧妙なメカニズムを持っています。関連する過去の記事では、AxWin TransformerとMaxViTも局所的と大域的な情報を取り扱っており、共通して局所的なアテンションはSwin Transformerと同様なウィンドウベースのアテンションを使用しています。一方、大域的なアテンションは、AxWin Transformerが軸ベース上で、MaxViTがグリッド上でアテンションを行います。Shuffle Transformerはその名の通り、Shuffleを用いたアテンションメカニズムを取り入れていますが、実際の機構はグリッドアテンションと同様です。異なる点として、階層構造を形成する際のモジュール内で、MaxViTは畳み込みネットワークを活用しており、これがShuffle Transformerの単一の畳み込みとは大きく異なるアプローチであることが挙げられます。この違いが性能差に影響を与えた可能性があり、ImageNetの画像分類実験によれば、Shuffle TransformerはMaxViTよりも性能が低いですが、FLOPsとパラメータ数は小さいです。発表時期を考慮すると、Shuffle Transformerが2021年6月、MaxViTが2022年4月、AxWin Transformerが2023年5月に発表されたことから、Shuffle Transformerはその後の研究によって性能が追い越されたと考えられます。本記事では、Shuffle Transformerの構造について解説し、MaxViTとの違いを説明します。また、前回のMaxViTに関する記事で充分に触れられなかったグリッド分割に対応する「空間シャッフル」についても、具体的な説明を追加します。
Model | Size | Params | FLOPs | IN-1K Top-1 Acc.(%) |
---|---|---|---|---|
Shuffle-T | 224 | 29M | 4.6G | 82.5 |
MaxViT-T | 224 | 31M | 5.6G | 83.62 |
AxWin-T | 224 | 22M | 3.5G | 83.9 |
Shuffle-S | 224 | 50M | 8.9G | 83.5 |
MaxViT-S | 224 | 69M | 11.7G | 84.45 |
AxWin-S | 224 | 48M | 7.6G | 84.6 |
Shuffle-S | 224 | 88M | 15.6G | 84.0 |
MaxViT-B | 224 | 120M | 23.4G | 84.95 |
AxWin-B | 224 | 84M | 12.7G | 85.1 |
モデル | 過去の記事 |
---|---|
AxWin Transformer | https://qiita.com/kinkalow/items/e93621582a2af4446e98 |
MaxViT | https://qiita.com/kinkalow/items/aa7508d3d34a2c827d40 |
Swin Transformer | https://qiita.com/kinkalow/items/cb1024c2c9856ee1afca |
Shuffle Transformer
以下は、Shuffle Transformerの全体的なアーキテクチャです。
Shuffle Transformerは、Token Embedding、Token Merging、Shuffle Transformer Blockで構成されています。Token Embeddingでは、2層の畳み込み3x3(ストライド=2)と1層の畳み込み1x1(ストライド=1)が使用されています。階層的な表現を生成するために、Token Mergingでは、ストライド2の畳み込み2x2が使用されています。Shuffle Transformer Blockは、各ステージで複数回実行され、特徴量マップの形状を保持します。
Shuffle Window-based Multi-head Self-Attention
ウィンドウベースのマルチヘッドセルフアテンション(WMSA)は、完全なセルフアテンションの計算コストを大幅に削減するために提案された手法です。WMSAは、入力特徴マップを複数のウィンドウに分割し、各ウィンドウ内でセルフアテンションを計算することで、計算複雑度を線形に抑えます。ウィンドウ分割と逆変換(ウィンドウから特徴マップへの変換)の操作が必要ですが、その計算コストは無視できるほど小さいです。
WMSAは、局所的なウィンドウ内で要素同士の相互作用を行いますが、Shuffle WMSAでは、異なる分割方式を採用し、大域的な要素同士の相互作用を実現します。この手法では、Spatial Shuffleと呼ばれる操作を導入し、遠く離れた要素同士をグループ化します。そして、このグループ内でセルフアテンションを適用することで、長距離の依存関係を効率的に捕捉することができます。Shuffle WMSAは、WMSAの分割と逆変換の操作が異なるだけであり、さらにその計算コストも非常に小さいです。このため、この手法は計算複雑度を線形に抑えることができます。
Spatial ShuffleとSpatial Alignment
次の図は、2つの連続するセルフアテンション、WMSA1とWMSA2を適用したときの状況を示しています。
図(a)は、シャッフルなしの状態を表しており、この場合、ウィンドウ内の情報しか関係せず、受容野がウィンドウ内に制限されています。その結果、ウィンドウ間の情報の流れが遮断され、表現力が低下します。この問題に対処するため、単純な解決策として、図(b)に示すように、WMSA2が異なるウィンドウからの入力データを取得できるようにします。これにより、異なるウィンドウのトークンが関連付けられます。この改善を効果的に実現する手法が、図(c)で示されています。Spatial Shuffleを使用してトークンをシャッフルし、異なるウィンドウのトークンを導入します。その後、ウィンドウ内でセルフアテンションを適用し、最後にSpatial Alignmentでシャッフルしたトークンを元の状態に戻します1。
空間1次元におけるSpatial ShuffleとSpatial Alignmentの具体的な操作について説明します。ここでは、ウィンドウサイズが$M$で、$N$個のトークンが存在すると仮定します。Spatial Shuffleの手順は、空間次元を($M, \frac{N}{M}$)の形状に分割し、その後転置してから結合します。例えば、以下の図で示した$N=9$かつ$M=3$の場合、操作を施すと、window1には離れたウィンドウのトークン、具体的には4と7が含まれ、これにより長距離のウィンドウ接続が構築されます。window2とwindow3にも同じことが言えます。Spatial Alignmentの手順は、シャッフルされた状態を元に戻すため、Spatial Shuffleの逆の操作を実施します。具体的な手順は、($\frac{N}{M}, M$)に分割し、その後転置してから結合します。以下の図では、Spatial Alignmentの分割、転置、結合の出力データが、それぞれ、Spatial Shuffleの結合、転置、分割への入力データに対応していることが示されています。したがって、逆の操作手順が行われています。
Spatial Shuffle 分割 転置 結合
+-+-+-+-+-+-+-+-+-+ +-+-+-+ +-+-+-+ +-+-+-+-+-+-+-+-+-+
|1|2|3|4|5|6|7|8|9| --> |1|2|3| --> |1|4|7| --> |1|4|7|2|5|8|3|6|9|
+-+-+-+-+-+-+-+-+-+ +-+-+-+ +-+-+-+ +-+-+-+-+-+-+-+-+-+
|<--->|<--->|<--->| |4|5|6| |2|5|8| |<--->|<--->|<--->|
+-+-+-+ +-+-+-+ | | |
|7|8|9| |3|6|9| | | window3
+-+-+-+ +-+-+-+ | window2
window1
Spatial Alignment 分割 転置 結合
+-+-+-+-+-+-+-+-+-+ +-+-+-+ +-+-+-+ +-+-+-+-+-+-+-+-+-+
|1|4|7|2|5|8|3|6|9| --> |1|4|7| --> |1|2|3| --> |1|2|3|4|5|6|7|8|9|
+-+-+-+-+-+-+-+-+-+ +-+-+-+ +-+-+-+ +-+-+-+-+-+-+-+-+-+
|<--->|<--->|<--->| |2|5|8| |4|5|6| |<--->|<--->|<--->|
+-+-+-+ +-+-+-+ | | |
|3|6|9| |7|8|9| | | window3
+-+-+-+ +-+-+-+ | window2
window1
空間2次元$(h,w)$の場合も同様な操作でシャッフルと逆変換が可能です。具体的には、Spatial Shuffleは、入力を$(M_h, \frac{h}{M_h}, M_w, \frac{w}{M_w})$に分割し、その後順序を変換して$(\frac{h}{M_h}, \frac{w}{M_w}, M_h, M_w)$とし、最後に$(\frac{h w}{M_h M_w}, M_h M_w)$に結合します。ここで、$M_h=M_w=M$であり、出力は(ウィンドウ数、トークン数)を表します。具体的な例として、$h=w=6$、$M=3$の場合を考えると、以下のコードで示すように、Spatial Shuffle後のwindow1には、異なるウィンドウのトークン、4、16、24、26、28が混ざり、等間隔に離れた要素が同じウィンドウ内に配置されます。Spatial AlignmentはSpatial Shuffleの逆操作で、入力を$(\frac{h}{M_h}, \frac{w}{M_w}, M_h, M_w)$に分割し、その後順序を変換して$(M_h, \frac{h}{M_h}, M_w, \frac{w}{M_w})$とし、最後に$(h, w)$に結合します。
Spatial ShuffleとSpatial Alignmentは、MaxViTのグリッド分割と結合操作と同等の機能を果たします。これは、Shuffle TransformerとMaxViTのアテンション機構は、本質的に同じものになることを意味します。
import torch
from einops import rearrange
h = w = 6
M = 3
x = torch.arange(h * w).view(h, w)
# Spatial Shuffle
shuffles = rearrange(x, '(Mh hh) (Mw ww) -> (hh ww) (Mh Mw)', Mh=M, Mw=M)
# Spatial Alignment
x_restored = rearrange(shuffles, '(hh ww) (Mh Mw) -> (Mh hh) (Mw ww)', hh=h // M, Mh=M, Mw=M)
print(x)
# tensor([[ 0, 1, 2, 3, 4, 5], |
# [ 6, 7, 8, 9, 10, 11], window1 | window2
# [12, 13, 14, 15, 16, 17], --------+--------
# [18, 19, 20, 21, 22, 23], window3 | window4
# [24, 25, 26, 27, 28, 29], |
# [30, 31, 32, 33, 34, 35]])
print(shuffles)
# tensor([[ 0, 2, 4, 12, 14, 16, 24, 26, 28], <--- window1
# [ 1, 3, 5, 13, 15, 17, 25, 27, 29], <--- window2
# [ 6, 8, 10, 18, 20, 22, 30, 32, 34], <--- window3
# [ 7, 9, 11, 19, 21, 23, 31, 33, 35]]) <--- window4
print(torch.equal(x, x_restored))
# True
Neighbor-Window Connection
空間シャッフルのセルフアテンションは、遠く離れた要素の結びつきを促進しますが、高解像度画像を処理する際に懸念が生じます。画像サイズがウィンドウサイズと比較して非常に大きい場合、各ウィンドウが狭くなり、情報の長距離伝播が効果的に行われない可能性があります。この問題に対処するため、近傍ウィンドウ接続(Neighbor-Window Connection)による情報伝播の改善が提案されています。具体的なアプローチとして、(1)ウィンドウサイズの拡大、(2)シフトウィンドウと連携、(3)畳み込みの導入が挙げられます。論文では、残差接続付きdepthwise convolutionを提案しています。この畳み込みのカーネルサイズはウィンドウサイズと同じです。実験結果によると、depthwise convolutionはWMSAとMLPの間に組み込むことで、精度が向上することが報告されています。
Shuffle Transformer Block
Shuffle Transformer Blockは、3つの主要なコンポーネントで構成されています:シャッフルの有無に基づくウィンドウベースのマルチヘッドセルフアテンション(WMSA)、近傍ウィンドウ接続(NWC)、および多層パーセプトロン(MLP)。以下の図は、2つの連続するShuffle Transformer Blockを示しています。
連続するShuffle Transformer Blockでは、ウィンドウ間の相互接続を確保するために、WMSAとShuffle-WMSAを交互に使用します。最初のブロックでは通常のWMSAが使用され、次のブロックではShuffle-WMSAが採用されます。その上、隣接するウィンドウ間の接続性を高めるためにNWCが追加されています2。Normは、通常のTransformer Blockで使われるレイヤーノルムではなく、バッチノルムを指します。
Shuffle TransformerとMaxViTの比較
Shuffle TransformerとMaxViTは、相対位置エンコーディングを含むアテンションを同じ分割方式で実施し、局所的なアテンションと大域的なアテンションを連続するTransformerブロック内で交互に呼び出すという共通点があります。また、Stem層によるダウンサンプリングにおいて、両者とも重なる畳み込み層を利用しています。
一方、異なる点としては、まずTransformer Blockの単位で比較すると、Shuffle Transformerは近傍ウィンドウ接続(NWC)を使用していますが、MaxViTではこれを利用していません。ステージ単位では、Shuffle TransformerのPatch Mergingは単一の畳み込み層を使用していますが、MaxViTでは畳み込みネットワークが多層構造で構築されています。さらに、MaxViTでは各ステージでTransformer Block以外にも畳み込みネットワークを複数回実行しますが、Shuffle TransformerではTransformer Blockのみを繰り返します。その他、細かい違いとして、MLPやTransformer Blockなどで異なる活性化層や正規化層が使用されています。最後に、実装面では、MaxViTはTensorFlowを、Shuffle TransformerはPyTorchを採用している点が挙げられます。
実験
Shuffle Transformerの有効性を検証するために、画像分類、物体検出、セマンティック/インスタンスセグメンテーションの実験を実施しています。画像分類タスクにおいては、Swin Transformerに比べて優れた速度と精度のトレードオフを実現しています。その他のタスクにおいても、Swin Transformerと同等またはそれ以上のパフォーマンスを達成しています。
Shuffle Transformerにおける各モジュールの重要性を検証するためのアブレーションスタディも実施されています。下図(a)は、空間シャッフルと近傍ウィンドウ接続(NWC)の使用が性能向上に有効であることを示しています。下図(b)は、異なる空間シャッフル手法における性能比較を示しています。長い範囲(通常の方法)、短い範囲、およびランダムなシャッフルの条件下で、長い範囲でのシャッフルが最も良い結果を出しています。下図(c)は、Transformer Blockの異なる位置(A、B、C)にNWCを挿入する影響について示しています。NWCをWMSAとMLPの間に挿入すること(B)が、最良のパフォーマンスを達成しています。
(a)
(b)
(c)
おわりに
Shuffle Transformerは、効率的なモデリングを実現するために、ウィンドウベースのセルフアテンションを取り入れ、さらにセルフアテンションに空間シャッフルを導入することでウィンドウ間の大域的な結合を確立しています。加えて、隣接するウィンドウ間の結合を強化するために、残差接続を備えたdepthwise convolutionを導入しています。これらのモジュールが組み合わせることで、Shuffle Transformerは情報を効果的に全てのウィンドウ間で伝達し、これまでのモデルよりも優れたパフォーマンスを発揮しています。