Vision Transformer について公式論文のモデル図では理解が難しいなと思っていました。
今回は、数式やコードは別の機会に譲り、モデルの構造のイメージを直感的に掴めるよう記載していきます。
そもそもTransformerとは?
最初にさらっとですが、モデルの変遷にふれます。興味のない方は読み飛ばしてください。
Transformerという名前は 「変換」 を意味します。これは、ある情報を別の形に変える仕組みです。
Transformerが登場する前は「seq2seq(シーケンス・ツー・シーケンス)」と呼ばれるモデルが使われていました。このモデルでは、文章のような順序が重要なデータ(系列データ)をRNN(再帰型ニューラルネットワーク)やLSTM(長短期記憶ネットワーク)に入力し、一つのまとまった情報(文脈ベクトル)に変換したうえで、少しずつ出力を生成します。
この変換の過程では、入力された情報を整理して意味を抽出する「エンコーダ」と、その情報を使って新たな出力を作り出す「デコーダ」が使われます。しかし、固定長の文脈ベクトルに情報を詰め込むため、長文の情報をすべて保持することが難しく、情報の圧縮が課題でした。
そこで、「重要度を考慮して情報を活用する」仕組みとして「Attention(注意機構)」が提案されました。これは、入力文章のどの部分が特に重要かを判断し、それを強調して情報を処理する方法です。「seq2seq with attention」では、出力のタイミングごとに「どの単語が重要か」を計算し、文全体の重み付き平均を求めることで、より意味を損なわないで出力できるようになりました。
しかし、この方法では依然として「系列の順番に従って処理を進める」必要があり、すべての計算を順次行うため処理速度に限界がありました。
そこで、「Attention」の考え方を一歩進め、文章全体で「どの単語同士が関係しているか」を自動的に学習する「self‑attention(自己注意機構)」が導入されました。これにより、各単語が他の単語とどのように関連しているかを並列処理できるようになりました。
以上が発展の流れです。この self‑attention をどのように画像分類タスクへ応用していくのかを見ていきます。
Vision Transformerの仕組み
まずは全体像を示します。これからVision Transformerを「ViT」と表記します。
ViTは大きく三つに分かれています。
- Input Layer
- Encoder
- MLP Head
画像がモデルに入力され、最後にMLP Headを通った後に分類結果が得られます。今回は例として私のお気に入りの「なまけたろう」というキャラクターの画像を拝借し、ViTがそのラベルを予測するイメージで説明します
Input Layer
Input Layerは画像をモデルが読める形へ変換する層です。
流れは次の四つです。
- ① パッチに分割
- ② 埋め込み
- ③ クラストークン
- ④ 位置埋め込み
① パッチに分割
画像を細かく分ける前処理はCNNにはないステップです。Transformerは自然言語処理で使われるモデルなので、画像も単語に相当する最小単位で扱う必要があります。ViTでは、画像を「パッチ」と呼ばれる小さい単位に分割します。
例として32 pixel × 32 pixelの画像を4分割し、16 pixel × 16 pixelのパッチを得るとします(パッチサイズはハイパーパラメータで設定します)。
RGB画像なら赤・緑・青の3チャネルを持つため、各パッチは16 × 16 × 3のテンソルです。これを一次元ベクトルへ変換すると要素数は16 × 16 × 3 = 768になります。パッチが4個あるので、長さ768のベクトルが4本得られるイメージです。
② 埋め込み
ベクトル化しただけでは、RGBの画素値(0〜255)が並んでいるだけでモデルは意味を理解できません。そのため、モデルが学習可能な表現に変換する埋め込み層を通します。ViTでは1層の線形層(全結合層)を用います。
③ クラストークン
ここで、先ほどのパッチベクトルとは別にクラストークン([CLS]トークン)を新たに作成します。このベクトルはパッチと同じ長さの次元を持ち、最終的に画像全体の表現を要約するものとして分類に用いられます。
④ 位置埋め込み
パッチが画像内のどこに位置するかを示す情報を加えます。この位置埋め込みも学習可能なパラメータとしてモデルが最適値を見つけます。
以上で次のEncoderへの入力準備が整いました。
Encoder
Encoderは複数のEncoder Blockから構成されます。
Encoder Blockの中身は各Blockで共通しており、以下のようになっています。
- ① Layer Norm(1回目)
- ② Multi‑Head Self‑Attention
- ③ Layer Norm(2回目)
- ④ MLP
① Layer Norm(1回目)
深層学習で一般的な正規化手法に Batch Normalization がありますが、これはバッチ全体の平均と分散を使います。自然言語タスクでは文ごとにトークン数が異なり、トークン位置間で統計量が揃いません。Layer Normalization は入力テンソルの最後の次元(特徴次元)で正規化を行うため、文長に依存しません。ViTでもTransformerと同様にLayer Normalizationを採用しています。
② Multi-Head Self Attention
ViTの核心部分です。名前に「Multi」が付く通り、Self‑Attentionを複数ヘッド並列で計算します。
ここで Self-Attention について説明します。
Self-Attention
説明を簡単にするため、これからはクラストークンとパッチ埋め込みを縦に並べた行列で示します。今まで横に並んでいたものが縦に並んだだけなので中身は変わっていません。
この縦型行列には、その画像を構成する全てのパッチ埋め込みとクラストークンがまとめられています。
- 行数 = (パッチ数 + 1) ⇒ 今回は5行
- 最初の1行目はトークン
- 残りの行が各パッチ
- 列数 = 埋め込み次元(768次元)
Self‑Attentionは「あるパッチ(またはクラストークン)が他のパッチとどれくらい関連しているか」を類似度として計算し、その類似度に基づいて各パッチ表現を更新します。流れは次の四段階です。
- ① 埋め込み層(線形層)で入力から 三つの行列を得る
- ② 二つの内積を取り、各要素間の類似度スコア(関連度)を計算
- ③ 類似度スコアにsoftmaxをかけて確率に換算する
- ④ 確率と三つ目のベクトルの重み付き和を求め、新しいパッチ表現を得る
今、一つ目、二つ目、三つ目と行列を表現しましたが、それぞれ呼び方がついています。
一つ目のベクトルは クエリ と呼ばれ、自分が知りたいことを示す質問ベクトルです。
二つ目のベクトルは キー と呼ばれ、照合の手がかりになるベクトルです。
三つ目のベクトルは バリュー と呼ばれ、照合後に相手へと渡す実際の情報ベクトルです。
キーの形が縦型から横型に変形していますが、これは類似度を行列の内積で計算するため、そのように表し直しました。
さて、③で中身が類似度の行列である Attention Weight ができたことが分かります。これが複数個つまり複数の関係を捉えた行列が得られれば、より損失を少なくすることができるのではないでしょうか。
それが、Multi-Head Self Attention と呼ばれる仕組みです(以下、「MHSA」)。
1つのパッチから複数のクエリ、キー、バリューを埋め込むことが可能で、その分割数もパラメータで指定します。結果として、指定した分割数分のAttention Weightが作成できます。
各ヘッドごとにパラメータが異なるので、「色」「形」「配置」など異なる相関パターンに注目するように分かれます。
以下は、頭の情報を見るヘッドのイメージです。
ヘッドごとの結果は結合された後、もう一度線形層で混ぜ合わせ最終的なトークンを更新するようになります。
③ Layer Normalization(2回目)
ここでMHSAが一区切りしました。
実際には、ここで 残差接続 といって、MHSAの出力 → 線形層を経たうえで、Encoderブロックの入力ベクトル(x0とする)と足し合わせます。
そして、残差で加算したベクトルx1 = x0 + MHSAの出力を再び Layer Normalization に入力し、分散を安定させてから次のMLPへ渡します。
④ MLP
2 層の全結合層と活性化関数 GELU から成る小さな Feed‑Forward Network です。
ここで得た特徴ベクトルを、再度先ほどの入力x1と残差加算してx2 = x1 + MLPの出力を作り、次の Encoder Block へ送ります。
残差接続は入力をそのままコピーしたベクトルと、複雑な変換を行った後のベクトルを足し合わせているため、仮にうまく学習できなくても入力がそのままわたるので性能が極端に悪化しにくいというメリットがあります。
MLP Head
Encoder 最終ブロックの出力全トークンは、まず LayerNorm をもう一度経由し、抽出されたクラストークンのベクトルを 線形層(全結合層)に入力します。
ここで出力次元を分類するクラス数に設定し、ソフトマックスなどで確率分布へ変換して予測を得ます。
気を付けたいのが、MLPに入力されるのはパッチ全体ではなく、クラストークンということです。
説明したとおり、各Self AttentionはCLS ⇆ パッチ で相互作用しており、クラストークンはブロックを重ねるごとに画像全体の要約ベクトルへと育っていきます。
一方で、Encoder Block間では常にクラストークン + すべてのパッチ埋め込みを渡しています。
クラストークンのみを抽出するのはあくまで分類タスク向けの設計なので、Segmentationを行う際はパッチ系列をそのままDecoderに渡すことになる等、設計が異なります。
以上がVitの仕組みになります。
実際にVision Transformerの中身にアクセスしてみる
ここから、実際にHugging faceから「vit_base_patch16_224」というモデルをロードし、中身を見てみます。
import timm
# モデルを取得
model = timm.create_model("vit_base_patch16_224", pretrained=True)
len(model.blocks)
出力は12
とでましたので、このEncoder Blockは12個あることが分かります。
Patch埋め込みの設定を確認
model.patch_embed.proj
Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
と出力されました。
パッチサイズは16×16なので、224×224の入力画像が14×14=196個のバッチに分割されること、768次元に埋め込みされることが分かります。
Encoder Blockの中身
model.blocks[0]
Block(
(norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=768, out_features=2304, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=768, out_features=768, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=768, out_features=3072, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=3072, out_features=768, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
個別にアクセスしてみましょう。
Model.blocks[0].attn.num_heads
出力は12
となりました。最初のEncoder BlockのMulti-Head Self-Attentionの数は12個あるということです。
出力次元を変更する
自分がやりたいことに合わせてMLP Headにアクセスする必要があります。
# モデルを定義
import timm
import torch.nn as nn
num_classes = 7 # 例
def create_model(model_name="vit_base_patch16_224", num_classes=7):
model = timm.create_model(model_name, pretrained=True)
model.head = nn.Linear(model.num_features, 7)
return model
modelのhead部分にアクセスして出力次元を7に変えてみました。
Attention Weightの可視化
今後、attention weightを可視化したいと思った際は以下のように、どのBlockから取得するかを決めてrefgister_foeward_hook
関数へと渡すことになります。
# 例:softmax後のattentionを受け取る
hook = model.blocks[0].attn.attn_drop.register_forward_hook(hook_fn)
以上です。読んでいただきありがとうございました。