LoginSignup
2
3

timmを利用したVision Transformerの学習

Last updated at Posted at 2024-03-15

Vision Transformer(ViT)はCNNより高精度な画像認識が可能なことから注目されています.しかし,ViTは一から学習するためには膨大なデータが必要であるため個人の環境で学習するのは難しいです.そのため,Fine-tuningを用いることが一般的です.
そこで,本記事ではViTをFine-tuningする方法を書いていきます.
今回はtimmというライブラリを使用したViTのFine-tuningの方法と各引数の簡単に説明します.

ViTのモデル構造
ViT.png

timmの学習済みモデルを使用

ViTの学習済みモデルを使用するためにはtimmを使用すると簡単にできます.
timmは正式名称Pytorch Image ModelsでSOTA(State of the Art)を達成した画像分類モデルを公開しているライブラリです.

下記はtimmでViTをFine-tuningするコードです.
下記コードでモデルを定義したら後はいつも通り学習するだけです.

import timm # timmをインポート
from timm.models import create_model # timmで公開されているモデルを読み込む関数

# 学習済みのViTを読み込む
# num_classesはデータセットのクラスに合わせて変更
model = create_model("vit_tiny_patch16_224", pretrained=True, num_classes=10) 
model.to('cuda')

各引数について

ここでは各引数について説明します.

モデル名

vit_tiny_patch16_224はモデル名を指します.
今回はvit_tiny_patch16_224を使用していますが使用できるモデルは他にも複数存在します.
実際に使用できるモデルはここの1792行目以降に定義されています.

tinyはモデルサイズを指し,tinysmallbaselargeなどが存在し,モデルの深さ(Transformer Encoderの数)や埋め込み次元数などが異なります.
patch16はパッチサイズを指し,patch14patch16patch32などが存在します.
224は画像サイズを指し,224384などが存在します.

学習済みの重み

pretrainedは学習済みの重みを使用するかを選択します.
pretrained=Trueにすると学習済みのモデルが読み込まれます.

クラス数

num_classesは分類器(上の図で言うMLP Head)の出力次元数を指します.使用するデータセットのクラスに合わせて変更します.10クラスであればnum_classes=10になります.
num_classesは指定しなければImageNet-1kと同じ1000クラス分類用の分類器が読み込まれます.
num_classes=0とすると分類器を使用せず特徴量を抽出することが可能になります.

まとめ

今回はtimmというライブラリを使用したViTのFine-tuningの方法について説明しました.
timmを使用することで,簡単に学習済みのViTを使用することが可能です.
ここで実際にtimmを使用してViTでCIFAR-10 / 100を分類するコードを公開しています.

参考文献

ViTの論文:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
論文LINK:https://arxiv.org/abs/2010.11929
timmのリポジトリ:https://github.com/huggingface/pytorch-image-models/tree/main

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3