LoginSignup
4
3

【Pytorch】SwinTransformerでLet's画像分類!

Last updated at Posted at 2022-12-18

はじめに

この記事は、AMBL株式会社 Advent Calendar 2022の17日目の記事になります。是非、他の記事も読んで見て下さい〜:santa_tone2:

SwinTransformerって名前はよく聞くけど中身はあまり理解していない...。
いい機会なので、SwinTransformerの実態を暴いていこうと思います。

時間がない方向け

  • Vision Transformerの問題点を解決する方法を提案
  • 当時Object Detection, Semantic Segmentation タスクでSoTA

SwinTransformerの実態を暴くには、まずVision Transformerの問題点を知る必要があります。

まず、Vision Transformerってなんぞ??

2020年10月に、An Image is Worth 16x16 Words: Transformers for Image Recognition at Scaleのタイトルで論文が発表されました。

Vision Transformerは画像分類のモデルです。

  • 入力画像をパッチに分割します。(自然言語処理の単語に分割する部分)
  • パッチ中のピクセル値を並べてベクトルとみなし、これを線形変換したものを各トークンのベクトル表現として扱います。
    apple.jpg

image.png

Vision Transformerの構造
image.png

Vision Transformerの問題点

  • 認識する対象は画像中で様々な大きさを取る -> パッチは対象物体をぶつ切りにする可能性がある
  • 画像の解像度が高くなると計算量が膨大になる

なんで計算量が膨大になるの?
画像内のすべてのPatchに対してAttentionの計算を行うため、計算コストは画像サイズに対して二乗(h×w)で増加してしまう。

Swin Transformerはこれらの問題をどのように解決した?

image.png

Patch MergingというPoolingのように画像の縦横を小さくする機構を導入

Vision Transformerの問題点の1つ目である「認識する対象は画像中で様々な大きさを取る -> パッチは対象物体をぶつ切りにする可能性がある」
を解決するためにSwinTransformerでは、Patch MergingというPoolingのように画像の縦横を小さくする機構を導入している。

隣接2×2のパッチ(チャネル数C)を1つにまとめ、チャネル数が4倍になったあと(チャネル数4C)線形変換を行うことでチャネル数を半分に減らす。CNNで言うPoolingのようなことを行っている。
→入力チャネル数の2倍のチャネル数がPatch Mergingの出力になる。これにより、ステージが進むたびにより広い領域の情報をマージしつつ、高次の特徴抽出処理が行われる。
→階層が深くなるにつれて特徴マップが小さくなるように設計し、階層的な特徴量を取得している。

Swin Transformer Block

Vision Transformerの問題点の2つ目である「画像の解像度が高くなると計算量が膨大になる」を解決したSwin Transformer Blockについて解説する。

ほとんどTransformerと同じ
違うのは

  • Window-based Multi-head Self-attention(W-MSA)
  • Shifted Window-based Multi-head Self-attention(SW-MSA)
    W-MSAとSW-MSAはBlock毎に交互に適応される点。

W-MSA

特徴マップをサイズがM×Mのwindowに区切り、window内でのみself-attentionを求める

  • h×w個のパッチが存在する特徴マップにおいて、(hw)×(hw)= h^2×w^2の計算量がM^2×M^2×(h/M)×(w/M) = M^2hwに削減
  • Mは定数なので、計算コストが画像サイズの2乗から画像サイズの線形に緩和

SW-MSA

入力された特徴マップの解像度が8×8で、64個のパッチからなるとする(下図左)。
M=4として、1つのWindowにはM^2=16個のパッチを含む、重なりの無いWindowに分割することができる。

このWindowを(M/2, M/2)(2,2)シフトさせると、上図右のようになる。
元のWindowの大きさを保っているのは1つのみ。
この状態でそのまま処理を進めるのはやや複雑な処理となる。

image.png

そのため、SW-MSAにおいては、cyclic shiftという工夫をしている。
下図のように、Windowをシフトしてはみ出した部分(薄く表示されているA, B, C)をWindowの反対位置へ移動させる。
→Windowの数やWindowあたりのパッチ数は一定となる。
標準的なMSAでよく使用されるマスクを設定してあげることで、無関係なパッチ間でAttentionが生じないようにできるので、複雑な実装を回避することができる。

image.png

こうして得られた特徴マップの移動していた部分を本来の位置へと戻す(reverse cyclic shift)ことで、SW-MSAの処理が完了する。

Swin Transformerを実装してみた。

dataloderでdatasetを読み込んで、Trainigの実装部分はたったこれだけ。

# timmを使用してswin transformerの事前学習済みモデルを読み込む
model = timm.create_model('swin_base_patch4_window7_224_in22k', pretrained=True, num_classes=2)

# Training

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

for epoch in range(epochs):
    for data, label in tqdm(train_loader):
        
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader) 

簡単ですね~。

image.png

精度もそこそこでてるみたいです。

最後に

画像分類タスクはCNNだけでなく、Transformerも優秀ですね。
Transformerの得意な分野。CNNの得意な分野を知るのが大切ですね。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3