PyTorch には MSELoss があって下図左のような平均2乗誤差を取ることができますが、下図右のように行ごとに重みをつけたいことがあると思います。MSELoss にはそのようなオプションはなさそうなので自分でつくります。
以下は実装した例です。kenzan_*
という関数は numpy でも検算してみているだけです。
重みを付けない場合 (torch.nn.MSELoss そのまま)
先に MSELoss が全ての方向に平均を取っていることを確認しただけです。
自作損失で MSELoss を使用するわけでもないので確認した意味はあまりありません。
import torch
import numpy as np
def kenzan_MSE(pred, true):
batch_size, _, _ = pred.size()
li_mse = []
for i in range(batch_size):
pred_ = pred[i].numpy()
true_ = true[i].numpy()
se = (pred_ - true_) * (pred_ - true_)
mse = se.mean()
print(f'バッチ内の{i}サンプル目の2乗誤差')
print(se)
print('MSE', mse)
li_mse.append(mse)
return print('バッチ内の全サンプルのMSEの平均', np.mean(li_mse))
criterion = torch.nn.MSELoss()
pred = torch.tensor([
[[ 0., 1., 2.], [ 0., 1., 2.]],
[[ 0., 1., 2.], [ 0., 1., 2.]],
[[ 0., 1., 2.], [ 0., 1., 2.]],
[[ 0., 1., 2.], [ 0., 1., 2.]],
])
true = torch.tensor([
[[ 0., 1., 4.], [ 0., 1., 4.]],
[[ 0., 1., 4.], [ 0., 1., 5.]],
[[ 0., 1., 2.], [ 0., 1., 6.]],
[[ 0., 1., 3.], [ 0., 1., 7.]],
])
print(criterion(pred, true))
kenzan_MSE(pred, true)
tensor(2.6250)
バッチ内の0サンプル目の2乗誤差
[[0. 0. 4.]
[0. 0. 4.]]
MSE 1.3333334
バッチ内の1サンプル目の2乗誤差
[[0. 0. 4.]
[0. 0. 9.]]
MSE 2.1666667
バッチ内の2サンプル目の2乗誤差
[[ 0. 0. 0.]
[ 0. 0. 16.]]
MSE 2.6666667
バッチ内の3サンプル目の2乗誤差
[[ 0. 0. 1.]
[ 0. 0. 25.]]
MSE 4.3333335
バッチ内の全サンプルのMSEの平均 2.625
重みを付ける場合 (カスタム)
行ごとに重みを適用したいようなとき torch.einsum()
が使用できるので使用します。
想定通りの計算結果になります。
class MSE_decay(torch.nn.Module):
def __init__(self, d):
super(MSE_decay, self).__init__()
lambda_ = 1.0 / d
self.decay = torch.exp(-1.0 * lambda_ * torch.arange(1, d + 1, dtype=torch.float))
def forward(self, pred, true):
return torch.mean(
torch.einsum('j,ijk->ijk', (self.decay, (pred - true) * (pred - true))))
def kenzan_MSE_decay(pred, true):
batch_size, d, _ = pred.size()
li_mse = []
for i in range(batch_size):
for j in range(d):
coef = np.exp(- (j + 1.) / d)
pred_ = pred[i, j, :].numpy()
true_ = true[i, j, :].numpy()
se = (pred_ - true_) * (pred_ - true_)
mse = se.mean()
print(f'バッチ内の{i}サンプル目の{j}行目の2乗誤差')
print(se, mse, '-->', coef * mse)
li_mse.append(coef * mse)
return print('バッチ内の全サンプルのMSEの平均', np.mean(li_mse))
criterion = MSE_decay(d=2)
print(criterion(pred, true))
kenzan_MSE_decay(pred, true)
tensor(1.0552)
バッチ内の0サンプル目の0行目の2乗誤差
[0. 0. 4.] 1.3333334 --> 0.8087075703848743
バッチ内の0サンプル目の1行目の2乗誤差
[0. 0. 4.] 1.3333334 --> 0.4905059361801387
バッチ内の1サンプル目の0行目の2乗誤差
[0. 0. 4.] 1.3333334 --> 0.8087075703848743
バッチ内の1サンプル目の1行目の2乗誤差
[0. 0. 9.] 3.0 --> 1.103638323514327
バッチ内の2サンプル目の0行目の2乗誤差
[0. 0. 0.] 0.0 --> 0.0
バッチ内の2サンプル目の1行目の2乗誤差
[ 0. 0. 16.] 5.3333335 --> 1.9620237447205549
バッチ内の3サンプル目の0行目の2乗誤差
[0. 0. 1.] 0.33333334 --> 0.20217689259621857
バッチ内の3サンプル目の1行目の2乗誤差
[ 0. 0. 25.] 8.333333 --> 3.0656618928162946
バッチ内の全サンプルのMSEの平均 1.0551777413246604