43
35

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

最近のMetric Learningの始め方(コンペを見据えて)

Last updated at Posted at 2023-12-02

Kaggle Advent Calendar 3日目の記事です。

今回はKaggleなどのコンペで Metric Learning を試すときにとりあえず最初に実装するコードをまとめました。

UMAPを使ったembeddingの可視化とか faiss を使った検索とかはこの記事で扱ってないです。

1. Metric Learning って何?

予測値じゃなくて特徴量間の距離に注目して学習する方法

  • 同じクラス内ではなるべく近い距離になるように
  • 違うクラス間ではなるべく遠い距離になるように

もっと詳しくしたい人は Qiita 内でもいい記事たくさんあるのでどうぞ。

2. どんなときに採用するの?

例えば下記みたいなケースで採用の検討を加速します。

  • Other クラスが存在するとき
    • 異常検知みたいなことが必要なとき
  • 類似したサンプルを複数個取得したいとき
    • 画像検索とか

Kaggle で代表的なのは Happywhale - Whale and Dolphin Identification | Kaggle などがあります。

最近の Kaggle では単純な Classification モデルのみで取り組めるコンペが気持ち少なくなっており、 UBC-OCEAN のように Other クラスを用意するなど、 クソ舐めた 挑戦的なコンペが生まれています。

2.1 Classification モデルじゃダメなの?

特徴量を取得するだけなら Classification モデルを学習するだけでも実現できます。
この特徴量と他の特徴量を組み合わせて後段の予測をするという処理も見かけます。

一方で下記の例のような特徴量の距離に応じた処理を考える場合、Classification モデルだとうまくいかないことがあります。

  • 距離が近いものをN個取得したい
  • 距離が一定以上離れている場合 Other(とかNew)クラスとしたい

というのも Classification の場合予測値に対する Loss が最小になるように学習しているので特徴量間の距離は考慮していないからです(それはそう)。実際 Classification で学習させて距離を計算すると極端な値(0 か max)になることが多く、しきい値を設定して Other クラスを検出するには向いていないと個人的に感じました。

3. Kaggler側の要求(要件)

自分がコンペでとりあえず試すとなったときは以下を満たしていると嬉しかったりします。

  • 気軽に試せてェ〜
    • → 一応下記のコードなら image, label_id を返すような dataloader 作れば動く(はず)
    • → 今回は ArcFace(Loss) を採用したのでほぼ Classification の延長線として扱える
  • 色々変更できてェ〜
    • → 一応下記のコードなら色々変更できる
  • できれば精度も良くてェ〜
    • → それは頑張れ
  • いい感じのボードで結果も見れてェ〜
    • → 必要なら WandbLogger 呼び出して適用できる

4. 最近のMetric Learningの始め方(本題)

というわけで上記の要件を満たせるよう、今回は下記のライブラリを使用して実装します。

MetricLearning はいろんな手法がありますが、今回は ArcFace(Loss) を採用します。理由としては Classification モデルと大体同じ使い方で学習できるためです。

今回の実装例は CIFAR-10 を使って学習しますがスクリプト実行時にデータがダウンロードされるので事前ダウンロードは不要です。

4.1 各ライブラリ紹介

  • pytorch
  • pytorch lightning(lightning)
  • lightning bolts
    • pytorch lightning の拡張機能を使えるライブラリ
    • 今回は lr scheduler の呼び出しや CIFAR-10 datasetの作成に使う
  • torchmetrics
    • 名前は違うけど pytorch lightning プロジェクトの一つ
    • log 用 metrics の集計に使う
      • 今回は関係ないが DDP 学習時の結果集約などで役に立つ
    • 今回は train embeddings の集計にも使う
  • timm
    • 言わずと知れた Classification モデルライブラリ
    • 今回は feature extractor として使う
  • PyTorch Metric learning
    • Metric Learning周りのツールがまとまってる
    • 今回は ArcFaceLoss や距離計算用の関数を呼び出すために使ってる
  • 今回使用しないもの

4.1.1 timm の使い方

  • 今回は feature extractor として使います
    • timm.create_model() 時に num_classes=0 とすると (last) feature extractor となります
    • timm.create_model(~~~, num_classes=0, global_pool='') とすると pool 前の last feature が出力されます
    • 詳しい使い方は下記を参照してください
  • feature size が ArcFaceLoss の __init__() 実行時に必要ですので feature_info から取得します
    • 適当に input 作って出力から size を確認することもできます
In [1]: import torch
In [2]: import timm
In [3]: timm.__version__
Out[3]: '0.9.7'

In [4]: m = timm.create_model('resnet18d', pretrained=False, num_classes=0)
In [5]: m.feature_info[-1]["num_chs"]
Out[5]: 512

In [6]: m(torch.rand(1, 3, 256, 256)).size()
Out[6]: torch.Size([1, 512])

4.1.2 torchmetrics の使い方

  • Accuracy — PyTorch-Metrics 1.1.0 documentation
  • 結果のappendとmetricsの計算をまとめたもの
  • 今回は batch 毎に update() で結果を格納していき、最後に compute() で結果を集約するようにしています

In [1]: import torch
In [2]: from torch import tensor
In [3]: from torchmetrics import Accuracy, MeanMetric
In [4]: n_class = 4
In [5]: bs = 4
In [6]: acc = Accuracy(task="multiclass", num_classes=n_class)

# 何もデータを格納していないときに compute() すると warning が出る
In [7]: acc.compute()
/usr/local/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
Out[7]: tensor(0.)

# 一気に計算するパターン
In [8]: target = tensor([0, 1, 2, 3])
In [9]: preds = tensor([0, 2, 1, 3])
In [10]: acc(preds, target)
Out[10]: tensor(0.5000)

# 一度計算すると acc 内部に target と preds が蓄積される
In [11]: acc.compute()
Out[11]: tensor(0.5000)

# 内部データをリセット
In [12]: acc.reset()
In [13]: acc.compute()
/usr/local/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
  warnings.warn(*args, **kwargs)  # noqa: B028
Out[13]: tensor(0.)

# preds は (bs, n_class) でも渡せる
In [14]: preds = torch.rand(bs, n_class).softmax(dim=-1)
In [15]: preds
Out[15]:
tensor([[0.2458, 0.2113, 0.3080, 0.2348],
        [0.2733, 0.3262, 0.1925, 0.2079],
        [0.1881, 0.4060, 0.2464, 0.1594],
        [0.3179, 0.1644, 0.2027, 0.3151]])

In [16]: acc(preds, target)
Out[16]: tensor(0.2500)

In [17]: acc.reset()

# preds は logits でも良い
In [18]: preds = torch.rand(bs, n_class) - 0.5
In [19]: preds
Out[19]:
tensor([[ 0.0214, -0.0699, -0.1490, -0.0610],
        [ 0.0845, -0.3965,  0.0996,  0.0343],
        [-0.3506, -0.0357,  0.0570, -0.0718],
        [ 0.1613, -0.4284,  0.0166,  0.3271]])

In [20]: acc(preds, target)
Out[20]: tensor(0.7500)

In [21]: acc.reset()

# update() で結果を蓄積して最後に compute() で結果を算出することも可能
In [22]: for _ in range(10):
    ...:     acc.update(preds=preds, target=target)
    ...:

In [23]: acc.compute()
Out[23]: tensor(0.7500)

In [24]: acc.reset()

4.2 実行環境

各ライブラリのバージョンは以下の通りです。

torch==2.0.0+cu118
torchvision==0.15.1+cu118
pytorch-metric-learning==2.3.0
lightning==2.0.8
lightning-bolts==0.7.0
torchmetrics==1.2.0
timm==0.9.7

4.3 実装全体

key value
Dataset CIFAR-10
Model Resnet18d (pretrain=True)
Loss ArcFaceLoss
Optimizer AdamW
Scheduler CosineAnnealing + warmup
Data augmentation RandomCrop + RandomHflip
How to pred class L2 distance from train mean embedding
train.py
import os
from pathlib import Path

import lightning as L
import timm
import torch
import torch.nn.functional as F
import torchvision
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_metric_learning.distances import LpDistance
from pytorch_metric_learning.losses import ArcFaceLoss
from torch.optim import AdamW
from torchmetrics import Accuracy, MeanMetric, MetricCollection
from torchmetrics.aggregation import CatMetric


class MyLightningModule(L.LightningModule):
    def __init__(
        self,
        num_classes: int = 10,
        max_epochs: int = 30,
        init_lr: float = 3e-4,
        arcface_margin: float = 28.6,
        arcface_scale: int = 64,
    ):
        super().__init__()
        self.max_epochs = max_epochs
        self.init_lr = init_lr

        # model と loss
        # Create model(num_classes=0) にすると feature extractor となる
        self.net = timm.create_model("resnet18d", pretrained=True, num_classes=0)
        emb_size = self.net.feature_info[-1]["num_chs"]

        # margin: The paper uses 0.5 radians, which is 28.6 degrees.
        self.loss = ArcFaceLoss(
            num_classes=num_classes,
            embedding_size=emb_size,
            margin=arcface_margin,
            scale=arcface_scale,
        )

        # log に残す metrics 用
        self.metrics = MetricCollection(
            dict(
                train_loss=MeanMetric(),
                val_loss=MeanMetric(),
                val_acc_macro=Accuracy(
                    task="multiclass", num_classes=num_classes, average="macro"
                ),
            )
        )

        # validation data のクラス予測時に使う用
        self.train_data = MetricCollection(
            dict(embeddings=CatMetric(), labels=CatMetric())
        )
        self.train_mean_embeddings = torch.rand(num_classes, emb_size)
        self.dist_func = LpDistance(p=2)  # L2 distance

    def forward(self, x):
        embeddings = self.net(x)
        return embeddings

    def training_step(self, batch, batch_nb):
        x, y = batch
        embeddings = self.forward(x)
        loss = self.loss(embeddings, y)
        self.metrics["train_loss"].update(loss.detach())
        self.train_data["embeddings"].update(F.normalize(embeddings.detach()))
        self.train_data["labels"].update(y.detach())
        return loss

    def on_train_epoch_end(self):
        result = self.train_data.compute()
        self.train_data.reset()
        labels = result["labels"]
        embeddings = result["embeddings"]

        # calc train mean feature for each class
        for class_id in range(len(self.train_mean_embeddings)):
            embeddings_tmp = embeddings[labels == class_id]
            self.train_mean_embeddings[class_id] = embeddings_tmp.mean(dim=0)

    def validation_step(self, batch, batch_nb):
        x, y = batch
        embeddings = self.forward(x)
        loss = self.loss(embeddings, y)
        self.metrics["val_loss"].update(loss)

        # train feature mean と距離を計算して matrix にする (val_batch x n_class)
        dist_matrix = self.dist_func(
            F.normalize(embeddings),
            self.train_mean_embeddings.to(embeddings.device),
        )

        # train feature mean と比較して近いものを予測クラスとして出力
        preds = dist_matrix.argmin(dim=1)
        self.metrics["val_acc_macro"].update(preds=preds, target=y)
        return loss

    def on_validation_epoch_end(self):
        log_tmp = dict(epoch=int(self.current_epoch))
        log_metrics = self.metrics.compute()
        log_metrics = {k: v.item() for k, v in log_metrics.items()}
        log_tmp.update(log_metrics)
        self.metrics.reset()
        self.log_dict(log_tmp, prog_bar=True, sync_dist=True)

    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(), lr=self.init_lr, weight_decay=1e-6, eps=1e-7
        )
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=5,
            max_epochs=self.max_epochs,
            warmup_start_lr=self.init_lr / 10.0,
            eta_min=1e-6,
            last_epoch=-1,
        )
        # interval: step or epoch
        scheduler = {
            "scheduler": scheduler,
            "interval": "epoch",
            "frequency": 1,
        }
        return [optimizer], [scheduler]


def main():
    output_path = Path("output")
    output_path.mkdir(parents=True, exist_ok=True)

    L.seed_everything(42)

    # CIFAR-10 datset
    train_transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            cifar10_normalization(),
        ]
    )
    test_transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            cifar10_normalization(),
        ]
    )
    cifar10_dm = CIFAR10DataModule(
        data_dir=".",
        batch_size=256,
        num_workers=6,
        train_transforms=train_transforms,
        test_transforms=test_transforms,
        val_transforms=test_transforms,
    )
    cifar10_dm.prepare_data()
    cifar10_dm.setup()

    # LightningModule
    max_epochs = 30
    model = MyLightningModule(
        num_classes=10,  # CIFAR-10 なので 10 クラス
        max_epochs=max_epochs,
        init_lr=1e-3,
        arcface_margin=28.6,  # チューニング要素
        arcface_scale=64,  # チューニング要素
    )

    # W in ArcFaceLoss is target of learning
    for name, _ in model.named_parameters(recurse=True):
        if "loss" in name:
            print("learning target in loss: ", name)

    # WandbLogger とか使いたい場合は loggers に append する
    loggers = [CSVLogger(save_dir=output_path, name="demo")]
    checkpoint_callback = ModelCheckpoint(
        dirpath=output_path,
        filename="sample",
        save_weights_only=True,
        monitor=None,
    )
    trainer = L.Trainer(
        logger=loggers,
        callbacks=[checkpoint_callback],
        default_root_dir=os.getcwd(),
        accelerator="gpu",
        strategy="ddp",
        devices=1,
        precision="16-mixed",  # 32-true or 16-mixed
        max_epochs=max_epochs,
        deterministic=False,
    )

    # Start train
    trainer.fit(
        model,
        train_dataloaders=cifar10_dm.train_dataloader(),
        val_dataloaders=cifar10_dm.val_dataloader(),
    )


if __name__ == "__main__":
    main()

4.4 実装のポイント

  • チューニング要素(e.g. ArcFace のハイパーパラメータ等)は外に出しておく
  • timm.create_model の引数で num_classes=0 とする
    • feature extractor として使いたいため
    • resnet18d はいろんなものに変更できる
  • embedding size を逐一調べなくて良いようにする
    • timm なら net.feature_info[-1]["num_chs"] で取得できる
  • metric learning の部分を Loss に任せる
    • self.loss = ArcFaceLoss(num_classes=n_class, embedding_size=emb_size)
    • siamese network とかは input から作成する必要があるけど ArcFace や CosFace なら Loss だけ変えて実装できるので気軽に試せる
  • torch metrics を使って結果を集約している
    • これで DDP で学習したときとかも安心して使える
  • loss だけじゃなくて accuracy も見れるようにした
    • train の各クラスの平均特徴量を作成して、その L2 distance が最も近いクラスを予測クラスとした
      • dist_matrix = self.dist_func(val_embeddings, train_mean_embeddings)
      • preds = dist_matrix.argmin(dim=1)
    • top N 個とか取りたい場合は argmin をいい感じに変更する

4.5 実行結果

上記のコードであれば pretrained weight とか CIFAR-10 とか事前にダウンロードしてなくても大丈夫です。

$ python train.py
~~~
learning target in loss:  loss.W
~~~
`Trainer.fit` stopped: `max_epochs=30` reached.

Epoch 29: 100%|██| 157/157 [00:11<00:00, 13.94it/s, v_num=0, epoch=29.00
, train_loss=5.400, val_acc_macro=0.874, val_loss=7.150]

結果は下記のように格納されています。

$ ls output/
demo  sample-v1.ckpt

$ ls output/demo/version_0/
hparams.yaml  metrics.csv

$ head -n 5 output/demo/version_0/metrics.csv
epoch,train_loss,val_acc_macro,val_loss,step
0.0,34.21628189086914,0.09846183657646179,33.21874237060547,156
1.0,27.717126846313477,0.6571612358093262,20.311311721801758,313
2.0,16.87421417236328,0.762157142162323,14.048077583312988,470
3.0,14.297616958618164,0.7699860334396362,13.174238204956055,627

4.6 注意事項: ArcFaceLoss.W も学習対象

  • 今回使用した pytorch_metric_learning.losses.ArcFaceLoss 内の W も学習対象であることに注意してください
    • embeddingn_class 次元に線形変換するための重み
    • timm model で言う classifier から bias を取ったものと大体同じ
    • nn.Module の定義抜きで簡潔に実装できているのはこれのおかげ
  • pytorch lightning の場合デフォルトで学習対象に含まれるため追加の設定は不要
    • 自動で self.parameters() に含まれる
      • 上記実行結果の learning target: loss.W から確認できる
  • 素の pytorch のみで学習させる場合は optimizer 定義時に loss.parameters() も渡すようにしてください
    • もしくは loss.parameters() 用の optimizer を別に用意する
      • その場合 optimizer 毎に zero_grad()step() が必要になる
optimizerにlossのparamsも渡す方法
optimizer = AdamW(
    [{'params': model.parameters()}, {'params': loss.parameters()}], 
    lr=3e-4,
    weight_decay=1e-6,
    eps=1e-7
)

5. 補足

5.1 Wandb Logger 使いたい

Step1. wandb にログインする

  • コンソールで $wandb login できる環境なら予めしておくと楽
    • その場合下記の対応は不要
  • 難しい場合、下記のような api_key.json を用意して環境変数に入力する
api_key.json
{
	"wandb": "XXXXXXXXXXXXXXXXXXXXXXXXX"
}
環境変数に入力
with open("./api_keys.json", "r") as f:
    key = json.load(f)["wandb"]
    os.environ["WANDB_API_KEY"] = key

Step2. loggers を変更

  • offline=True にすると local 環境だけで実行されるので、offline=debug とかにするのはおすすめ
    • wandb web console 上で debug 用の log を消す必要がなくなる
from lightning.pytorch.loggers import CSVLogger, WandbLogger

loggers = [
    CSVLogger(save_dir=output_path, name=f"fold_{fold}"),
    WandbLogger(
        project="Sample-MetricLearning",
        group="group1",
        name="sample_fold_0",
        offline=False,
    ),
]

5.2 上記のスクリプトで学習したモデルで予測したい

  • load_from_checkpoint で読み込める
weight_path = "output/sample-v1.ckpt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MyLightningModule.load_from_checkpoint(
    weight_path, 
    num_classes=10,
    max_epochs=30,
    init_lr=1e-3,
    arcface_margin=28.6,
    arcface_scale=64,
).net
model = model.to(device)
model = model.eval()

5.3 自分のデータで学習したい

  • image, label_id を返すような dataset および data_loader を作成する
    • label_id はクラス数 n のときは [0, ..., n-1]
  • データセットに応じた num_classesMyLightningModule 初期化時に引数として与える
dataset と dataloader の例
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, image_paths, label_ids, transform=None):
        self.image_paths = image_paths
        self.label_ids = label_ids
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.img_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image=image)["image"].float()

        label_id = torch.tensor(self.label_ids[idx]).long()

        return image, label_id

train_ds = MyDataset(...)
valid_ds = MyDataset(...)
train_loader = DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    drop_last=True
)
valid_loader = DataLoader(
    valid_ds,
    batch_size=16,
    shuffle=False,
    num_workers=4,
    drop_last=False
)

~~~~~~

model = MyLightningModule(num_classes=n_classes, ...)

~~~~~~

trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)
43
35
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
43
35

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?