0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ViTにおける特定のTransformer Encoderの学習

Last updated at Posted at 2025-02-02

Vision Transformer(ViT)はCNNより高精度な画像認識が可能なことから注目されています.
しかし,ViTはパラメータが非常に多いため学習するために高性能な計算機器が必要になります.
そこで,本記事ではViTの特定のTransformer EncoderのみをFine-tuningする方法を書いていきます.

ViTのモデル構造
ViT_org.png

特定のTransformer Encoderを学習

ViTの学習済みモデルを使用するためにはtimmを使用すると簡単にできます.
timmは正式名称Pytorch Image ModelsでSOTA(State of the Art)を達成した画像分類モデルを公開しているライブラリです.
timmの使用方法については,こちらを参照してください.

ViTで特定のTransformer Encoderを学習する方法はPartial-k(kは学習する層の数を指します.)と呼ばれ,"Masked Autoencoders Are Scalable Vision Learners
"と"Visual Prompt Tuning"において,省メモリで学習しつつ転移学習に比べて精度が向上することが示されています.
以下の図のように,Transformer Encoderの最終層を学習する場合はPartial-1と呼ばれます.

Partial-1.png

下記はViTのPartial-1で学習するコードです.

import timm # timmをインポート
from timm.models import create_model # timmで公開されているモデルを読み込む関数

model = create_model("vit_tiny_patch16_224", pretrained=True, num_classes=len(class_names)) 
model.to('cuda')

# 全パラメータを固定
for param in model.parameters():
    param.requires_grad = False

# 最後のTransformer Encoderブロックを学習可能にする
for param in model.blocks[-1].parameters():
    param.requires_grad = True

# 分類層(head)も学習可能にする
for param in model.head.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(lambda p: p.requires_grad, model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # 学習対象である最後のTransformer Encoderとheadのパラメータを渡す

精度・パラメータの比較

ここでは精度・パラメータを比較します.
モデルは,ImageNet-1kで学習済みのViTをCifar-10で1 Epochだけ学習します.
比較する学習方法は,MLP Headのみを学習する転移学習(Transfer Learning),全てのパラメータを学習するFine-Tuning,Partial-1の3種類です.

パラメータ数

以下のように,転移学習(Transfer Learning)と比較すれば多いものの,Fine-Tuningと比較するとパラメータ数を1桁以上削減できています.

転移学習(Transfer Learning)

Total number of trainable parameters: 1930

Fine-Tuning

Total number of trainable parameters: 5526346

Partial-1

Total number of trainable parameters: 446794

精度

以下のように,Fine-Tuningと比較すると劣るものの,Fine-Tuningより少ないパラメータ数で転移学習(Transfer Learning)より高い精度を発揮できています.

Fine-Tuning

epoch: 1, train loss: 0.2751163761680707, train accuracy: 0.90644 test loss: 0.12482822295042532, test accuracy: 0.9596

転移学習(Transfer Learning)

epoch: 1, train loss: 1.616312658175444, train accuracy: 0.45186 test loss: 1.1747484222243103, test accuracy: 0.6182

Partial-1

epoch: 1, train loss: 1.0304751597917996, train accuracy: 0.64032 test loss: 0.7293716035311735, test accuracy: 0.7493

まとめ

今回はViTの特定のTransformer EncoderのみをFine-tuningする方法であるPartial-kについて説明しました.
高性能な計算機器が手元にない場合や学習を高速化したい場合などは使用するといいかもしれません.
また,Contrastive Language-Image Pre-training(CLIP)やMasked Autoencoder(MAE)などの大規模なデータ学習したモデルを省メモリで学習する場合にも使用するといいかもしれないです.
ここで実際にtimmを使用してCIFAR-10 / 100を分類するコードを公開しています.

参考文献

・An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
・Visual Prompt Tuning
・Masked Autoencoders Are Scalable Vision Learners
・Learning Transferable Visual Models From Natural Language Supervision

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?