はじめに
TorchMetricsとはPyTorchやPyTorch Lightningでサポートされているメトリクス算出用の抽象クラスである。
たとえばPytorchだとこんな感じにとても簡単にかける。
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
# Reseting internal state such that metric ready for new data
metric.reset()
カスタムメトリクスの作成
torchmetricsモジュールに基本的なライブラリは登録されているがどうしても自分で作りたい時がある。
その際はMetricクラスを継承すると同じようにかける。
メトリクスの算出に必要な変数はコンストラクタ内でadd_state
メソッドを使って定義する。
メトリクスの変数の更新はupdate
で実施する。updateはmetric(pred,target)
で実行されているメソッドである。つまり、引数にはバッチサイズ分のモデル出力値と教師データが入る。
メトリクスの算出はcompute
メソッドで実施する。これは毎エポック実行し、ログを出力するために使用される。
class CustomMetric(Metric):
def __init__(self, dist_sync_on_step = False):
#親コンストラクタの呼び出しとフィールドの定義
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
# metricの更新(metric(preds, target)で呼び出される。
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
def _calc_score(self,col:np.ndarray):
s1 = col[0].tolist()
s2 = col[1].tolist()
score = textdistance.levenshtein.distance(s1,s2)/max(len(s1),len(s2))
return score
実例
研究で使ったレーベンシュタイン距離を算出するクラス
from torchmetrics import Metric
import torch
import numpy as np
import textdistance
class Leivensitein(Metric):
"""
Leivensitein
"""
def __init__(self, dist_sync_on_step = False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
assert preds.shape == target.shape
catted = torch.stack([preds,target]).numpy().astype(int)
stack = 0
for i in range(catted.shape[2]):# batch size
stack+=( self._calc_score(catted[:,:,i]))
self.correct += torch.tensor(stack, dtype=torch.int)
self.total += preds.shape[1]
def compute(self):
return self.correct.float() / self.total
def _calc_score(self,col:np.ndarray):
s1 = col[0].tolist()
s2 = col[1].tolist()
score = textdistance.levenshtein.distance(s1,s2)/max(len(s1),len(s2))
return score
終わりに
便利な割に日本語の解説記事が少なかったので書きましたー
参考になれば嬉しいです!