LoginSignup
9
4

More than 1 year has passed since last update.

PyTorchのCrossEntropyLossクラスについて

Last updated at Posted at 2021-05-05

はじめに

シングルラベルタスクにおける損失関数としてよく用いられる交差エントロピーですが、PyTorchの実装クラスの挙動がややわかりにくいものだったため苦労しました。備忘録を兼ねて記事を書きます。

交差エントロピーとは

主にシングルクラスラベルで用いられる損失関数で、以下の数式で定義されます(PyTorchの公式リファレンス)。

loss(x, class) = -log(\frac{exp(x[class])}{\sum_j exp(x[j])}) = -x[class] + log(\sum_j exp(x[j]))

交差エントロピー自体の説明については、この記事がわかりやすいです。GLUEタスクに代表されるようなシングルラベルタスクにおいては、この交差エントロピー関数を用いると損失が上手に収束して高い精度が達成されます。

PyTorchCrossEntropyLossクラスについて

PyTorchには、nnモジュールの中に交差エントロピーの損失関数が用意されています。PyTorchの公式リファレンスによると、使い方は以下の通りです。

PyTorchの公式リファレンスより
>>> from torch import nn
>>> import torch
>>>
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()

ターゲット変数の型について

上述のコードにおいて、入力変数であるinputの型はLogitsをシグモイド関数に通したものであるため、float型と考えて問題ないかと思います。他方、ターゲット変数であるtargetの型はint型で、変数のインデックスが1.0になっているone-hotベクトルがターゲット変数として設定されるようです。つまり、torch.tensor(0)torch.tensor([1., 0.])と扱われるようです。

>>> from torch import nn
>>> import torch
>>>
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.tensor([0.8, 0.1, 0.1]).unsqueeze(0)
>>> target = torch.tensor(0).unsqueeze(0)
>>> output = loss(input, target)
>>> output
tensor(0.6897)

CrossEntropyLossクラスの内部で、整数インデックスからone-hotベクトルが生成されているとは思いもよらなかったため、Huggingface inc.のコードを解読している時にハマってしまいました。
他方、マルチラベルタスクで用いられるBCEWithLogitsLossというクラスでは、ターゲット変数にfloat型のmulti-hot型ベクトルを入力する必要があるようです。

参考文献

9
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
4