はじめに
毎回混乱するので備忘録です.
CrossEntropyLoss
pytorch実装では logit とターゲットを入力します.
import torch
import torch.nn as nn
g = torch.Generator().manual_seed(42)
a = 10 * torch.rand(4, 3, generator=g) # (batch, N), logits
print(a)
# tensor([[8.8227, 9.1500, 3.8286],
# [9.5931, 3.9045, 6.0090],
# [2.5657, 7.9364, 9.4077],
# [1.3319, 9.3460, 5.9358]])
cross_entropy_loss = nn.CrossEntropyLoss()
target = torch.tensor([0, 1, 2, 1]) # (batch, )
loss = cross_entropy_loss(a, target)
print(loss)
# tensor([1.7082])
NLLLoss
pytorch実装では確率の対数尤度とターゲットを入力します.
nllloss = nn.NLLLoss()
log_p = torch.log_softmax(a, dim=1)
loss = nllloss(log_p, target)
print(loss)
# tensor([1.7082])
cross_entropy_loss(a, target)
と nllloss(torch.log_softmax(a), target)
が一致します.
なお,確率とターゲットを入力にできる損失関数はpytorchではありません.おとなしくCrossEntropyLoss か NLLLoss を使いましょう.
あるいは確率を計算した後に自分で対数をとりましょう.
p = model(input)
nllloss = nn.NLLLoss()
log_p = torch.log(p)
loss = nllloss(log_p, target)
print(loss)
# tensor([1.7082])