Kaggle Advent Calendar 3日目の記事です。
今回はKaggleなどのコンペで Metric Learning を試すときにとりあえず最初に実装するコードをまとめました。
UMAPを使ったembeddingの可視化とか faiss を使った検索とかはこの記事で扱ってないです。
1. Metric Learning って何?
予測値じゃなくて特徴量間の距離に注目して学習する方法
- 同じクラス内ではなるべく近い距離になるように
- 違うクラス間ではなるべく遠い距離になるように
もっと詳しくしたい人は Qiita 内でもいい記事たくさんあるのでどうぞ。
- モダンな深層距離学習 (deep metric learning) 手法: SphereFace, CosFace, ArcFace #DeepLearning - Qiita
- Softmax関数をベースにした Deep Metric Learning が上手くいく理由 #Python - Qiita
2. どんなときに採用するの?
例えば下記みたいなケースで採用の検討を加速します。
- Other クラスが存在するとき
- 異常検知みたいなことが必要なとき
- 類似したサンプルを複数個取得したいとき
- 画像検索とか
Kaggle で代表的なのは Happywhale - Whale and Dolphin Identification | Kaggle などがあります。
最近の Kaggle では単純な Classification モデルのみで取り組めるコンペが気持ち少なくなっており、 UBC-OCEAN のように Other クラスを用意するなど、 クソ舐めた 挑戦的なコンペが生まれています。
2.1 Classification モデルじゃダメなの?
特徴量を取得するだけなら Classification モデルを学習するだけでも実現できます。
この特徴量と他の特徴量を組み合わせて後段の予測をするという処理も見かけます。
一方で下記の例のような特徴量の距離に応じた処理を考える場合、Classification モデルだとうまくいかないことがあります。
- 距離が近いものをN個取得したい
- 距離が一定以上離れている場合 Other(とかNew)クラスとしたい
というのも Classification の場合予測値に対する Loss が最小になるように学習しているので特徴量間の距離は考慮していないからです(それはそう)。実際 Classification で学習させて距離を計算すると極端な値(0 か max)になることが多く、しきい値を設定して Other クラスを検出するには向いていないと個人的に感じました。
3. Kaggler側の要求(要件)
自分がコンペでとりあえず試すとなったときは以下を満たしていると嬉しかったりします。
-
気軽に試せてェ〜
- → 一応下記のコードなら
image, label_id
を返すような dataloader 作れば動く(はず) - → 今回は ArcFace(Loss) を採用したのでほぼ Classification の延長線として扱える
- → 一応下記のコードなら
-
色々変更できてェ〜
- → 一応下記のコードなら色々変更できる
-
できれば精度も良くてェ〜
- → それは頑張れ
-
いい感じのボードで結果も見れてェ〜
- → 必要なら
WandbLogger
呼び出して適用できる
- → 必要なら
4. 最近のMetric Learningの始め方(本題)
というわけで上記の要件を満たせるよう、今回は下記のライブラリを使用して実装します。
MetricLearning はいろんな手法がありますが、今回は ArcFace(Loss) を採用します。理由としては Classification モデルと大体同じ使い方で学習できるためです。
今回の実装例は CIFAR-10 を使って学習しますがスクリプト実行時にデータがダウンロードされるので事前ダウンロードは不要です。
4.1 各ライブラリ紹介
- pytorch
-
pytorch lightning(lightning)
- pytorch のループ周りの wrapper として使う
- 概要とか他の使い方は以前記事書いたのでどうぞ。でもバージョンアップにより細かい書き方が変わってるので注意してください
- (余談) 2.x 系列になるにあたり名前が lightning になったので install 時の名前も変える必要があるので注意
- kaggle notebook 環境だと pytorch lightning のままの場合もある
-
lightning bolts
- pytorch lightning の拡張機能を使えるライブラリ
- 今回は lr scheduler の呼び出しや CIFAR-10 datasetの作成に使う
-
torchmetrics
- 名前は違うけど pytorch lightning プロジェクトの一つ
- log 用 metrics の集計に使う
- 今回は関係ないが DDP 学習時の結果集約などで役に立つ
- 今回は train embeddings の集計にも使う
-
timm
- 言わずと知れた Classification モデルライブラリ
- 今回は feature extractor として使う
-
PyTorch Metric learning
- Metric Learning周りのツールがまとまってる
- 今回は ArcFaceLoss や距離計算用の関数を呼び出すために使ってる
- 今回使用しないもの
-
wandb(Weights & Biases)
- ナウでヤングな logger
- もし使いたい場合は
from lightning.pytorch.loggers import WandbLogger
-
OmefaConf: Flexible Python configuration system.
- config ファイル読み込むやつ
- pyyaml(yaml) よりも python の型に対応してるので楽
-
wandb(Weights & Biases)
4.1.1 timm の使い方
- 今回は feature extractor として使います
-
timm.create_model()
時にnum_classes=0
とすると (last) feature extractor となります -
timm.create_model(~~~, num_classes=0, global_pool='')
とすると pool 前の last feature が出力されます - 詳しい使い方は下記を参照してください
-
- feature size が ArcFaceLoss の
__init__()
実行時に必要ですのでfeature_info
から取得します- 適当に input 作って出力から size を確認することもできます
In [1]: import torch
In [2]: import timm
In [3]: timm.__version__
Out[3]: '0.9.7'
In [4]: m = timm.create_model('resnet18d', pretrained=False, num_classes=0)
In [5]: m.feature_info[-1]["num_chs"]
Out[5]: 512
In [6]: m(torch.rand(1, 3, 256, 256)).size()
Out[6]: torch.Size([1, 512])
4.1.2 torchmetrics の使い方
- Accuracy — PyTorch-Metrics 1.1.0 documentation
- 結果のappendとmetricsの計算をまとめたもの
- 今回は batch 毎に
update()
で結果を格納していき、最後にcompute()
で結果を集約するようにしています
In [1]: import torch
In [2]: from torch import tensor
In [3]: from torchmetrics import Accuracy, MeanMetric
In [4]: n_class = 4
In [5]: bs = 4
In [6]: acc = Accuracy(task="multiclass", num_classes=n_class)
# 何もデータを格納していないときに compute() すると warning が出る
In [7]: acc.compute()
/usr/local/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs) # noqa: B028
Out[7]: tensor(0.)
# 一気に計算するパターン
In [8]: target = tensor([0, 1, 2, 3])
In [9]: preds = tensor([0, 2, 1, 3])
In [10]: acc(preds, target)
Out[10]: tensor(0.5000)
# 一度計算すると acc 内部に target と preds が蓄積される
In [11]: acc.compute()
Out[11]: tensor(0.5000)
# 内部データをリセット
In [12]: acc.reset()
In [13]: acc.compute()
/usr/local/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
warnings.warn(*args, **kwargs) # noqa: B028
Out[13]: tensor(0.)
# preds は (bs, n_class) でも渡せる
In [14]: preds = torch.rand(bs, n_class).softmax(dim=-1)
In [15]: preds
Out[15]:
tensor([[0.2458, 0.2113, 0.3080, 0.2348],
[0.2733, 0.3262, 0.1925, 0.2079],
[0.1881, 0.4060, 0.2464, 0.1594],
[0.3179, 0.1644, 0.2027, 0.3151]])
In [16]: acc(preds, target)
Out[16]: tensor(0.2500)
In [17]: acc.reset()
# preds は logits でも良い
In [18]: preds = torch.rand(bs, n_class) - 0.5
In [19]: preds
Out[19]:
tensor([[ 0.0214, -0.0699, -0.1490, -0.0610],
[ 0.0845, -0.3965, 0.0996, 0.0343],
[-0.3506, -0.0357, 0.0570, -0.0718],
[ 0.1613, -0.4284, 0.0166, 0.3271]])
In [20]: acc(preds, target)
Out[20]: tensor(0.7500)
In [21]: acc.reset()
# update() で結果を蓄積して最後に compute() で結果を算出することも可能
In [22]: for _ in range(10):
...: acc.update(preds=preds, target=target)
...:
In [23]: acc.compute()
Out[23]: tensor(0.7500)
In [24]: acc.reset()
4.2 実行環境
各ライブラリのバージョンは以下の通りです。
torch==2.0.0+cu118
torchvision==0.15.1+cu118
pytorch-metric-learning==2.3.0
lightning==2.0.8
lightning-bolts==0.7.0
torchmetrics==1.2.0
timm==0.9.7
4.3 実装全体
key | value |
---|---|
Dataset | CIFAR-10 |
Model | Resnet18d (pretrain=True) |
Loss | ArcFaceLoss |
Optimizer | AdamW |
Scheduler | CosineAnnealing + warmup |
Data augmentation | RandomCrop + RandomHflip |
How to pred class | L2 distance from train mean embedding |
import os
from pathlib import Path
import lightning as L
import timm
import torch
import torch.nn.functional as F
import torchvision
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_metric_learning.distances import LpDistance
from pytorch_metric_learning.losses import ArcFaceLoss
from torch.optim import AdamW
from torchmetrics import Accuracy, MeanMetric, MetricCollection
from torchmetrics.aggregation import CatMetric
class MyLightningModule(L.LightningModule):
def __init__(
self,
num_classes: int = 10,
max_epochs: int = 30,
init_lr: float = 3e-4,
arcface_margin: float = 28.6,
arcface_scale: int = 64,
):
super().__init__()
self.max_epochs = max_epochs
self.init_lr = init_lr
# model と loss
# Create model(num_classes=0) にすると feature extractor となる
self.net = timm.create_model("resnet18d", pretrained=True, num_classes=0)
emb_size = self.net.feature_info[-1]["num_chs"]
# margin: The paper uses 0.5 radians, which is 28.6 degrees.
self.loss = ArcFaceLoss(
num_classes=num_classes,
embedding_size=emb_size,
margin=arcface_margin,
scale=arcface_scale,
)
# log に残す metrics 用
self.metrics = MetricCollection(
dict(
train_loss=MeanMetric(),
val_loss=MeanMetric(),
val_acc_macro=Accuracy(
task="multiclass", num_classes=num_classes, average="macro"
),
)
)
# validation data のクラス予測時に使う用
self.train_data = MetricCollection(
dict(embeddings=CatMetric(), labels=CatMetric())
)
self.train_mean_embeddings = torch.rand(num_classes, emb_size)
self.dist_func = LpDistance(p=2) # L2 distance
def forward(self, x):
embeddings = self.net(x)
return embeddings
def training_step(self, batch, batch_nb):
x, y = batch
embeddings = self.forward(x)
loss = self.loss(embeddings, y)
self.metrics["train_loss"].update(loss.detach())
self.train_data["embeddings"].update(F.normalize(embeddings.detach()))
self.train_data["labels"].update(y.detach())
return loss
def on_train_epoch_end(self):
result = self.train_data.compute()
self.train_data.reset()
labels = result["labels"]
embeddings = result["embeddings"]
# calc train mean feature for each class
for class_id in range(len(self.train_mean_embeddings)):
embeddings_tmp = embeddings[labels == class_id]
self.train_mean_embeddings[class_id] = embeddings_tmp.mean(dim=0)
def validation_step(self, batch, batch_nb):
x, y = batch
embeddings = self.forward(x)
loss = self.loss(embeddings, y)
self.metrics["val_loss"].update(loss)
# train feature mean と距離を計算して matrix にする (val_batch x n_class)
dist_matrix = self.dist_func(
F.normalize(embeddings),
self.train_mean_embeddings.to(embeddings.device),
)
# train feature mean と比較して近いものを予測クラスとして出力
preds = dist_matrix.argmin(dim=1)
self.metrics["val_acc_macro"].update(preds=preds, target=y)
return loss
def on_validation_epoch_end(self):
log_tmp = dict(epoch=int(self.current_epoch))
log_metrics = self.metrics.compute()
log_metrics = {k: v.item() for k, v in log_metrics.items()}
log_tmp.update(log_metrics)
self.metrics.reset()
self.log_dict(log_tmp, prog_bar=True, sync_dist=True)
def configure_optimizers(self):
optimizer = AdamW(
self.parameters(), lr=self.init_lr, weight_decay=1e-6, eps=1e-7
)
scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=5,
max_epochs=self.max_epochs,
warmup_start_lr=self.init_lr / 10.0,
eta_min=1e-6,
last_epoch=-1,
)
# interval: step or epoch
scheduler = {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
}
return [optimizer], [scheduler]
def main():
output_path = Path("output")
output_path.mkdir(parents=True, exist_ok=True)
L.seed_everything(42)
# CIFAR-10 datset
train_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
cifar10_normalization(),
]
)
test_transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
cifar10_normalization(),
]
)
cifar10_dm = CIFAR10DataModule(
data_dir=".",
batch_size=256,
num_workers=6,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)
cifar10_dm.prepare_data()
cifar10_dm.setup()
# LightningModule
max_epochs = 30
model = MyLightningModule(
num_classes=10, # CIFAR-10 なので 10 クラス
max_epochs=max_epochs,
init_lr=1e-3,
arcface_margin=28.6, # チューニング要素
arcface_scale=64, # チューニング要素
)
# W in ArcFaceLoss is target of learning
for name, _ in model.named_parameters(recurse=True):
if "loss" in name:
print("learning target in loss: ", name)
# WandbLogger とか使いたい場合は loggers に append する
loggers = [CSVLogger(save_dir=output_path, name="demo")]
checkpoint_callback = ModelCheckpoint(
dirpath=output_path,
filename="sample",
save_weights_only=True,
monitor=None,
)
trainer = L.Trainer(
logger=loggers,
callbacks=[checkpoint_callback],
default_root_dir=os.getcwd(),
accelerator="gpu",
strategy="ddp",
devices=1,
precision="16-mixed", # 32-true or 16-mixed
max_epochs=max_epochs,
deterministic=False,
)
# Start train
trainer.fit(
model,
train_dataloaders=cifar10_dm.train_dataloader(),
val_dataloaders=cifar10_dm.val_dataloader(),
)
if __name__ == "__main__":
main()
4.4 実装のポイント
- チューニング要素(e.g. ArcFace のハイパーパラメータ等)は外に出しておく
-
timm.create_model
の引数でnum_classes=0
とする- feature extractor として使いたいため
-
resnet18d
はいろんなものに変更できる
- embedding size を逐一調べなくて良いようにする
- timm なら
net.feature_info[-1]["num_chs"]
で取得できる
- timm なら
- metric learning の部分を Loss に任せる
self.loss = ArcFaceLoss(num_classes=n_class, embedding_size=emb_size)
- siamese network とかは input から作成する必要があるけど ArcFace や CosFace なら Loss だけ変えて実装できるので気軽に試せる
- torch metrics を使って結果を集約している
- これで DDP で学習したときとかも安心して使える
- loss だけじゃなくて accuracy も見れるようにした
- train の各クラスの平均特徴量を作成して、その L2 distance が最も近いクラスを予測クラスとした
dist_matrix = self.dist_func(val_embeddings, train_mean_embeddings)
preds = dist_matrix.argmin(dim=1)
- top N 個とか取りたい場合は
argmin
をいい感じに変更する
- train の各クラスの平均特徴量を作成して、その L2 distance が最も近いクラスを予測クラスとした
4.5 実行結果
上記のコードであれば pretrained weight とか CIFAR-10 とか事前にダウンロードしてなくても大丈夫です。
$ python train.py
~~~
learning target in loss: loss.W
~~~
`Trainer.fit` stopped: `max_epochs=30` reached.
Epoch 29: 100%|██| 157/157 [00:11<00:00, 13.94it/s, v_num=0, epoch=29.00
, train_loss=5.400, val_acc_macro=0.874, val_loss=7.150]
結果は下記のように格納されています。
$ ls output/
demo sample-v1.ckpt
$ ls output/demo/version_0/
hparams.yaml metrics.csv
$ head -n 5 output/demo/version_0/metrics.csv
epoch,train_loss,val_acc_macro,val_loss,step
0.0,34.21628189086914,0.09846183657646179,33.21874237060547,156
1.0,27.717126846313477,0.6571612358093262,20.311311721801758,313
2.0,16.87421417236328,0.762157142162323,14.048077583312988,470
3.0,14.297616958618164,0.7699860334396362,13.174238204956055,627
4.6 注意事項: ArcFaceLoss.W
も学習対象
- 今回使用した
pytorch_metric_learning.losses.ArcFaceLoss
内のW
も学習対象であることに注意してください-
embedding
をn_class
次元に線形変換するための重み - timm model で言う classifier から bias を取ったものと大体同じ
-
nn.Module
の定義抜きで簡潔に実装できているのはこれのおかげ
-
- pytorch lightning の場合デフォルトで学習対象に含まれるため追加の設定は不要
- 自動で
self.parameters()
に含まれる- 上記実行結果の
learning target: loss.W
から確認できる
- 上記実行結果の
- 自動で
- 素の pytorch のみで学習させる場合は
optimizer
定義時にloss.parameters()
も渡すようにしてください- もしくは
loss.parameters()
用のoptimizer
を別に用意する- その場合
optimizer
毎にzero_grad()
とstep()
が必要になる
- その場合
- もしくは
optimizer = AdamW(
[{'params': model.parameters()}, {'params': loss.parameters()}],
lr=3e-4,
weight_decay=1e-6,
eps=1e-7
)
5. 補足
5.1 Wandb Logger 使いたい
Step1. wandb にログインする
- コンソールで
$wandb login
できる環境なら予めしておくと楽- その場合下記の対応は不要
- 難しい場合、下記のような
api_key.json
を用意して環境変数に入力する
{
"wandb": "XXXXXXXXXXXXXXXXXXXXXXXXX"
}
with open("./api_keys.json", "r") as f:
key = json.load(f)["wandb"]
os.environ["WANDB_API_KEY"] = key
Step2. loggers を変更
-
offline=True
にすると local 環境だけで実行されるので、offline=debug
とかにするのはおすすめ- wandb web console 上で debug 用の log を消す必要がなくなる
from lightning.pytorch.loggers import CSVLogger, WandbLogger
loggers = [
CSVLogger(save_dir=output_path, name=f"fold_{fold}"),
WandbLogger(
project="Sample-MetricLearning",
group="group1",
name="sample_fold_0",
offline=False,
),
]
5.2 上記のスクリプトで学習したモデルで予測したい
-
load_from_checkpoint
で読み込める
weight_path = "output/sample-v1.ckpt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MyLightningModule.load_from_checkpoint(
weight_path,
num_classes=10,
max_epochs=30,
init_lr=1e-3,
arcface_margin=28.6,
arcface_scale=64,
).net
model = model.to(device)
model = model.eval()
5.3 自分のデータで学習したい
-
image, label_id
を返すような dataset および data_loader を作成する-
label_id
はクラス数n
のときは[0, ..., n-1]
-
- データセットに応じた
num_classes
をMyLightningModule
初期化時に引数として与える
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, image_paths, label_ids, transform=None):
self.image_paths = image_paths
self.label_ids = label_ids
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.img_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transform(image=image)["image"].float()
label_id = torch.tensor(self.label_ids[idx]).long()
return image, label_id
train_ds = MyDataset(...)
valid_ds = MyDataset(...)
train_loader = DataLoader(
train_ds,
batch_size=16,
shuffle=True,
num_workers=4,
drop_last=True
)
valid_loader = DataLoader(
valid_ds,
batch_size=16,
shuffle=False,
num_workers=4,
drop_last=False
)
~~~~~~
model = MyLightningModule(num_classes=n_classes, ...)
~~~~~~
trainer.fit(
model,
train_dataloaders=train_loader,
val_dataloaders=valid_loader,
)