こんにちは!いつもマルチクラス分類のときの評価指標の計算方法を忘れてしまうのでメモ書きです。
sklearn.metricsを使ってもできますが、深層学習でログを取りたいときは都度cpuに戻さず計算できるほうが楽なので今回はtorchmetricsを使います。
torchmetricsのポイント
- 最初に
metrics
を宣言する - バッチごとに
update
で部分的に足し算 - epochの最後に
compute
で最終値を出す
torchmetricsの使い方
1. 最初に metrics
を宣言する
最初に計算したいmetricsの関数を呼び出して変数を用意しておきます。
二値分類のときは、task="binary"に変わるだけです。
acc = Accuracy(task="multiclass", num_classes=num_classes, average='micro').to(device)
f1 = F1Score(task="multiclass", num_classes=num_classes, average='macro').to(device)
prec = Precision(task="multiclass", num_classes=num_classes, average='macro').to(device)
rec = Recall(task="multiclass", num_classes=num_classes, average='macro').to(device)
auroc = AUROC(task="multiclass", num_classes=num_classes).to(device)
cm = ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)
また、lossの計算も行いたいので最初に定義しておきます。
合計のサンプル数は上記のmetricsの計算でも使うのでこちらも定義します。
val_loss_sum = 0.0
total_val_samples = 0
2. バッチごとに update
で部分的に足し算
dataloaderがfor文でまわるごとに指標を更新していきます。以下はvalidの例です。
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs) # [B, num_classes]で各行が1サンプル、各列がそのクラスに対するスコア(ロジット)
loss = criterion(logits, labels)
probs = torch.softmax(logits, dim=1) # 行ごとに確率に正規化
preds = torch.argmax(logits, dim=1) # 行ごとに一番大きい列のインデックスを返す
# state更新
acc.update(preds, labels)
f1.update(preds, labels)
prec.update(preds, labels)
rec.update(preds, labels)
auroc.update(probs, labels) # AUROCは確率を入力
cm.update(preds, labels)
# バッチごとのサンプル数
batch_size = labels.size(0) # labels: [B]という一次元テンソル
# バッチの平均損失とサンプル数をかけて合計損失を計算
val_loss_sum += loss.item() * batch_size # lossは0次元テンソル→item()でPythonのfloatへ
# サンプル数を合計
total_val_samples += batch_size
補足
- モデルの出力
logits
(バッチサイズ B=3, クラス数=4 の場合)
logits =
tensor([[ 2.1, -1.3, 0.5, 3.0],
[-0.8, 0.2, 1.2, -2.5],
[ 0.1, 2.2, -1.0, 0.4]])
# 形: [3, 4]
-
torch.softmax(logits, dim=1)
softmax は行ごとに適用する。このとき、[i, j] は「i番目のサンプルがクラスjである確率」。
probs =
tensor([[0.241, 0.010, 0.059, 0.690],
[0.088, 0.239, 0.649, 0.024],
[0.080, 0.689, 0.030, 0.201]])
-
torch.argmax(logits, dim=1)
argmaxで行ごとに一番大きい列のインデックスを返す → 最も確からしいクラスの番号を取ってくる。
preds =
tensor([3, 2, 1])
3. epochの最後に compute
で最終値を出す
今まで累積してきた値から、最終的な値を出します。
compute()
は0次元テンソル(スカラー)を返すので、.item()
でPythonのfloatに変換します。
混同行列はスカラー値ではなく、その名のとおり「行列」であり最終的に可視化することが多いのでNumpy配列にしておきます。ここでそのまま.numpy()
を呼ぶとRuntimeError
になるので.to('cpu')
を挟んでCPUに移してからnumpy形式に変換します。
val_loss = val_loss_sum / max(total_val_samples, 1)
val_acc = acc.compute().item()
val_auc = auroc.compute().item()
val_rc = rec.compute().item()
val_prec = prec.compute().item()
val_f1 = f1.compute().item()
confmat = cm.compute().to('cpu').numpy()
全体の流れ
個人的にtrainはlossとaccぐらいしか計算しないので省略して、検証用の関数の全体像を示します。
def evaluate(model, num_classes, val_loader, criterion, device):
model.eval()
acc = Accuracy(task="multiclass", num_classes=num_classes, average='micro').to(device)
f1 = F1Score(task="multiclass", num_classes=num_classes, average='macro').to(device)
prec = Precision(task="multiclass", num_classes=num_classes, average='macro').to(device)
rec = Recall(task="multiclass", num_classes=num_classes, average='macro').to(device)
auroc = AUROC(task="multiclass", num_classes=num_classes).to(device)
cm = ConfusionMatrix(task='multiclass', num_classes=num_classes).to(device)
val_loss_sum = 0.0
total_val_samples = 0
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
logits = model(imgs)
loss = criterion(logits, labels)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(logits, dim=1)
# state更新
acc.update(preds, labels)
f1.update(preds, labels)
prec.update(preds, labels)
rec.update(preds, labels)
auroc.update(probs, labels)
cm.update(preds, labels)
batch_size = labels.size(0)
val_loss_sum += loss.item() * batch_size
total_val_samples += batch_size
val_loss = val_loss_sum / max(total_val_samples, 1)
val_acc = acc.compute().item()
val_auc = auroc.compute().item()
val_rc = rec.compute().item()
val_prec = prec.compute().item()
val_f1 = f1.compute().item()
confmat = cm.compute().to('cpu').numpy()
return val_loss, val_acc, val_auc, val_rc, val_prec, val_f1, confmat
torchmetricsをconfigで管理する
今までは、関数内に評価指標を書いていましたが、可読性が悪いので慣れたらconfigで管理してしまいます。
使いやすい関数として、以下のレポジトリから拝借させていただきました。
https://github.com/bakqui/ST-MEM/blob/main/util/perf_metrics.py
一発でconfigからmetricsをとってこれる関数
def build_metric_fn(config: dict) -> Tuple[torchmetrics.Metric, Dict[str, float]]:
common_metric_fn_kwargs = {"task": config["task"],
"compute_on_cpu": config["compute_on_cpu"], # compute() 時に CPU へ移してから最終計算
"sync_on_compute": config["sync_on_compute"] # 分散学習用
}
# タスクに応じてクラス数/ラベル数を指定(binaryの場合は設定不要)
# 多クラス分類のときは出力[N, C]の確率スコア(softmax後)を渡すと内部でargmaxを取ってくれる
if config["task"] == "multiclass":
assert "num_classes" in config, "num_classes must be provided for multiclass task"
# クラス数を明示
common_metric_fn_kwargs["num_classes"] = config["num_classes"]
# 複数ラベルのときは出力[N, L]の確率スコア(sigmoid後)を渡すとthresholdで各ラベルごとに0/1に変換
elif config["task"] == "multilabel":
# ラベル数を明示
assert "num_labels" in config, "num_labels must be provided for multilabel task"
common_metric_fn_kwargs["num_labels"] = config["num_labels"]
metric_list = []
for metric_class_name in config["target_metrics"]:
# 辞書
if isinstance(metric_class_name, dict):
# e.g., {"AUROC": {"average": macro}}
assert len(metric_class_name) == 1, f"Invalid metric name: {metric_class_name}"
metric_class_name, metric_fn_kwargs = list(metric_class_name.items())[0]
metric_fn_kwargs.update(common_metric_fn_kwargs)
# 文字列の場合はデフォルト設定
else:
metric_fn_kwargs = common_metric_fn_kwargs
assert isinstance(metric_class_name, str), f"metric name must be a string: {metric_class_name}"
assert hasattr(torchmetrics, metric_class_name), f"Invalid metric name: {metric_class_name}"
# クラスを取得してインスタンス化
metric_class = getattr(torchmetrics, metric_class_name)
metric_fn = metric_class(**metric_fn_kwargs)
metric_list.append(metric_fn)
# MetricCollectionを作成
metric_fn = torchmetrics.MetricCollection(metric_list)
# best_metricsの初期化
best_metrics = {
k: -float("inf") if v.higher_is_better else float("inf") # 大きいほど良いなら初期値は-inf
for k, v in metric_fn.items()
}
return metric_fn, best_metrics
ベスト保存ロジック
def is_best_metric(metric_class: torchmetrics.Metric,
prev_metric: float,
curr_metric: float) -> bool:
# torchmetricsの最適化方向を利用
higher_is_better = metric_class.higher_is_better
if higher_is_better:
return curr_metric > prev_metric
else:
return curr_metric < prev_metric
configの書き方
metric:
task: multiclass
compute_on_cpu: true
sync_on_compute: false # 分散学習をするかどうか
num_classes: 3
target_metrics:
- Accuracy:
average: micro
- F1Score:
average: macro
- AUROC:
average: macro
- Recall:
average: macro
- Precision:
average: macro
- ConfusionMatrix
この関数を使うと先ほどの3ステップは以下のように省略できます。
1. 最初に metrics
を宣言する
先ほど作った関数を呼び出し、metric_fn
をGPU側へ送ります。
なお、関数を呼び出すたびにbest_metrics
が初期化されてしまう(±infの辞書)ので、更新用にbest
を作成します。
metric_fn, best_metrics = build_metric_fn(metric_cfg)
metric_fn = metric_fn.to(device)
best = best_metrics.copy() # 更新用にbestを作る
2. バッチごとに update
で部分的に足し算
まとめて update()
します。
metric_fn.update(probs, labels)
3. epochの最後に compute
で最終値を出す
まとめて compute()
します。
computed = metric_fn.compute()
4. (補足)ログのとり方
# 辞書にキーワードと値を格納する
out = {}
for k, v in computed.items():
out[k] = v.item() if v.ndim == 0 else v.to('cpu').numpy() # ログ用に型変換
out["loss"] = total_loss / max(total_n, 1)
# best更新
for k, v in computed.items():
if v.ndim != 0:
# ConfusionMatrix は除外
continue
curr = v.item()
metric_obj = metric_fn[k] # MetricCollection から該当メトリクスを取得
prev = best[k] # これまでのベスト
if is_best_metric(metric_obj, prev, curr):
best[k] = curr
# loss のベスト更新
if "best_loss" not in best:
best["best_loss"] = float("inf")
if out["loss"] < best["best_loss"]:
best["best_loss"] = out["loss"]
metric_fn.reset() # 再利用するなら必要
以上です。読んでいただきありがとうございました。