Vision Transformer(ViT)はCNNより高精度な画像認識が可能なことから注目されています.
しかし,ViTはパラメータが非常に多いため学習するために高性能な計算機器が必要になります.
そこで,本記事ではViTの特定のTransformer EncoderのみをFine-tuningする方法を書いていきます.
特定のTransformer Encoderを学習
ViTの学習済みモデルを使用するためにはtimmを使用すると簡単にできます.
timmは正式名称Pytorch Image ModelsでSOTA(State of the Art)を達成した画像分類モデルを公開しているライブラリです.
timmの使用方法については,こちらを参照してください.
ViTで特定のTransformer Encoderを学習する方法はPartial-k(kは学習する層の数を指します.)と呼ばれ,"Masked Autoencoders Are Scalable Vision Learners
"と"Visual Prompt Tuning"において,省メモリで学習しつつ転移学習に比べて精度が向上することが示されています.
以下の図のように,Transformer Encoderの最終層を学習する場合はPartial-1と呼ばれます.
下記はViTのPartial-1で学習するコードです.
import timm # timmをインポート
from timm.models import create_model # timmで公開されているモデルを読み込む関数
model = create_model("vit_tiny_patch16_224", pretrained=True, num_classes=len(class_names))
model.to('cuda')
# 全パラメータを固定
for param in model.parameters():
param.requires_grad = False
# 最後のTransformer Encoderブロックを学習可能にする
for param in model.blocks[-1].parameters():
param.requires_grad = True
# 分類層(head)も学習可能にする
for param in model.head.parameters():
param.requires_grad = True
optimizer = torch.optim.AdamW(lambda p: p.requires_grad, model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # 学習対象である最後のTransformer Encoderとheadのパラメータを渡す
精度・パラメータの比較
ここでは精度・パラメータを比較します.
モデルは,ImageNet-1kで学習済みのViTをCifar-10で1 Epochだけ学習します.
比較する学習方法は,MLP Headのみを学習する転移学習(Transfer Learning),全てのパラメータを学習するFine-Tuning,Partial-1の3種類です.
パラメータ数
以下のように,転移学習(Transfer Learning)と比較すれば多いものの,Fine-Tuningと比較するとパラメータ数を1桁以上削減できています.
転移学習(Transfer Learning)
Total number of trainable parameters: 1930
Fine-Tuning
Total number of trainable parameters: 5526346
Partial-1
Total number of trainable parameters: 446794
精度
以下のように,Fine-Tuningと比較すると劣るものの,Fine-Tuningより少ないパラメータ数で転移学習(Transfer Learning)より高い精度を発揮できています.
Fine-Tuning
epoch: 1, train loss: 0.2751163761680707, train accuracy: 0.90644 test loss: 0.12482822295042532, test accuracy: 0.9596
転移学習(Transfer Learning)
epoch: 1, train loss: 1.616312658175444, train accuracy: 0.45186 test loss: 1.1747484222243103, test accuracy: 0.6182
Partial-1
epoch: 1, train loss: 1.0304751597917996, train accuracy: 0.64032 test loss: 0.7293716035311735, test accuracy: 0.7493
まとめ
今回はViTの特定のTransformer EncoderのみをFine-tuningする方法であるPartial-kについて説明しました.
高性能な計算機器が手元にない場合や学習を高速化したい場合などは使用するといいかもしれません.
また,Contrastive Language-Image Pre-training(CLIP)やMasked Autoencoder(MAE)などの大規模なデータ学習したモデルを省メモリで学習する場合にも使用するといいかもしれないです.
ここで実際にtimmを使用してCIFAR-10 / 100を分類するコードを公開しています.
参考文献
・An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
・Visual Prompt Tuning
・Masked Autoencoders Are Scalable Vision Learners
・Learning Transferable Visual Models From Natural Language Supervision