0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

CrossEntoropyLoss と NLLLoss の関係

Last updated at Posted at 2025-03-25

はじめに

毎回混乱するので備忘録です.

CrossEntropyLoss

pytorch実装では logit とターゲットを入力します.

import torch
import torch.nn as nn

g = torch.Generator().manual_seed(42)
a = 10 * torch.rand(4, 3, generator=g) # (batch, N), logits
print(a)
# tensor([[8.8227, 9.1500, 3.8286],
#         [9.5931, 3.9045, 6.0090],
#         [2.5657, 7.9364, 9.4077],
#         [1.3319, 9.3460, 5.9358]])

cross_entropy_loss = nn.CrossEntropyLoss()
target = torch.tensor([0, 1, 2, 1]) # (batch, )
loss = cross_entropy_loss(a, target)
print(loss)
# tensor([1.7082])

NLLLoss

pytorch実装では確率の対数尤度とターゲットを入力します.

nllloss = nn.NLLLoss()
log_p = torch.log_softmax(a, dim=1)
loss = nllloss(log_p, target)
print(loss)
# tensor([1.7082])

cross_entropy_loss(a, target)nllloss(torch.log_softmax(a), target) が一致します.

なお,確率とターゲットを入力にできる損失関数はpytorchではありません.おとなしくCrossEntropyLoss か NLLLoss を使いましょう.

あるいは確率を計算した後に自分で対数をとりましょう.

p = model(input)
nllloss = nn.NLLLoss()
log_p = torch.log(p)
loss = nllloss(log_p, target)
print(loss)
# tensor([1.7082])
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?