25
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

[PyTorch] NLLLoss と CrossEntropyLoss の違い

Last updated at Posted at 2021-10-20

この記事で説明すること

PyTorchのチュートリアルなどで,torch.nn.NLLLoss を交差エントロピーを計算するために使っている場面を見かけます.
私は初めて見た時,なぜ torch.nn.CrossEntropyLoss を使っていないのか疑問に感じました(こっちの方が関数名で何をするか想像しやすいし...).
この記事では,この NLLLoss がどういう計算をしているのか,CrossEntropyLoss とどう違うのかについて説明します.

LogSoftmax + NLLLossCrossEntropyLoss は同様の計算を行う

結論を述べると,torch.nn.LogSoftmax に続いて torch.nn.NLLLoss を適用するのと,torch.nn.CrossEntropyLoss を適用するのは同じことです.
以下のコードで実際に確認してみます.

LogSoftmax+NLLLossとCrossEntropyLossの違い
import torch
import torch.nn as nn
torch.manual_seed(7777)

ce_loss = nn.CrossEntropyLoss()
logsoftmax = nn.LogSoftmax(dim=1)
nll_loss = nn.NLLLoss()

# モデルの最終的な全結合層の出力(という想定)
x = torch.randn(2, 5)
# 正解ラベルのリスト(という想定)
target = torch.tensor([2, 1])

# CrossEntropyLoss単体
print(ce_loss(x, target))
# LogSoftmax + NLLLoss
z =  logsoftmax(x)
print(nll_loss(z, target))

# 出力
# tensor(1.7886)
# tensor(1.7886)

LogSoftmax + NLLLoss の動作説明

LogSoftmax

torch.nn.LogSoftmax は,入力されたテンソル $x$ に対してsoftmaxの計算をしてからlogを取るだけです.
以下,torch.nn.LogSoftmax の公式ドキュメントから式を引用します.

\text{LogSoftmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)

ここで注意すべきは,**softmaxの式中の総和の部分をどの軸について取るか?**を指定する必要があることです.
これは引数 dim で指定します.同じテンソルでも,どの軸で総和を取るかによって,結果が異なってしまいます.その例を以下に示します.
以下を見ると,引数 dim が違うだけで計算結果が異なることがわかります.

LogSoftmaxはdimに注意
x = torch.randn(2, 5)
z0 = nn.LogSoftmax(dim=0)(x)
z1 = nn.LogSoftmax(dim=1)(x)
print(z0)
print(z1)

# 出力
# tensor([[-0.0889, -1.0106, -0.4296, -0.7947, -0.8084],
#         [-2.4647, -0.4526, -1.0521, -0.6009, -0.5898]])
# tensor([[-0.2800, -2.7624, -2.4444, -2.7047, -3.5961],
#         [-1.5842, -1.1328, -1.9953, -1.4392, -2.3059]])

  • 上の例では,x は (2,5) という形状のテンソルです.
  • z0 は0番目の軸について総和を取るため,ソフトマックスの式中の総和は2つの和となります.
    • 例えば,z0[0][0] を計算する際,総和の部分 = exp(x[0][0])+exp(x[1][0]) となります.
  • z1 は1番目の軸について総和を取るため,ソフトマックスの式中の総和は5つの和となります.
    • 例えば,z1[0][0] を計算する際,総和の部分 = exp(x[0][0])+exp(x[0][1])+...+exp(x[0][4]) となります.

実際に分類問題などで用いる場合

ニューラルネットワークを $C$ クラスの分類問題などで用いる際,ミニバッチサイズ $N$ として,この入力のテンソル $x$ が $(N,C)$ という形状の場合は,dim=1 とする必要があります.
ミニバッチに含まれる各サンプルに対する計算結果(長さ $C$ のリスト)について,LogSoftmaxの総和の部分を計算する必要があるためです.

NLLLoss

torch.nn.NLLLossの公式ドキュメントに基づいて説明します.
まず,NLLLoss は Negative Log-Likelihood Loss を表すそうです.
しかし,実態を見ると,Log-Likelihood(対数尤度)の計算は特に担っておらず,基本的に 'Negative' の部分しか担っていないことがわかりました.

reduction='mean' 時の動作

オプションについて:

NLLLoss の引数 reduction によって計算する内容が異なります.
デフォルトは reduction='mean' なので,この時の動作について説明します.
また,引数 weight を設定することで,ロス計算時に各クラスに対する重みを設定できます(不均衡データなどの場合に有用らしい)が,デフォルトでは weight=None となっているのでこの点についても省略します.

入力:
最初のコードで示したように,NLLLoss には2つのテンソル $x,y$ を入力します.

  • $x$ : 形状 $(N, C)$ の2次元のテンソル
  • $y$ : 形状 $(N,)$ の1次元のテンソル

$x$ の方は,ミニバッチの各サンプルについて,各クラスに所属する確率のlogをとったものが格納された形状 $(N,C)$ のテンソルが入力されることが想定されています.
確率に変換し,logを取るという操作は LogSoftmax によって行えます.よって,この関数はLogSoftmax の後に適用することが想定されています.
一方で,$y$ の方は,ミニバッチの各サンプルについて,正解ラベルが $[0, C-1]$ の整数で格納された形状 $(N,)$ のテンソルが入力されることが想定されています.

出力:
そして,以下の値 $\ell (x,y)$ を計算して返します.

\ell (x,y) = \frac{1}{N} \sum_{n=1}^N (-x_{n, y_n})

上式で使われている文字の説明は以下の通りです:

  • $N$ : ミニバッチのサイズ
  • $y_n$ : ミニバッチの $n$ 番目のサンプルの正解ラベル($[0, C-1]$ の整数)
  • $x_{n, y_n}$ : ミニバッチの $n$ 番目のサンプルに対する,クラス $y_n$(=正解のクラス)の予測値

つまり,ミニバッチの各サンプルについて,正解のラベルへの予測値にマイナスつけたものを平均しているだけです.
これが,基本的に 'Negative' の部分しか担っていない,という意味でした.

reduction'sum''none' の場合の動作については,公式ドキュメントを見てください.
しかし,この 'mean' の場合の動作が大体理解できれば他の場合も理解しやすいと思います.

計算例
以下に NLLLoss の計算例を示します.
ミニバッチサイズ $N=2$ ,クラス数 $C=5$ の場合です.
$\frac{1}{2} (-x_{0,4}-x_{1,1}) = \frac{1}{2}(-0.5-0.1) = -0.3$ と計算されています.

NLLossの計算例
nll_loss = nn.NLLLoss()
xx = torch.tensor([ [0.1, 0.1, 0.1, 0.2, 0.5],
                    [0.1, 0.1, 0.1, 0.2, 0.5]
                  ])
yy = torch.tensor([4, 1])
nll_loss(xx, yy)

# 出力
# tensor(-0.3000)

なぜ LogSoftmaxNLLLossCrossEntropyLoss と同じになるのか

torch.nn.CrossEntropyLoss の公式ドキュメントを見ると,以下の計算式によって計算した値を返す旨が書かれています.
この式は交差エントロピーの定義を少し変形すればわかります.

\text{loss}(x, class) = -\log{\left( \frac{\exp({x[class]})}{\sum_j \exp(x[j])} \right)}

この式の右辺は,LogSoftmax で計算する式と似ています.違うのは,

  • 先頭にマイナスがついている.
  • $x$ の全ての要素ではなく,正解ラベル$class$に対応する要素だけを取り出している.

上記の2つの操作はちょうど NLLLoss を使うことで実現できます.
そのため,LogSoftmaxNLLLoss を使うと CrossEntropyLoss と同じになります.

実際の利用時

以上のことから,

  • 訓練時に NLLLoss を利用する場合は,モデルの最後の全結合層の後に,LogSoftmax の層を入れておく必要があります.
  • 訓練時に CrossEntropyLoss を利用する場合は,そのような追加の層を入れる必要はありません.

参考

[1] torch.nn.LogSoftmax の公式ドキュメント
[2] torch.nn.NLLLossの公式ドキュメント
[3] torch.nn.CrossEntropyLoss の公式ドキュメント
[4] 交差エントロピー - Wikipedia
[5] pytorch の NLLLoss の挙動 - メモ

25
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
14

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?