31
35

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

huggingface / transformersを使ってVision Transformer(ViT)で画像分類タスクをファインチューニングで解いてみた

Last updated at Posted at 2021-07-29

はじめに

huggingface / transformersを使えば簡単に画像分類系で(今のところ)最先端なVision Transformer(以降ViTと略します)が使えるようなので、手元に用意したデータセットに対してファインチューニングして画像分類タスクを解いてみました。
本記事はあくまでtransformersのライブラリを使ってViTを動かすことが目的なので、ViTの細かな理論的な話には触れませんが、

  • 入力画像をパッチに分割してシーケンスとして扱う(ベースはTransformerなので)
  • 自然言語処理と同様に先頭にCLSトークンを差し込んでいる

くらいは知っておいたほうが良いです。

理論的な説明は以下が大変参考になります。ViT知らないって方は一読されることをおすすめします。

huggingfaceのViTに関するリファレンスは以下になります。

本記事はGoogle Colab Proで動かしています。(多分私の処理の仕方が下手くそなのか、通常のcolabだと本記事で紹介する実装はDataSet作成のところでメモリオーバーすると思います...)

ViTはtransformersライブラリの中に含まれており、transformersは以下のようにpipで簡単にインストールできます。本記事を投稿した時点ではtransformersのバージョンは4.9.1でした。

!pip install transformers
!pip list | grep transformers
# transformers                  4.9.1

ViTの基本的な使い方を確認する

まずはどんな事前学習済モデルが使えるのかって話ですが、huggingfaceのリポジトリにたくさんの事前学習済のViTモデルが公開されています。素晴らしい。

ネーミングルールもある程度統一されているようで、例えば
google/vit-base-patch16-224-in21kは入力の画像サイズが224x224、パッチサイズは16x16であることを意味します。-in21kはImageNet-21kで事前学習したことを意味するようです。
他にもgoogle/vit-base-patch16-224というのもあって、こちらは上のモデルに対して、ImageNet2012のデータ(100万枚の画像データで1000クラス)でファインチューニングしたモデルのようです。各モデルのページを見ればどんなデータで学習されたモデルなのかを確認することができますし、簡単な呼び出し方も書いてくれています。

今回は自分で手元に用意したデータをファインチューニングしたいので、google/vit-base-patch16-224-in21kのモデルを使おうと思います。(BERTでいうところのbert-baseに相当すると思ってていいのかな?)

まずはサンプルの画像をViTで順伝播させてみようと思います。
以下のような画像を用意します。

from PIL import Image
import requests
import matplotlib.pyplot as plt

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

image.png

huggingfaceが提供しているTransformerベースの事前学習済モデルは、基本的にその事前学習済モデルで使われた前処理用のクラスもセットで配布してくれています。BERTでいうところのBertTokenizerBertModelみたいな感じですかね。通常の使い方であれば、モデルをロードするとき、この前処理用のクラスも一緒にロードします。これはViTも同じで、以下のようにViTFeatureExtractorViTModelをロードすることができます。

from transformers import ViTFeatureExtractor, ViTModel
# ファインチューニングされたモデルをロードして使う場合はViTForImageClassificationですぐに分類問題に適用できるようですが、
# 今回はファインチューニングの実装のところからも行いたいので、こちらは使いません。
# from transformers import ViTForImageClassification

# 前処理用クラス
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# モデル本体
# 順伝播時の出力にAttentionの結果もほしいときはoutput_attentions=Trueを指定する。
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', output_attentions=True)

ViTFeatureExtractor

feature_extractorはViTの事前学習で行われた画像に対する前処理を施す機能を提供します。google/vit-base-patch16-224-in21kのモデルの前処理はモデルのページを見ると以下のように記載されています。

Images are resized/rescaled to the same resolution (224x224) and normalized across the RGB channels with mean (0.5, 0.5, 0.5) and standard deviation (0.5, 0.5, 0.5).

上の2匹の猫画像をfeature_extractorに通してみましょう。

# PyTorchのテンソルで返したいときはreturn_tensor='pt'を指定します。
# imagesはPILで読み込んだimageオブジェクトでもnumpyやtensorに変換した画像でもOK
# 配列で指定すると、バッチとしてまとめて処理してくれます。
input_ids = feature_extractor(images=image, return_tensors="pt")
print(input_ids)
# {'pixel_values': tensor([[[[ 0.1137,  0.1686,  0.1843,  ..., -0.1922, -0.1843, -0.1843],
#          [ 0.1373,  0.1686,  0.1843,  ..., -0.1922, -0.1922, -0.2078],
#          [ 0.1137,  0.1529,  0.1608,  ..., -0.2314, -0.2235, -0.2157],
#          〜省略〜
#          [ 0.5686,  0.5529,  0.4510,  ...,  0.4431,  0.3882,  0.3255],
#          [ 0.5451,  0.4902,  0.5137,  ...,  0.3020,  0.2078,  0.1294],
#          [ 0.5686,  0.5608,  0.5137,  ..., -0.2000, -0.4275, -0.5294]]]])}

print(input_ids['pixel_values'].size())
# torch.Size([1, 3, 224, 224])
# バッチサイズ , チャネルサイズ , 高さ , 幅

# 上のpixel_valuesの値は以下の処理を施した結果と同様です。
# from torchvision import transforms
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# ])
# input_ids = transform(image)

feature_extractorの戻り値は辞書型でpixel_valuesに前処理後のデータが格納されています。どんな画像データに変換されたか確認してみましょう。

plt.imshow(input_ids['pixel_values'].squeeze(0).numpy().transpose(1,2,0))
plt.show()

image.png

ViTModel

feature_extractorで前処理されたデータをそのままViTModelでロードしたモデル本体にそのままぶち込むことができます。
がその前に一度モデルをプリントして中身を見てみましょう。Conv2dを使って16枚のパッチを作成していたり、Transformerブロックが12層あったり、といったことが確認できます。

print(vit_model)
出力は長いので閉じておきます
ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): PatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (1): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (2): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (3): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (4): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (5): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (6): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (7): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (8): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (9): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (10): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
      (11): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
    )
  )
  (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (pooler): ViTPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

それではモデル本体にデータを順伝播させてみます。

out = vit_model(input_ids['pixel_values'])

print(out.keys())
# odict_keys(['last_hidden_state', 'pooler_output', 'attentions'])

print(out['last_hidden_state'].size())
# torch.Size([1, 197, 768])

こちらも辞書形式で出力されます。デフォルトではlast_hidden_statepooler_outputが返ってきますが、今回はモデルのロード時にoutput_attentions=Trueを指定していたので、attentionsも取得できています。
基本的にファインチューニングのときに使うのはlast_hidden_state(Transformerブロックの最終層の出力結果)だと思いますが、このテンソルのサイズ[1, 197, 768]の意味だけ簡単に補足しておきます。

これは(batch_size, sequence_length, hidden_size)を表しています。

  • batch_sizeは1枚の画像だけ使用していたので1が格納されています。
  • hidden_sizeはモデルの中身を見れば分かるようにモデルの隠れ層の次元数が768だからです。
  • 注意が必要なのはsequence_lengthですかね。197って何だ?って感じですが、現在使っているViTのモデルの入力画像サイズは224x224でパッチサイズが16x16なので、パッチの数は$(224÷16)^2=196$枚あります。あとは先頭に挿入されているCLSトークンを考えればシーケンスの長さは$1+196=197$ということになります。

ファインチューニング時に一番欲しいのはTransformerブロックの最終層のCLSトークンなのかなーと考えると、以下のようにしてCLSトークンのベクトルを取得できます。これが画像全体を表す特徴ベクトルとみなすことができます。

# CLSトークンはパッチの列の先頭に挿入されている
cls_vec = out['last_hidden_state'][:, 0, :]

print(cls_vec.size())
# torch.Size([1, 768])

print(cls_vec)
# tensor([[ 1.5587e-01,  9.1426e-02,  1.5177e-01, -2.0727e-02, -5.3335e-02,
#          -1.5654e-01,  6.0954e-02, -9.6690e-02,  4.3038e-02, -1.6288e-01,
#          〜省略〜
#           9.1331e-02, -3.2680e-01,  5.5741e-02,  1.7970e-01, -7.1189e-02,
#          -3.1798e-01, -8.5894e-02, -9.0293e-02]], grad_fn=<SliceBackward>)

ファインチューニングして画像分類タスクを解かせる

上までで基本的なViTの使い方がわかったところで、この事前学習済のViTモデルを個別のタスクを解けるようにファインチューニングしてみようと思います。

そこで、ファインチューニングに使うデータセットを以下からダウンロードしました。120カテゴリの犬を分類するタスクです。

DataLoaderの準備

ダウンロードしたデータセットをGoogle Driveに格納しておいて、以下のようにして学習用のDataLoaderと検証用のDataLoaderを作成しました。

# colabにgoogle driveをマウント
from google.colab import drive
drive.mount('/content/drive')

# データセットの格納先
drive_dir = "drive/MyDrive/ColabNotebooks/dog_datasets/"

from glob import glob
import pandas as pd

filename_list = glob(drive_dir + 'images/Images/*/*.jpg')
tmp = []
for filename in filename_list:
    category = filename.split("/")[-2].split("-")[1]
    tmp.append([filename, category])

# 1レコードがファイルパスとカテゴリー(正解ラベル)になるようにDataFrameにまとめる
dog_df = pd.DataFrame(tmp, columns=['path', 'category'])

# カテゴリーをID(数値)に変換した列を追加する
categories = dog_df['category'].unique().tolist()
dog_df['category_id'] = dog_df['category'].map(lambda x: categories.index(x))

Dataset作成時に画像データ1件1件をfeature_extractorでtensor型に変換しており、めちゃくちゃ時間かかります。

import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class DogData(Dataset):
    def __init__(self, df):
        self.images = []
        self.categories = []

        for row in tqdm(df.itertuples(), total=df.shape[0]):
            path = row.path
            category = row.category_id
            image = Image.open(path)
            # channelサイズが4のデータが紛れ込んでおり、feature_extractorでエラーになってしまうため、
            # feature_extractorでエラーになるデータは除外する
            try:
                feature_ids = feature_extractor(image, return_tensors='pt')['pixel_values'].squeeze(0)
                self.images.append(feature_ids)
            except:
                pass
            self.categories.append(category)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.categories[idx]

# 学習と検証を8:2に分ける
train_df, val_df = train_test_split(dog_df, train_size=0.8)
print(train_df.shape, val_df.shape)
# (16464, 3) (4116, 3)

# 1件1件処理しちゃっているので、これがめっちゃ時間かかる...
# もっとよい方法ありましたらご教示いただけると幸いです...
train_data = DogData(train_df)
val_data = DogData(val_df)

# DataLoaderを取得する
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

モデル定義

ファインチューニング用のモデルはシンプルに最終層のCLSトークンのベクトルを全結合に一度通してクラス分類できるように変換するだけ、としました。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class ViTNet(nn.Module):
    def __init__(self, pretrained_vit_model, class_num):
        super(ViTNet, self).__init__()
        self.vit = pretrained_vit_model
        self.fc = nn.Linear(768, class_num)
    
    def _get_cls_vec(self, states):
        return states['last_hidden_state'][:, 0, :]
    
    def forward(self, input_ids):
        states = self.vit(input_ids)
        states = self._get_cls_vec(states)
        states = self.fc(states)
        return states

# 今回のデータは120カテゴリ
CLASS_NUM = len(categories)

# 上のほうでViTModel.from_pretrained()でロードした事前学習済モデルを引数で渡すようにしています。
net = ViTNet(vit_model, CLASS_NUM)
# GPU使う
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

criterion = nn.CrossEntropyLoss()

ファインチューニングの設定

以下の記事のBERTモデルのファインチューニングの方法を参考にしながら、ViTではクラス分類用に追加した層だけパラメータ更新をONにするようにしました。

# まず全パラメータを勾配計算Falseにする
for param in net.parameters():
    param.requires_grad = False

# 追加したクラス分類用の全結合層を勾配計算ありに変更
for param in net.fc.parameters():
    param.requires_grad = True

optimizer = optim.Adam([
    {'params': net.fc.parameters(), 'lr': 1e-4}
])

# 損失関数
criterion = nn.CrossEntropyLoss()

学習&検証

とりあえず10エポック回しました。

from sklearn.metrics import f1_score

train_losses = []
val_losses = []
train_fscores = []
val_fscores = []

for epoch in range(10):

    # 学習
    train_loss = 0.0
    train_predict = []
    train_answer = []
    net.train()
    for batch in train_loader:
        optimizer.zero_grad()

        input_ids = batch[0].to(device)
        y = batch[1].to(device)
        out = net(input_ids)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        train_predict += out.argmax(dim=1).cpu().detach().numpy().tolist()
        train_answer += y.cpu().detach().numpy().tolist()
        train_loss += loss.item()

    # エポックごとの損失の合計とF1-scoreを計算する
    train_losses.append(train_loss)
    train_fscore = f1_score(train_answer, train_predict, average='macro')
    train_fscores.append(train_fscore)

    # 検証
    val_loss = 0.0
    val_predict = []
    val_answer = []
    net.eval()
    for batch in val_loader:
        with torch.no_grad():
            
            input_ids = batch[0].to(device)
            y = batch[1].to(device)
            out = net(input_ids)
            loss = criterion(out, y)

            val_loss += loss.item()
            _, y_pred = torch.max(out, 1)            
            val_predict += out.argmax(dim=1).cpu().detach().numpy().tolist()
            val_answer += y.cpu().detach().numpy().tolist()

    # エポックごとの損失の合計とF1-scoreを計算する
    val_losses.append(val_loss)
    val_fscore = f1_score(val_answer, val_predict, average='macro')
    val_fscores.append(val_fscore)

    print('epoch', epoch,
          '\ttrain loss', round(train_loss, 4), '\ttrain fscore', round(train_fscore, 4),
          '\tval loss', round(val_loss, 4), '\tval fscore', round(val_fscore, 4)
          )

エポック毎の損失とF1-scoreは以下のようになりました。1エポック目から検証データの精度がかなり良いですが、ホントか?ってちょっと思っちゃってます。(kaggleのカーネル見る限り、ResNet等でFスコア0.8くらいでているようなのですが、1エポック目からこんな数値がでると自分の実装がどこかおかしかったのか疑ってしまう...)

epoch 0 	train loss 1145.6883 	train fscore 0.504 	val loss 259.0798 	val fscore 0.8329
epoch 1 	train loss 968.949 	train fscore 0.7527 	val loss 210.4485 	val fscore 0.8674
epoch 2 	train loss 807.7338 	train fscore 0.7663 	val loss 166.8859 	val fscore 0.8793
epoch 3 	train loss 669.7677 	train fscore 0.7723 	val loss 130.5843 	val fscore 0.8864
epoch 4 	train loss 561.8881 	train fscore 0.7758 	val loss 102.7415 	val fscore 0.8907
epoch 5 	train loss 484.4582 	train fscore 0.7813 	val loss 82.7449 	val fscore 0.8953
epoch 6 	train loss 432.3182 	train fscore 0.7822 	val loss 68.8519 	val fscore 0.8958
epoch 7 	train loss 398.4146 	train fscore 0.7844 	val loss 59.2286 	val fscore 0.8983
epoch 8 	train loss 375.4528 	train fscore 0.786 	val loss 52.4938 	val fscore 0.9001
epoch 9 	train loss 360.0181 	train fscore 0.7876 	val loss 47.6859 	val fscore 0.9009

一応、学習曲線も表示してみます。学習データ、検証データともにきれいに損失も減ってFスコアの精度も上がっているようです。

import matplotlib.pyplot as plt

plt.figure(figsize=(15,5))

plt.subplot(1,2,1)
plt.plot(train_losses, '-o' label='train loss')
plt.plot(val_losses, '-^', label='val loss')
plt.title('loss')
plt.legend()
plt.grid()

plt.subplot(1,2,2)
plt.plot(train_fscores, '-o', label='train fscore')
plt.plot(val_fscores, '-^', label='val fscore')
plt.title('fscore')
plt.legend()
plt.grid()

plt.show()

image.png

おわりに

huggingface /transformers におけるViTの基本的な使い方はわかったような気がしますが、画像系の処理全般に不慣れでセオリーな実装方法じゃないかもしれません。なにか変なことしてたらぜひご指摘ください。

おわり

31
35
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
31
35

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?