0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

CoAtNet: モデル容量と汎化性能を両立する畳み込みとTransformerのハイブリッドアーキテクチャ

Last updated at Posted at 2024-02-15

論文とコード

挿入している画像は、特に言及していない限り、本論文からの引用となります。

はじめに

画像認識分野において、畳み込み層とトランスフォーマー層を組み合わせたハイブリッド構造を持つVision Transformerモデルが注目されています。その代表例として、MaxViTとCoAtNetが挙げられます。MaxViTは各ステージで畳み込み層とトランスフォーマー層を併用しますが、CoAtNetは各ステージでどちらか一方を使用します。これらの異なるアプローチを持つモデルの性能を比較すると、ImageNet-21Kの画像分類(下図左)では、同じモデルサイズに対してMaxViTが優れた性能を示します。しかし、モデルが大きくなるにつれ、CoAtNetが追い越す勢いを示しています。さらに大規模なデータセットであるJFT-300M(下図右)でも、CoAtNetはMaxViTを追い越す傾向を見られます。これらの結果から、CoAtNetは大規模なモデルにおいて高い性能を発揮する可能性があります。そこで、本記事ではCoAtNetについて解説します。
imagenet21k_jft300m.png
図: ImageNet-21K(左)とJFT-300M(右)による画像分類(Zhengzhong Tu, Hossein Talebi, Han Zhang, Feng Yang, Peyman Milanfar, Alan Bovik, Yinxiao Li, MaxViT: Multi-Axis Vision Transformer, arXiv:2204.01697, 2022より)

モデル 過去の記事
MaxViT https://qiita.com/kinkalow/items/aa7508d3d34a2c827d40

CoAtNet

CoAtNetの全体的なアーキテクチャを以下に示します。
architecture.png
CoAtNetは、S0のStem、S1とS2のMBConv、S3とS4のTransformer Blockで構成されています。S1からS4までは複数回実行されます。各ステージの最初のブロックでは、ダウンサンプリングが行われます。

Stem(S0)

S0は、2つの3x3の畳み込み層で構成されています。各畳み込み層の後には、BatchNormとGELU活性化関数が適用されます。最初の畳み込み層ではストライドが2で、ダウンサンプリングが行われます。次の畳み込み層ではストライドが1となります。

MBConv(S1とS2)

MBConvは、MaxViTの記事で紹介したものと類似しています。各ステージの初めのブロックでは、ダウンサンプリングが適用されるため、MBConvの最初のブロックとそれ以降のブロックでは構造がわずかに異なります。MBConvの構造は以下のように表現できます。

 MBConv(First time)         MBConv(>=Second times)    SE
 Input                      Input                     Input
    |---------------+          |---------------+         |-------------+
 BatchNorm          |       BatchNorm          |      AdaptiveAvgPool  |
    |               |          |               |         |             |
 Conv1x1(stride=2)  |       Conv1x1(stride=1)  |      FC(reduction)    |
 BatchNorm          |       BatchNorm          |      GELU             |
 GELU               |       GELU               |         |             |
    |               |          |               |      FC(expansion)    |
 DepthConv3x3    MaxPool    DepthConv3x3       |         |             |
 BatchNorm       Conv1x1    BatchNorm          |      Sigmoid          |
 GELU               |       GELU               |         |             |
    |               |          |               |        Mul------------+
   SE               |         SE               |         |
    |               |          |               |      Output
 Conv1x1            |       Conv1x1            |
 BatchNorm          |       BatchNorm          |
    |               |          |               |
   Add--------------+         Add--------------+
    |                          |
 Output                     Output

MBConvは、1x1、3x3、1x1の畳み込みを組み合わせ、最初の1x1レイヤーでチャンネル次元を4倍増加させ、その後2回目の1x1レイヤーで縮小します1。これは、ResNetの縮小から拡大へのボトルネック設計とは逆の構造を有しています。そのため、MBConvは逆ボトルネックの構造を利用していると言われています。また、SE(Squeeze-and-Excitation)は、MaxViTの記事で取り上げた手法とほぼ同様であり、チャンネルの重要度に基づいてチャンネル次元を変換する効果があります。なお、SEでは2つのFCが使用され、最初にチャンネル次元が4分の1に縮小され、その後に元の次元に戻ります。
通常のMBConvでは、ダウンサンプリングにおいては、depthwise convolutionにおけるストライド2が使用されますが、CoAtNetのMBConvでは、最初のConv1x1でダウンサンプリングが行われます。この違いによる性能差は以下の表で示されています。Conv1x1でダウンサンプリングを行うと、FLOPsが低減し、モデルが大きくなるほど精度への影響が小さくなります。
downsample_conv_depthwise.png

Transformer Block(S3とS4)

Transformer Blockの構造は以下のようになります。

 Transformer Block(First time)    Transformer Block(>=Second times)
  Input                           Input
     |---------------+               |---------------+
  MaxPool         MaxPool            |               |
     |            Conv1x1            |               |
  LayerNorm          |            LayerNorm          |
  RelativeAttention  |            RelativeAttention  |
     |               |               |               |
    Add--------------+              Add--------------+
     |---------------+               |---------------+
  LayerNorm          |            LayerNorm          |
  FFN(expansion=4)   |            FFN(expansion=4)   |
     |               |               |               |
    Add--------------+              Add--------------+
     |                               |
  Output                          Output

Transformer Blockは、各ステージの最初のブロックでのダウンサンプリングを除けば、通常のTransformer Blockとほぼ同じ構造です。具体的には、相対位置エンコーディングを含むグローバルなマルチヘッドセルフアテンションと、2層のFCとGELUを組み合わせたFFNから構成されています。最初のブロックでは、残差ブランチと恒等ブランチにマックスプーリングによるダウンサンプリングが追加されます。また、恒等ブランチにはConv1x1を追加しています。これは、アテンションの最後の全結合層において、出力チャンネル数が常にアテンションへの入力チャンネル数と同じになるように設定されていないためです。ただし、2回目以降は同じになります。

モデル構築のための探索的実験

CoAtNetのネットワークを構築する前に、MBConvとTransformer Blockをどのように組み合わせると、優れた性能が得られるかについての実験が行われています。5つの異なるモデルを導出しています。

  • 4つのモデルは、CoAtNetと同様に、S0からS4までの5段階のネットワークを構築します。S0は2層畳み込みStemであり、S1はMBConvを使用します。S2からS4までは、畳み込み層が最初に現れなければならないという制約の下で、MBConvまたはTransformer Blockのどちらかを使用します。これにより、$\text{C-C-C-C}$、$\text{C-C-C-T}$、$\text{C-C-T-T}$、$\text{C-T-T-T}$の4つのモデルが得られます。$\text{C}$と$\text{T}$はそれぞれ畳み込みとトランスフォーマーを表します。
  • 残り1つのモデルは、オリジナルのVision Transformerと同じように、大きいストライド(例えば、ストライド16)を持つ畳み込みStemを使用します。その後、相対的アテンションを持つTransformer BlockをL個直接積み重ねます。このモデルを$\text{ViT}_\text{REL}$と呼びます。

これらの5つのモデルは、同じモデルサイズを使用しており、ImageNet-1K(130万枚)とJFT(3億枚以上)のデータセットで訓練します。各モデルの訓練損失と評価精度は以下の通りです。
comparison_model_generalization_capacity.png
ImageNet-1Kは、モデルの汎化能力、つまり未知データに対しても高い精度で予測を行える能力を評価するための指標です。具体的には、以下の順で優れています。
$\text{C-C-C-C} \approx \text{C-C-C-T} \geq \text{C-C-T-T} \gt \text{C-T-T-T} \gg \text{ViT}_\text{REL}$
特に、$\text{ViT}_\text{REL}$は、他のモデルと比較して大幅に劣っています。これは、大幅なダウンサンプリングにより局所的な特徴量の処理が不十分であることが原因と考えられます。また、$\text{ViT}_\text{REL}$を除いたモデルでは、畳み込み段数が増えるほど、汎化能力は高くなる傾向があります。
JFTの比較では、大規模な訓練データセットに適合する能力(モデル容量)を評価しています。その結果、以下の順位となります。
$\text{C-C-T-T} \approx \text{C-T-T-T} \gt \text{ViT}_\text{REL} \gt \text{C-C-C-T} \gt \text{C-C-C-C}$
単にTransformer Blockを増やすだけが、必ずしも性能向上につながらないことが示されています。$\text{ViT}_\text{REL}$は最初は劣っているものの、最終的にはMBConv段数が多い2つのモデルに追いついています。しかし、$\text{C-C-T-T}$と$\text{C-T-T-T}$のどちらも$\text{ViT}_\text{REL}$を上回っていることから、$\text{ViT}_\text{REL}$が大きなストライドによって情報が失わた可能性があります。興味深いことに、$\text{C-C-T-T}$と$\text{C-T-T-T}$はほぼ同等の性能を示します。これは、畳み込みの局所操作がグローバルアテンションと同等に機能を果たすことを意味し、計算量とメモリ使用量を削減できるという大きな利点があります。
最後に、$\text{C-C-T-T}$と$\text{C-T-T-T}$の選択を決定するために、転移学習テストを実施しています。JFTで事前学習されたこれらの2つのモデルをImageNet-1Kでファインチューニングし、その転移性能を比較します。以下の表から分かる通り、事前学習性能が同等であるにも関わらず、$\text{C-C-T-T}$の方が高い転移精度を示しています。
transferability_test_results.png
CoAtNetでは、汎化能力、モデル容量、転移学習、効率を考慮し、$\text{C-C-T-T}$のレイアウトを採用しています。

画像分類

実験では、CoAtNetモデルの性能を評価するために、主に画像分類を対象としています。具体的には、ImageNet-1K(128万画像)、ImageNet-21K(1270万画像)、JFT-300M(3億画像)とJFT-3B(30億画像)の各データセットに対して事前学習が行われます。得られた事前学習済みモデルは、目的の解像度でファインチューニングされ、最終的にはImageNet-1Kで評価精度が計測されます。ただし、ImageNet-1Kで解像度224のパフォーマンスは、事前学習が完了した時点で直接取得されます。結果は以下の通りです。

imagenet1k_imagenet21k.png
図: 224x224のImageNet-1K(左)とImageNet-21K(右)による事前学習

jft300M_jft3B.png
表: JFT-300MとJFT-3B(下から3行)による事前学習

ImageNet-1KとImageNet-21Kの事前学習では、既存のモデルを凌駕する結果が得られています。特に注目すべきは、ImageNet21-Kで事前学習したCoAtNet変種が88.56%のトップ1精度を達成し、これはJFT-300Mで事前学習したViT-H/14の88.55%に匹敵します。ただし、ViT-H/14はCoAtNet変種よりも、データセットが23倍大きく、さらにパラメータが2.3倍多く、事前学習時間も2.2倍を要しています。これにより、データ効率と計算効率の両面で改善が見られます。JFT-300Mの比較では、CoAtNet-4はNFNet-F4+に対して、TPUトレーニング時間(TPUv3-core-days)とパラメータ数の両方で2倍の効率性を誇りながら、最高性能にほぼ追いついています。さらに、大規模なJFT-3Bを使用すると、CoAtNet-6はViT-G/14と同等の精度を維持しながら計算時間を4.5倍短縮し、CoAtNet-7は1.5倍少ない計算時間で90.88%という新たな最先端の精度を達成しています。

アブレーションスタディ

CoAtNetに関するアブレーションスタディも実施されています。
(a) ablation_relative_attention.png
(b) ablation_architecture_layout.png
(c) ablation_headsize_normalization.png
図(a)は、相対位置エンコーディングの効果を検証したアブレーションスタディの結果です。転移学習の有無に関わらず、相対位置エンコーディングを導入することで、モデルの性能が向上することが示されています。図(b)は、計算の主要な部分を占めるMBConvのS2とTransformerのS3のブロック数を変化させた場合の性能比較を示しています。S2とS3のブロックの総数は一定であり、各ステージのブロック数が変更されています。S3のブロック数を増やすと性能が向上し、一定の大きさを超えると性能が低下する可能性があることが示唆されています。図(c)は、各アテンションヘッドの次元をデフォルトの32から64に変更した場合と、MBConvで使用するバッチノルムをレイヤーノルムに変更した場合の性能を比較しています。ヘッド数を変更すると性能が低下しますが、TPU速度が大幅に向上するため、ヘッド数の調整は精度と速度のトレードオフが生じます。バッチノルムとレイヤーノルムはほぼ同等の性能を示しますが、バッチノルムはTPU上で速くなります。

おわりに

CoAtNetは、畳み込みとトランスフォーマーを効果的に組み合わせたモデルです。広範な実験により、CoAtNetは優れた汎化能力とモデル容量を備え、さまざまなデータサイズと計算コストの条件下で最先端のパフォーマンスを実現することが示されています。

  1. 厳密に言えば、各ステージの最初のブロックでは、2回目のConv1x1によってチャンネル数を必ずしも縮小させるとは限りません。これは、Conv1x1の出力チャンネル数がユーザーによって指定可能であり、その値に依存して縮小するかどうかが決まります。ただし、2回目以降のブロックでは、必ず4分の1に縮小されます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?