0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

重み付き平均2乗誤差をつくる (PyTorch)

Last updated at Posted at 2024-04-29

PyTorch には MSELoss があって下図左のような平均2乗誤差を取ることができますが、下図右のように行ごとに重みをつけたいことがあると思います。MSELoss にはそのようなオプションはなさそうなので自分でつくります。

image.png

以下は実装した例です。kenzan_* という関数は numpy でも検算してみているだけです。

重みを付けない場合 (torch.nn.MSELoss そのまま)

先に MSELoss が全ての方向に平均を取っていることを確認しただけです。
自作損失で MSELoss を使用するわけでもないので確認した意味はあまりありません。

import torch
import numpy as np

def kenzan_MSE(pred, true):
    batch_size, _, _ = pred.size()
    li_mse = []
    for i in range(batch_size):
        pred_ = pred[i].numpy()
        true_ = true[i].numpy()
        se = (pred_ - true_) * (pred_ - true_)
        mse = se.mean()
        print(f'バッチ内の{i}サンプル目の2乗誤差')
        print(se)
        print('MSE', mse)
        li_mse.append(mse)
    return print('バッチ内の全サンプルのMSEの平均', np.mean(li_mse))

criterion = torch.nn.MSELoss()
pred = torch.tensor([
    [[ 0.,  1.,  2.], [ 0.,  1.,  2.]],
    [[ 0.,  1.,  2.], [ 0.,  1.,  2.]],
    [[ 0.,  1.,  2.], [ 0.,  1.,  2.]],
    [[ 0.,  1.,  2.], [ 0.,  1.,  2.]],
])
true = torch.tensor([
    [[ 0.,  1.,  4.], [ 0.,  1.,  4.]],
    [[ 0.,  1.,  4.], [ 0.,  1.,  5.]],
    [[ 0.,  1.,  2.], [ 0.,  1.,  6.]],
    [[ 0.,  1.,  3.], [ 0.,  1.,  7.]],
])
print(criterion(pred, true))
kenzan_MSE(pred, true)
tensor(2.6250)

バッチ内の0サンプル目の2乗誤差
[[0. 0. 4.]
 [0. 0. 4.]]
MSE 1.3333334
バッチ内の1サンプル目の2乗誤差
[[0. 0. 4.]
 [0. 0. 9.]]
MSE 2.1666667
バッチ内の2サンプル目の2乗誤差
[[ 0.  0.  0.]
 [ 0.  0. 16.]]
MSE 2.6666667
バッチ内の3サンプル目の2乗誤差
[[ 0.  0.  1.]
 [ 0.  0. 25.]]
MSE 4.3333335
バッチ内の全サンプルのMSEの平均 2.625

重みを付ける場合 (カスタム)

行ごとに重みを適用したいようなとき torch.einsum() が使用できるので使用します。
想定通りの計算結果になります。

class MSE_decay(torch.nn.Module):
    def __init__(self, d):
        super(MSE_decay, self).__init__()
        lambda_ = 1.0 / d
        self.decay = torch.exp(-1.0 * lambda_ * torch.arange(1, d + 1, dtype=torch.float))
    def forward(self, pred, true):
        return torch.mean(
            torch.einsum('j,ijk->ijk', (self.decay, (pred - true) * (pred - true))))

def kenzan_MSE_decay(pred, true):
    batch_size, d, _ = pred.size()
    li_mse = []
    for i in range(batch_size):
        for j in range(d):
            coef = np.exp(- (j + 1.) / d)
            pred_ = pred[i, j, :].numpy()
            true_ = true[i, j, :].numpy()
            se = (pred_ - true_) * (pred_ - true_)
            mse = se.mean()
            print(f'バッチ内の{i}サンプル目の{j}行目の2乗誤差')
            print(se, mse, '-->', coef * mse)
            li_mse.append(coef * mse)
    return print('バッチ内の全サンプルのMSEの平均', np.mean(li_mse))

criterion = MSE_decay(d=2)
print(criterion(pred, true))
kenzan_MSE_decay(pred, true)
tensor(1.0552)

バッチ内の0サンプル目の0行目の2乗誤差
[0. 0. 4.] 1.3333334 --> 0.8087075703848743
バッチ内の0サンプル目の1行目の2乗誤差
[0. 0. 4.] 1.3333334 --> 0.4905059361801387
バッチ内の1サンプル目の0行目の2乗誤差
[0. 0. 4.] 1.3333334 --> 0.8087075703848743
バッチ内の1サンプル目の1行目の2乗誤差
[0. 0. 9.] 3.0 --> 1.103638323514327
バッチ内の2サンプル目の0行目の2乗誤差
[0. 0. 0.] 0.0 --> 0.0
バッチ内の2サンプル目の1行目の2乗誤差
[ 0.  0. 16.] 5.3333335 --> 1.9620237447205549
バッチ内の3サンプル目の0行目の2乗誤差
[0. 0. 1.] 0.33333334 --> 0.20217689259621857
バッチ内の3サンプル目の1行目の2乗誤差
[ 0.  0. 25.] 8.333333 --> 3.0656618928162946
バッチ内の全サンプルのMSEの平均 1.0551777413246604
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?