Vision Transformer(ViT)はCNNより高精度な画像認識が可能なことから注目されています.しかし,ViTは一から学習するためには膨大なデータが必要であるため個人の環境で学習するのは難しいです.そのため,Fine-tuningを用いることが一般的です.
そこで,本記事ではViTをFine-tuningする方法を書いていきます.
今回はtimmというライブラリを使用したViTのFine-tuningの方法と各引数の簡単に説明します.
timmの学習済みモデルを使用
ViTの学習済みモデルを使用するためにはtimmを使用すると簡単にできます.
timmは正式名称Pytorch Image ModelsでSOTA(State of the Art)を達成した画像分類モデルを公開しているライブラリです.
下記はtimmでViTをFine-tuningするコードです.
下記コードでモデルを定義したら後はいつも通り学習するだけです.
import timm # timmをインポート
from timm.models import create_model # timmで公開されているモデルを読み込む関数
# 学習済みのViTを読み込む
# num_classesはデータセットのクラスに合わせて変更
model = create_model("vit_tiny_patch16_224", pretrained=True, num_classes=10)
model.to('cuda')
各引数について
ここでは各引数について説明します.
モデル名
vit_tiny_patch16_224
はモデル名を指します.
今回はvit_tiny_patch16_224
を使用していますが使用できるモデルは他にも複数存在します.
実際に使用できるモデルはここの1792行目以降に定義されています.
tiny
はモデルサイズを指し,tiny
,small
,base
,large
などが存在し,モデルの深さ(Transformer Encoderの数)や埋め込み次元数などが異なります.
patch16
はパッチサイズを指し,patch14
,patch16
,patch32
などが存在します.
224
は画像サイズを指し,224
や384
などが存在します.
学習済みの重み
pretrained
は学習済みの重みを使用するかを選択します.
pretrained=True
にすると学習済みのモデルが読み込まれます.
クラス数
num_classes
は分類器(上の図で言うMLP Head)の出力次元数を指します.使用するデータセットのクラスに合わせて変更します.10クラスであればnum_classes=10
になります.
num_classes
は指定しなければImageNet-1kと同じ1000クラス分類用の分類器が読み込まれます.
num_classes=0
とすると分類器を使用せず特徴量を抽出することが可能になります.
まとめ
今回はtimmというライブラリを使用したViTのFine-tuningの方法について説明しました.
timmを使用することで,簡単に学習済みのViTを使用することが可能です.
ここで実際にtimmを使用してViTでCIFAR-10 / 100を分類するコードを公開しています.
参考文献
ViTの論文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
論文LINK:https://arxiv.org/abs/2010.11929
timmのリポジトリ:https://github.com/huggingface/pytorch-image-models/tree/main