この記事で説明すること
PyTorchのチュートリアルなどで,torch.nn.NLLLoss
を交差エントロピーを計算するために使っている場面を見かけます.
私は初めて見た時,なぜ torch.nn.CrossEntropyLoss
を使っていないのか疑問に感じました(こっちの方が関数名で何をするか想像しやすいし...).
この記事では,この NLLLoss
がどういう計算をしているのか,CrossEntropyLoss
とどう違うのかについて説明します.
LogSoftmax
+ NLLLoss
と CrossEntropyLoss
は同様の計算を行う
結論を述べると,torch.nn.LogSoftmax
に続いて torch.nn.NLLLoss
を適用するのと,torch.nn.CrossEntropyLoss
を適用するのは同じことです.
以下のコードで実際に確認してみます.
import torch
import torch.nn as nn
torch.manual_seed(7777)
ce_loss = nn.CrossEntropyLoss()
logsoftmax = nn.LogSoftmax(dim=1)
nll_loss = nn.NLLLoss()
# モデルの最終的な全結合層の出力(という想定)
x = torch.randn(2, 5)
# 正解ラベルのリスト(という想定)
target = torch.tensor([2, 1])
# CrossEntropyLoss単体
print(ce_loss(x, target))
# LogSoftmax + NLLLoss
z = logsoftmax(x)
print(nll_loss(z, target))
# 出力
# tensor(1.7886)
# tensor(1.7886)
LogSoftmax
+ NLLLoss
の動作説明
LogSoftmax
torch.nn.LogSoftmax
は,入力されたテンソル $x$ に対してsoftmaxの計算をしてからlogを取るだけです.
以下,torch.nn.LogSoftmax
の公式ドキュメントから式を引用します.
\text{LogSoftmax}(x_i) = \log\left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)
ここで注意すべきは,**softmaxの式中の総和の部分をどの軸について取るか?**を指定する必要があることです.
これは引数 dim
で指定します.同じテンソルでも,どの軸で総和を取るかによって,結果が異なってしまいます.その例を以下に示します.
以下を見ると,引数 dim
が違うだけで計算結果が異なることがわかります.
x = torch.randn(2, 5)
z0 = nn.LogSoftmax(dim=0)(x)
z1 = nn.LogSoftmax(dim=1)(x)
print(z0)
print(z1)
# 出力
# tensor([[-0.0889, -1.0106, -0.4296, -0.7947, -0.8084],
# [-2.4647, -0.4526, -1.0521, -0.6009, -0.5898]])
# tensor([[-0.2800, -2.7624, -2.4444, -2.7047, -3.5961],
# [-1.5842, -1.1328, -1.9953, -1.4392, -2.3059]])
- 上の例では,
x
は (2,5) という形状のテンソルです. -
z0
は0番目の軸について総和を取るため,ソフトマックスの式中の総和は2つの和となります.- 例えば,
z0[0][0]
を計算する際,総和の部分 =exp(x[0][0])+exp(x[1][0])
となります.
- 例えば,
-
z1
は1番目の軸について総和を取るため,ソフトマックスの式中の総和は5つの和となります.- 例えば,
z1[0][0]
を計算する際,総和の部分 =exp(x[0][0])+exp(x[0][1])+...+exp(x[0][4])
となります.
- 例えば,
実際に分類問題などで用いる場合
ニューラルネットワークを $C$ クラスの分類問題などで用いる際,ミニバッチサイズ $N$ として,この入力のテンソル $x$ が $(N,C)$ という形状の場合は,dim=1
とする必要があります.
ミニバッチに含まれる各サンプルに対する計算結果(長さ $C$ のリスト)について,LogSoftmaxの総和の部分を計算する必要があるためです.
NLLLoss
torch.nn.NLLLoss
の公式ドキュメントに基づいて説明します.
まず,NLLLoss
は Negative Log-Likelihood Loss を表すそうです.
しかし,実態を見ると,Log-Likelihood(対数尤度)の計算は特に担っておらず,基本的に 'Negative' の部分しか担っていないことがわかりました.
reduction='mean'
時の動作
オプションについて:
NLLLoss
の引数 reduction
によって計算する内容が異なります.
デフォルトは reduction='mean'
なので,この時の動作について説明します.
また,引数 weight
を設定することで,ロス計算時に各クラスに対する重みを設定できます(不均衡データなどの場合に有用らしい)が,デフォルトでは weight=None
となっているのでこの点についても省略します.
入力:
最初のコードで示したように,NLLLoss
には2つのテンソル $x,y$ を入力します.
- $x$ : 形状 $(N, C)$ の2次元のテンソル
- $y$ : 形状 $(N,)$ の1次元のテンソル
$x$ の方は,ミニバッチの各サンプルについて,各クラスに所属する確率のlogをとったものが格納された形状 $(N,C)$ のテンソルが入力されることが想定されています.
確率に変換し,logを取るという操作は LogSoftmax
によって行えます.よって,この関数はLogSoftmax
の後に適用することが想定されています.
一方で,$y$ の方は,ミニバッチの各サンプルについて,正解ラベルが $[0, C-1]$ の整数で格納された形状 $(N,)$ のテンソルが入力されることが想定されています.
出力:
そして,以下の値 $\ell (x,y)$ を計算して返します.
\ell (x,y) = \frac{1}{N} \sum_{n=1}^N (-x_{n, y_n})
上式で使われている文字の説明は以下の通りです:
- $N$ : ミニバッチのサイズ
- $y_n$ : ミニバッチの $n$ 番目のサンプルの正解ラベル($[0, C-1]$ の整数)
- $x_{n, y_n}$ : ミニバッチの $n$ 番目のサンプルに対する,クラス $y_n$(=正解のクラス)の予測値
つまり,ミニバッチの各サンプルについて,正解のラベルへの予測値にマイナスつけたものを平均しているだけです.
これが,基本的に 'Negative' の部分しか担っていない,という意味でした.
reduction
が 'sum'
や'none'
の場合の動作については,公式ドキュメントを見てください.
しかし,この 'mean'
の場合の動作が大体理解できれば他の場合も理解しやすいと思います.
計算例
以下に NLLLoss
の計算例を示します.
ミニバッチサイズ $N=2$ ,クラス数 $C=5$ の場合です.
$\frac{1}{2} (-x_{0,4}-x_{1,1}) = \frac{1}{2}(-0.5-0.1) = -0.3$ と計算されています.
nll_loss = nn.NLLLoss()
xx = torch.tensor([ [0.1, 0.1, 0.1, 0.2, 0.5],
[0.1, 0.1, 0.1, 0.2, 0.5]
])
yy = torch.tensor([4, 1])
nll_loss(xx, yy)
# 出力
# tensor(-0.3000)
なぜ LogSoftmax
と NLLLoss
で CrossEntropyLoss
と同じになるのか
torch.nn.CrossEntropyLoss
の公式ドキュメントを見ると,以下の計算式によって計算した値を返す旨が書かれています.
この式は交差エントロピーの定義を少し変形すればわかります.
\text{loss}(x, class) = -\log{\left( \frac{\exp({x[class]})}{\sum_j \exp(x[j])} \right)}
この式の右辺は,LogSoftmax
で計算する式と似ています.違うのは,
- 先頭にマイナスがついている.
- $x$ の全ての要素ではなく,正解ラベル$class$に対応する要素だけを取り出している.
上記の2つの操作はちょうど NLLLoss
を使うことで実現できます.
そのため,LogSoftmax
と NLLLoss
を使うと CrossEntropyLoss
と同じになります.
実際の利用時
以上のことから,
- 訓練時に
NLLLoss
を利用する場合は,モデルの最後の全結合層の後に,LogSoftmax
の層を入れておく必要があります. - 訓練時に
CrossEntropyLoss
を利用する場合は,そのような追加の層を入れる必要はありません.
参考
[1] torch.nn.LogSoftmax
の公式ドキュメント
[2] torch.nn.NLLLoss
の公式ドキュメント
[3] torch.nn.CrossEntropyLoss
の公式ドキュメント
[4] 交差エントロピー - Wikipedia
[5] pytorch の NLLLoss の挙動 - メモ