0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【torchmetrics】マルチクラス分類

Last updated at Posted at 2025-08-26

こんにちは!いつもマルチクラス分類のときの評価指標の計算方法を忘れてしまうのでメモ書きです。
sklearn.metricsを使ってもできますが、深層学習でログを取りたいときは都度cpuに戻さず計算できるほうが楽なので今回はtorchmetricsを使います。

torchmetricsのポイント

  1. 最初に metrics を宣言する
  2. バッチごとに update で部分的に足し算
  3. epochの最後に compute で最終値を出す

torchmetricsの使い方

1. 最初に metrics を宣言する

最初に計算したいmetricsの関数を呼び出して変数を用意しておきます。
二値分類のときは、task="binary"に変わるだけです。

acc   = Accuracy(task="multiclass", num_classes=num_classes, average='micro').to(device)
f1    = F1Score(task="multiclass", num_classes=num_classes, average='macro').to(device)
prec  = Precision(task="multiclass", num_classes=num_classes, average='macro').to(device)
rec   = Recall(task="multiclass", num_classes=num_classes, average='macro').to(device)
auroc = AUROC(task="multiclass", num_classes=num_classes).to(device)
cm    = ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)

また、lossの計算も行いたいので最初に定義しておきます。
合計のサンプル数は上記のmetricsの計算でも使うのでこちらも定義します。

val_loss_sum  = 0.0
total_val_samples = 0

2. バッチごとに update で部分的に足し算

dataloaderがfor文でまわるごとに指標を更新していきます。以下はvalidの例です。

for imgs, labels in val_loader:
    imgs, labels = imgs.to(device), labels.to(device)
    logits = model(imgs) # [B, num_classes]で各行が1サンプル、各列がそのクラスに対するスコア(ロジット)
    loss = criterion(logits, labels)

    probs = torch.softmax(logits, dim=1) # 行ごとに確率に正規化
    preds = torch.argmax(logits, dim=1) # 行ごとに一番大きい列のインデックスを返す
            
    # state更新
    acc.update(preds, labels)
    f1.update(preds, labels)
    prec.update(preds, labels)
    rec.update(preds, labels)
    auroc.update(probs, labels) # AUROCは確率を入力
    cm.update(preds, labels)

    # バッチごとのサンプル数
    batch_size = labels.size(0) # labels: [B]という一次元テンソル
    # バッチの平均損失とサンプル数をかけて合計損失を計算
    val_loss_sum += loss.item() * batch_size # lossは0次元テンソル→item()でPythonのfloatへ
    # サンプル数を合計
    total_val_samples += batch_size

補足

  • モデルの出力 logits(バッチサイズ B=3, クラス数=4 の場合)
logits =
tensor([[ 2.1, -1.3,  0.5,  3.0],
        [-0.8,  0.2,  1.2, -2.5],
        [ 0.1,  2.2, -1.0,  0.4]])
# 形: [3, 4]
  • torch.softmax(logits, dim=1)
    softmax は行ごとに適用する。このとき、[i, j] は「i番目のサンプルがクラスjである確率」。
probs =
tensor([[0.241, 0.010, 0.059, 0.690],
        [0.088, 0.239, 0.649, 0.024],
        [0.080, 0.689, 0.030, 0.201]])
  • torch.argmax(logits, dim=1)
    argmaxで行ごとに一番大きい列のインデックスを返す → 最も確からしいクラスの番号を取ってくる。
preds =
tensor([3, 2, 1])

3. epochの最後に compute で最終値を出す

今まで累積してきた値から、最終的な値を出します。
compute()は0次元テンソル(スカラー)を返すので、.item()でPythonのfloatに変換します。
混同行列はスカラー値ではなく、その名のとおり「行列」であり最終的に可視化することが多いのでNumpy配列にしておきます。ここでそのまま.numpy()を呼ぶとRuntimeErrorになるので.to('cpu')を挟んでCPUに移してからnumpy形式に変換します。

val_loss = val_loss_sum / max(total_val_samples, 1)
    
val_acc = acc.compute().item()
val_auc = auroc.compute().item()
val_rc = rec.compute().item()
val_prec = prec.compute().item()
val_f1 = f1.compute().item()
confmat = cm.compute().to('cpu').numpy()

全体の流れ

個人的にtrainはlossとaccぐらいしか計算しないので省略して、検証用の関数の全体像を示します。

def evaluate(model, num_classes, val_loader, criterion, device):
    model.eval()
    
    acc   = Accuracy(task="multiclass", num_classes=num_classes, average='micro').to(device)
    f1    = F1Score(task="multiclass", num_classes=num_classes, average='macro').to(device)
    prec  = Precision(task="multiclass", num_classes=num_classes, average='macro').to(device)
    rec   = Recall(task="multiclass", num_classes=num_classes, average='macro').to(device)
    auroc = AUROC(task="multiclass", num_classes=num_classes).to(device)
    cm    = ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)
    
    val_loss_sum  = 0.0
    total_val_samples = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)

            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            # state更新
            acc.update(preds, labels)
            f1.update(preds, labels)
            prec.update(preds, labels)
            rec.update(preds, labels)
            auroc.update(probs, labels)
            cm.update(preds, labels)

            batch_size = labels.size(0)
            val_loss_sum += loss.item() * batch_size
            total_val_samples += batch_size
        
    val_loss = val_loss_sum / max(total_val_samples, 1)
    
    val_acc = acc.compute().item()
    val_auc = auroc.compute().item()
    val_rc = rec.compute().item()
    val_prec = prec.compute().item()
    val_f1 = f1.compute().item()
    confmat = cm.compute().to('cpu').numpy()

    return val_loss, val_acc, val_auc, val_rc, val_prec, val_f1, confmat

torchmetricsをconfigで管理する

今までは、関数内に評価指標を書いていましたが、可読性が悪いので慣れたらconfigで管理してしまいます。

使いやすい関数として、以下のレポジトリから拝借させていただきました。
https://github.com/bakqui/ST-MEM/blob/main/util/perf_metrics.py

一発でconfigからmetricsをとってこれる関数

def build_metric_fn(config: dict) -> Tuple[torchmetrics.Metric, Dict[str, float]]:
    
    common_metric_fn_kwargs = {"task": config["task"],
                            "compute_on_cpu": config["compute_on_cpu"], # compute() 時に CPU へ移してから最終計算
                            "sync_on_compute": config["sync_on_compute"] # 分散学習用
                            }
    # タスクに応じてクラス数/ラベル数を指定(binaryの場合は設定不要)
    # 多クラス分類のときは出力[N, C]の確率スコア(softmax後)を渡すと内部でargmaxを取ってくれる
    if config["task"] == "multiclass": 
        assert "num_classes" in config, "num_classes must be provided for multiclass task"
        # クラス数を明示
        common_metric_fn_kwargs["num_classes"] = config["num_classes"]
    # 複数ラベルのときは出力[N, L]の確率スコア(sigmoid後)を渡すとthresholdで各ラベルごとに0/1に変換
    elif config["task"] == "multilabel":
        # ラベル数を明示
        assert "num_labels" in config, "num_labels must be provided for multilabel task"
        common_metric_fn_kwargs["num_labels"] = config["num_labels"]

    metric_list = []
    for metric_class_name in config["target_metrics"]:
        # 辞書
        if isinstance(metric_class_name, dict):
            # e.g., {"AUROC": {"average": macro}}
            assert len(metric_class_name) == 1, f"Invalid metric name: {metric_class_name}"
            metric_class_name, metric_fn_kwargs = list(metric_class_name.items())[0]
            metric_fn_kwargs.update(common_metric_fn_kwargs)
        # 文字列の場合はデフォルト設定
        else:
            metric_fn_kwargs = common_metric_fn_kwargs
        assert isinstance(metric_class_name, str), f"metric name must be a string: {metric_class_name}"
        assert hasattr(torchmetrics, metric_class_name), f"Invalid metric name: {metric_class_name}"
        # クラスを取得してインスタンス化
        metric_class = getattr(torchmetrics, metric_class_name)
        metric_fn = metric_class(**metric_fn_kwargs)
        metric_list.append(metric_fn)
    
    # MetricCollectionを作成
    metric_fn = torchmetrics.MetricCollection(metric_list)
    
    # best_metricsの初期化
    best_metrics = {
        k: -float("inf") if v.higher_is_better else float("inf") # 大きいほど良いなら初期値は-inf
        for k, v in metric_fn.items()
    }

    return metric_fn, best_metrics

ベスト保存ロジック

def is_best_metric(metric_class: torchmetrics.Metric,
                   prev_metric: float,
                   curr_metric: float) -> bool:
    # torchmetricsの最適化方向を利用
    higher_is_better = metric_class.higher_is_better
    if higher_is_better:
        return curr_metric > prev_metric
    else:
        return curr_metric < prev_metric

configの書き方

metric:
  task: multiclass
  compute_on_cpu: true
  sync_on_compute: false # 分散学習をするかどうか
  num_classes: 3
  target_metrics:
  - Accuracy:
      average: micro
  - F1Score:
      average: macro
  - AUROC:
      average: macro
  - Recall:
      average: macro
  - Precision:
      average: macro
  - ConfusionMatrix

この関数を使うと先ほどの3ステップは以下のように省略できます。

1. 最初に metrics を宣言する

先ほど作った関数を呼び出し、metric_fnをGPU側へ送ります。
なお、関数を呼び出すたびにbest_metricsが初期化されてしまう(±infの辞書)ので、更新用にbestを作成します。

metric_fn, best_metrics = build_metric_fn(metric_cfg)
metric_fn = metric_fn.to(device)
best = best_metrics.copy() # 更新用にbestを作る

2. バッチごとに update で部分的に足し算

まとめて update() します。

metric_fn.update(probs, labels) 

3. epochの最後に compute で最終値を出す

まとめて compute() します。

computed = metric_fn.compute()

4. (補足)ログのとり方

# 辞書にキーワードと値を格納する
out = {}
for k, v in computed.items():
    out[k] = v.item() if v.ndim == 0 else v.to('cpu').numpy() # ログ用に型変換
out["loss"] = total_loss / max(total_n, 1)

# best更新
for k, v in computed.items():
    if v.ndim != 0:
        # ConfusionMatrix は除外
        continue
    curr = v.item()
    metric_obj = metric_fn[k] # MetricCollection から該当メトリクスを取得
    prev = best[k] # これまでのベスト
    if is_best_metric(metric_obj, prev, curr):
        best[k] = curr

# loss のベスト更新
if "best_loss" not in best:
    best["best_loss"] = float("inf")
if out["loss"] < best["best_loss"]:
    best["best_loss"] = out["loss"]

metric_fn.reset()  # 再利用するなら必要

以上です。読んでいただきありがとうございました。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?