はじめに
huggingface / transformersを使えば簡単に画像分類系で(今のところ)最先端なVision Transformer(以降ViTと略します)が使えるようなので、手元に用意したデータセットに対してファインチューニングして画像分類タスクを解いてみました。
本記事はあくまでtransformersのライブラリを使ってViTを動かすことが目的なので、ViTの細かな理論的な話には触れませんが、
- 入力画像をパッチに分割してシーケンスとして扱う(ベースはTransformerなので)
- 自然言語処理と同様に先頭にCLSトークンを差し込んでいる
くらいは知っておいたほうが良いです。
理論的な説明は以下が大変参考になります。ViT知らないって方は一読されることをおすすめします。
- https://qiita.com/omiita/items/0049ade809c4817670d7
- https://deepsquare.jp/2020/10/vision-transformer/
huggingfaceのViTに関するリファレンスは以下になります。
本記事はGoogle Colab Proで動かしています。(多分私の処理の仕方が下手くそなのか、通常のcolabだと本記事で紹介する実装はDataSet作成のところでメモリオーバーすると思います...)
ViTはtransformersライブラリの中に含まれており、transformersは以下のようにpipで簡単にインストールできます。本記事を投稿した時点ではtransformersのバージョンは4.9.1
でした。
!pip install transformers
!pip list | grep transformers
# transformers 4.9.1
ViTの基本的な使い方を確認する
まずはどんな事前学習済モデルが使えるのかって話ですが、huggingfaceのリポジトリにたくさんの事前学習済のViTモデルが公開されています。素晴らしい。
ネーミングルールもある程度統一されているようで、例えば
google/vit-base-patch16-224-in21k
は入力の画像サイズが224x224、パッチサイズは16x16であることを意味します。-in21k
はImageNet-21kで事前学習したことを意味するようです。
他にもgoogle/vit-base-patch16-224
というのもあって、こちらは上のモデルに対して、ImageNet2012のデータ(100万枚の画像データで1000クラス)でファインチューニングしたモデルのようです。各モデルのページを見ればどんなデータで学習されたモデルなのかを確認することができますし、簡単な呼び出し方も書いてくれています。
今回は自分で手元に用意したデータをファインチューニングしたいので、google/vit-base-patch16-224-in21k
のモデルを使おうと思います。(BERTでいうところのbert-base
に相当すると思ってていいのかな?)
まずはサンプルの画像をViTで順伝播させてみようと思います。
以下のような画像を用意します。
from PIL import Image
import requests
import matplotlib.pyplot as plt
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image
huggingfaceが提供しているTransformerベースの事前学習済モデルは、基本的にその事前学習済モデルで使われた前処理用のクラスもセットで配布してくれています。BERTでいうところのBertTokenizer
とBertModel
みたいな感じですかね。通常の使い方であれば、モデルをロードするとき、この前処理用のクラスも一緒にロードします。これはViTも同じで、以下のようにViTFeatureExtractor
とViTModel
をロードすることができます。
from transformers import ViTFeatureExtractor, ViTModel
# ファインチューニングされたモデルをロードして使う場合はViTForImageClassificationですぐに分類問題に適用できるようですが、
# 今回はファインチューニングの実装のところからも行いたいので、こちらは使いません。
# from transformers import ViTForImageClassification
# 前処理用クラス
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
# モデル本体
# 順伝播時の出力にAttentionの結果もほしいときはoutput_attentions=Trueを指定する。
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k', output_attentions=True)
ViTFeatureExtractor
feature_extractor
はViTの事前学習で行われた画像に対する前処理を施す機能を提供します。google/vit-base-patch16-224-in21k
のモデルの前処理はモデルのページを見ると以下のように記載されています。
Images are resized/rescaled to the same resolution (224x224) and normalized across the RGB channels with mean (0.5, 0.5, 0.5) and standard deviation (0.5, 0.5, 0.5).
上の2匹の猫画像をfeature_extractor
に通してみましょう。
# PyTorchのテンソルで返したいときはreturn_tensor='pt'を指定します。
# imagesはPILで読み込んだimageオブジェクトでもnumpyやtensorに変換した画像でもOK
# 配列で指定すると、バッチとしてまとめて処理してくれます。
input_ids = feature_extractor(images=image, return_tensors="pt")
print(input_ids)
# {'pixel_values': tensor([[[[ 0.1137, 0.1686, 0.1843, ..., -0.1922, -0.1843, -0.1843],
# [ 0.1373, 0.1686, 0.1843, ..., -0.1922, -0.1922, -0.2078],
# [ 0.1137, 0.1529, 0.1608, ..., -0.2314, -0.2235, -0.2157],
# 〜省略〜
# [ 0.5686, 0.5529, 0.4510, ..., 0.4431, 0.3882, 0.3255],
# [ 0.5451, 0.4902, 0.5137, ..., 0.3020, 0.2078, 0.1294],
# [ 0.5686, 0.5608, 0.5137, ..., -0.2000, -0.4275, -0.5294]]]])}
print(input_ids['pixel_values'].size())
# torch.Size([1, 3, 224, 224])
# バッチサイズ , チャネルサイズ , 高さ , 幅
# 上のpixel_valuesの値は以下の処理を施した結果と同様です。
# from torchvision import transforms
# transform = transforms.Compose([
# transforms.Resize((224, 224)),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# ])
# input_ids = transform(image)
feature_extractor
の戻り値は辞書型でpixel_values
に前処理後のデータが格納されています。どんな画像データに変換されたか確認してみましょう。
plt.imshow(input_ids['pixel_values'].squeeze(0).numpy().transpose(1,2,0))
plt.show()
ViTModel
feature_extractor
で前処理されたデータをそのままViTModel
でロードしたモデル本体にそのままぶち込むことができます。
がその前に一度モデルをプリントして中身を見てみましょう。Conv2d
を使って16枚のパッチを作成していたり、Transformerブロックが12層あったり、といったことが確認できます。
print(vit_model)
出力は長いので閉じておきます
ViTModel(
(embeddings): ViTEmbeddings(
(patch_embeddings): PatchEmbeddings(
(projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): ViTEncoder(
(layer): ModuleList(
(0): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(1): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(2): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(3): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(4): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(5): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(6): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(7): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(8): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(9): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(10): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(11): ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(pooler): ViTPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
それではモデル本体にデータを順伝播させてみます。
out = vit_model(input_ids['pixel_values'])
print(out.keys())
# odict_keys(['last_hidden_state', 'pooler_output', 'attentions'])
print(out['last_hidden_state'].size())
# torch.Size([1, 197, 768])
こちらも辞書形式で出力されます。デフォルトではlast_hidden_state
とpooler_output
が返ってきますが、今回はモデルのロード時にoutput_attentions=True
を指定していたので、attentions
も取得できています。
基本的にファインチューニングのときに使うのはlast_hidden_state
(Transformerブロックの最終層の出力結果)だと思いますが、このテンソルのサイズ[1, 197, 768]
の意味だけ簡単に補足しておきます。
これは(batch_size, sequence_length, hidden_size)
を表しています。
-
batch_size
は1枚の画像だけ使用していたので1が格納されています。 -
hidden_size
はモデルの中身を見れば分かるようにモデルの隠れ層の次元数が768だからです。 - 注意が必要なのは
sequence_length
ですかね。197って何だ?って感じですが、現在使っているViTのモデルの入力画像サイズは224x224でパッチサイズが16x16なので、パッチの数は$(224÷16)^2=196$枚あります。あとは先頭に挿入されているCLSトークンを考えればシーケンスの長さは$1+196=197$ということになります。
ファインチューニング時に一番欲しいのはTransformerブロックの最終層のCLSトークンなのかなーと考えると、以下のようにしてCLSトークンのベクトルを取得できます。これが画像全体を表す特徴ベクトルとみなすことができます。
# CLSトークンはパッチの列の先頭に挿入されている
cls_vec = out['last_hidden_state'][:, 0, :]
print(cls_vec.size())
# torch.Size([1, 768])
print(cls_vec)
# tensor([[ 1.5587e-01, 9.1426e-02, 1.5177e-01, -2.0727e-02, -5.3335e-02,
# -1.5654e-01, 6.0954e-02, -9.6690e-02, 4.3038e-02, -1.6288e-01,
# 〜省略〜
# 9.1331e-02, -3.2680e-01, 5.5741e-02, 1.7970e-01, -7.1189e-02,
# -3.1798e-01, -8.5894e-02, -9.0293e-02]], grad_fn=<SliceBackward>)
ファインチューニングして画像分類タスクを解かせる
上までで基本的なViTの使い方がわかったところで、この事前学習済のViTモデルを個別のタスクを解けるようにファインチューニングしてみようと思います。
そこで、ファインチューニングに使うデータセットを以下からダウンロードしました。120カテゴリの犬を分類するタスクです。
DataLoaderの準備
ダウンロードしたデータセットをGoogle Driveに格納しておいて、以下のようにして学習用のDataLoaderと検証用のDataLoaderを作成しました。
# colabにgoogle driveをマウント
from google.colab import drive
drive.mount('/content/drive')
# データセットの格納先
drive_dir = "drive/MyDrive/ColabNotebooks/dog_datasets/"
from glob import glob
import pandas as pd
filename_list = glob(drive_dir + 'images/Images/*/*.jpg')
tmp = []
for filename in filename_list:
category = filename.split("/")[-2].split("-")[1]
tmp.append([filename, category])
# 1レコードがファイルパスとカテゴリー(正解ラベル)になるようにDataFrameにまとめる
dog_df = pd.DataFrame(tmp, columns=['path', 'category'])
# カテゴリーをID(数値)に変換した列を追加する
categories = dog_df['category'].unique().tolist()
dog_df['category_id'] = dog_df['category'].map(lambda x: categories.index(x))
Dataset作成時に画像データ1件1件をfeature_extractor
でtensor型に変換しており、めちゃくちゃ時間かかります。
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class DogData(Dataset):
def __init__(self, df):
self.images = []
self.categories = []
for row in tqdm(df.itertuples(), total=df.shape[0]):
path = row.path
category = row.category_id
image = Image.open(path)
# channelサイズが4のデータが紛れ込んでおり、feature_extractorでエラーになってしまうため、
# feature_extractorでエラーになるデータは除外する
try:
feature_ids = feature_extractor(image, return_tensors='pt')['pixel_values'].squeeze(0)
self.images.append(feature_ids)
except:
pass
self.categories.append(category)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.categories[idx]
# 学習と検証を8:2に分ける
train_df, val_df = train_test_split(dog_df, train_size=0.8)
print(train_df.shape, val_df.shape)
# (16464, 3) (4116, 3)
# 1件1件処理しちゃっているので、これがめっちゃ時間かかる...
# もっとよい方法ありましたらご教示いただけると幸いです...
train_data = DogData(train_df)
val_data = DogData(val_df)
# DataLoaderを取得する
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
モデル定義
ファインチューニング用のモデルはシンプルに最終層のCLSトークンのベクトルを全結合に一度通してクラス分類できるように変換するだけ、としました。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class ViTNet(nn.Module):
def __init__(self, pretrained_vit_model, class_num):
super(ViTNet, self).__init__()
self.vit = pretrained_vit_model
self.fc = nn.Linear(768, class_num)
def _get_cls_vec(self, states):
return states['last_hidden_state'][:, 0, :]
def forward(self, input_ids):
states = self.vit(input_ids)
states = self._get_cls_vec(states)
states = self.fc(states)
return states
# 今回のデータは120カテゴリ
CLASS_NUM = len(categories)
# 上のほうでViTModel.from_pretrained()でロードした事前学習済モデルを引数で渡すようにしています。
net = ViTNet(vit_model, CLASS_NUM)
# GPU使う
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
criterion = nn.CrossEntropyLoss()
ファインチューニングの設定
以下の記事のBERTモデルのファインチューニングの方法を参考にしながら、ViTではクラス分類用に追加した層だけパラメータ更新をONにするようにしました。
# まず全パラメータを勾配計算Falseにする
for param in net.parameters():
param.requires_grad = False
# 追加したクラス分類用の全結合層を勾配計算ありに変更
for param in net.fc.parameters():
param.requires_grad = True
optimizer = optim.Adam([
{'params': net.fc.parameters(), 'lr': 1e-4}
])
# 損失関数
criterion = nn.CrossEntropyLoss()
学習&検証
とりあえず10エポック回しました。
from sklearn.metrics import f1_score
train_losses = []
val_losses = []
train_fscores = []
val_fscores = []
for epoch in range(10):
# 学習
train_loss = 0.0
train_predict = []
train_answer = []
net.train()
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch[0].to(device)
y = batch[1].to(device)
out = net(input_ids)
loss = criterion(out, y)
loss.backward()
optimizer.step()
train_predict += out.argmax(dim=1).cpu().detach().numpy().tolist()
train_answer += y.cpu().detach().numpy().tolist()
train_loss += loss.item()
# エポックごとの損失の合計とF1-scoreを計算する
train_losses.append(train_loss)
train_fscore = f1_score(train_answer, train_predict, average='macro')
train_fscores.append(train_fscore)
# 検証
val_loss = 0.0
val_predict = []
val_answer = []
net.eval()
for batch in val_loader:
with torch.no_grad():
input_ids = batch[0].to(device)
y = batch[1].to(device)
out = net(input_ids)
loss = criterion(out, y)
val_loss += loss.item()
_, y_pred = torch.max(out, 1)
val_predict += out.argmax(dim=1).cpu().detach().numpy().tolist()
val_answer += y.cpu().detach().numpy().tolist()
# エポックごとの損失の合計とF1-scoreを計算する
val_losses.append(val_loss)
val_fscore = f1_score(val_answer, val_predict, average='macro')
val_fscores.append(val_fscore)
print('epoch', epoch,
'\ttrain loss', round(train_loss, 4), '\ttrain fscore', round(train_fscore, 4),
'\tval loss', round(val_loss, 4), '\tval fscore', round(val_fscore, 4)
)
エポック毎の損失とF1-scoreは以下のようになりました。1エポック目から検証データの精度がかなり良いですが、ホントか?ってちょっと思っちゃってます。(kaggleのカーネル見る限り、ResNet等でFスコア0.8くらいでているようなのですが、1エポック目からこんな数値がでると自分の実装がどこかおかしかったのか疑ってしまう...)
epoch 0 train loss 1145.6883 train fscore 0.504 val loss 259.0798 val fscore 0.8329
epoch 1 train loss 968.949 train fscore 0.7527 val loss 210.4485 val fscore 0.8674
epoch 2 train loss 807.7338 train fscore 0.7663 val loss 166.8859 val fscore 0.8793
epoch 3 train loss 669.7677 train fscore 0.7723 val loss 130.5843 val fscore 0.8864
epoch 4 train loss 561.8881 train fscore 0.7758 val loss 102.7415 val fscore 0.8907
epoch 5 train loss 484.4582 train fscore 0.7813 val loss 82.7449 val fscore 0.8953
epoch 6 train loss 432.3182 train fscore 0.7822 val loss 68.8519 val fscore 0.8958
epoch 7 train loss 398.4146 train fscore 0.7844 val loss 59.2286 val fscore 0.8983
epoch 8 train loss 375.4528 train fscore 0.786 val loss 52.4938 val fscore 0.9001
epoch 9 train loss 360.0181 train fscore 0.7876 val loss 47.6859 val fscore 0.9009
一応、学習曲線も表示してみます。学習データ、検証データともにきれいに損失も減ってFスコアの精度も上がっているようです。
import matplotlib.pyplot as plt
plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(train_losses, '-o' label='train loss')
plt.plot(val_losses, '-^', label='val loss')
plt.title('loss')
plt.legend()
plt.grid()
plt.subplot(1,2,2)
plt.plot(train_fscores, '-o', label='train fscore')
plt.plot(val_fscores, '-^', label='val fscore')
plt.title('fscore')
plt.legend()
plt.grid()
plt.show()
おわりに
huggingface /transformers におけるViTの基本的な使い方はわかったような気がしますが、画像系の処理全般に不慣れでセオリーな実装方法じゃないかもしれません。なにか変なことしてたらぜひご指摘ください。
おわり