はじめに
シングルラベルタスクにおける損失関数としてよく用いられる交差エントロピーですが、PyTorch
の実装クラスの挙動がややわかりにくいものだったため苦労しました。備忘録を兼ねて記事を書きます。
交差エントロピーとは
主にシングルクラスラベルで用いられる損失関数で、以下の数式で定義されます(PyTorchの公式リファレンス)。
loss(x, class) = -log(\frac{exp(x[class])}{\sum_j exp(x[j])}) = -x[class] + log(\sum_j exp(x[j]))
交差エントロピー自体の説明については、この記事がわかりやすいです。GLUEタスクに代表されるようなシングルラベルタスクにおいては、この交差エントロピー関数を用いると損失が上手に収束して高い精度が達成されます。
PyTorch
のCrossEntropyLoss
クラスについて
PyTorch
には、nn
モジュールの中に交差エントロピーの損失関数が用意されています。PyTorchの公式リファレンスによると、使い方は以下の通りです。
>>> from torch import nn
>>> import torch
>>>
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
ターゲット変数の型について
上述のコードにおいて、入力変数であるinput
の型はLogits
をシグモイド関数に通したものであるため、float
型と考えて問題ないかと思います。他方、ターゲット変数であるtarget
の型はint
型で、変数のインデックスが1.0
になっているone-hotベクトルがターゲット変数として設定されるようです。つまり、torch.tensor(0)
はtorch.tensor([1., 0.])
と扱われるようです。
>>> from torch import nn
>>> import torch
>>>
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.tensor([0.8, 0.1, 0.1]).unsqueeze(0)
>>> target = torch.tensor(0).unsqueeze(0)
>>> output = loss(input, target)
>>> output
tensor(0.6897)
CrossEntropyLoss
クラスの内部で、整数インデックスからone-hotベクトルが生成されているとは思いもよらなかったため、Huggingface inc.のコードを解読している時にハマってしまいました。
他方、マルチラベルタスクで用いられるBCEWithLogitsLoss
というクラスでは、ターゲット変数にfloat
型のmulti-hot型ベクトルを入力する必要があるようです。