7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

TorchMetricsでのカスタムメトリクスの実装

Posted at

はじめに

TorchMetricsとはPyTorchやPyTorch Lightningでサポートされているメトリクス算出用の抽象クラスである。

たとえばPytorchだとこんな感じにとても簡単にかける。

import torch
# import our library
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Reseting internal state such that metric ready for new data
metric.reset()

カスタムメトリクスの作成

torchmetricsモジュールに基本的なライブラリは登録されているがどうしても自分で作りたい時がある。
その際はMetricクラスを継承すると同じようにかける。

メトリクスの算出に必要な変数はコンストラクタ内でadd_stateメソッドを使って定義する。

メトリクスの変数の更新はupdateで実施する。updateはmetric(pred,target)で実行されているメソッドである。つまり、引数にはバッチサイズ分のモデル出力値と教師データが入る。

メトリクスの算出はcomputeメソッドで実施する。これは毎エポック実行し、ログを出力するために使用される。

class CustomMetric(Metric):
    def __init__(self, dist_sync_on_step = False):
        #親コンストラクタの呼び出しとフィールドの定義
        super().__init__(dist_sync_on_step=dist_sync_on_step) 
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        
    def update(self, preds: torch.Tensor, target: torch.Tensor):
     # metricの更新(metric(preds, target)で呼び出される。
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape
        self.correct += torch.sum(preds == target)
        self.total += target.numel()
           
    def compute(self):
        return self.correct.float() / self.total
   
    def _calc_score(self,col:np.ndarray):
        s1 = col[0].tolist()
        s2 = col[1].tolist()
        score = textdistance.levenshtein.distance(s1,s2)/max(len(s1),len(s2))
        return score

実例

研究で使ったレーベンシュタイン距離を算出するクラス

from torchmetrics import Metric
import torch
import numpy as np
import textdistance

class Leivensitein(Metric):
    """
    Leivensitein
    """
    def __init__(self, dist_sync_on_step = False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        catted = torch.stack([preds,target]).numpy().astype(int)
        stack = 0
        for i in range(catted.shape[2]):# batch size
            stack+=( self._calc_score(catted[:,:,i]))         
        self.correct += torch.tensor(stack, dtype=torch.int)
        self.total += preds.shape[1]
           
    def compute(self):
        return self.correct.float() / self.total
   
    def _calc_score(self,col:np.ndarray):
        s1 = col[0].tolist()
        s2 = col[1].tolist()
        score = textdistance.levenshtein.distance(s1,s2)/max(len(s1),len(s2))
        return score

終わりに

便利な割に日本語の解説記事が少なかったので書きましたー
参考になれば嬉しいです!

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?