LoginSignup
16
11

More than 3 years have passed since last update.

[PyTorch]CrossEntropyLossを数式入りでちょっと理解する

Last updated at Posted at 2020-02-27

はじめに

Pytorchの損失関数の基準によくcriterion=torch.nn.CrossEntropyLoss()を使用しているため,
詳細を理解するためにアウトプットしてます.
間違ってたらそっと教えてください.

CrossEntropyLoss

Pytorchのサンプル(1)を参考にして,

torch.manual_seed(42) #再現性を保つためseed固定
loss = nn.CrossEntropyLoss()
input_num = torch.randn(1, 5, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
print('input_num:',input_num)
print('target:',target)
output = loss(input_num, target)
print('output:',output)
input_num: tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229]], requires_grad=True)
target: tensor([0])
output: tensor(1.3472, grad_fn=<NllLossBackward>)

正解クラスを $class$,クラス数を $n$ とするとCrossEntropyLossの誤差 $loss$ は以下の式で表すことができます.

loss=-\log(\frac{\exp(x[class])}{\sum_{j=0}^{n} \exp(x[j])}) \\
=-\{\log(\exp(x[class])- \log(\sum_{j=0}^{n} \exp(x[j])\} \\
=-\log(\exp(x[class]))+\log(\sum_{j=0}^{n} \exp(x[j])) \\
=-x[class]+\log(\sum_{j=0}^{n} \exp(x[j])) \\

ソースコードのサンプルより正解クラスは $class=0$,クラス数は $n=5$ なので,確かめてみると

loss=-x[0]+\log(\sum_{j=0}^{5} \exp(x[j]))\\
=-x[0]+\log(\exp(x[0])+\exp(x[1])+\exp(x[2])+\exp(x[3])+\exp(x[4])) \\
= -0.3367 + \log(\exp(0.3367)+\exp(0.1288)+\exp(0.2345)+\exp(0.2303)+\exp(-1.1229)) \\
= 1.34717 \cdots \\
\fallingdotseq 1.3472

プログラムの結果と無事合いましたね!
ちなみに計算は以下のコードでしてます(手計算の算数は無理・・・)

from math import exp, log
x_sum = exp(0.3367)+exp( 0.1288)+exp(0.2345)+exp(0.2303)+exp(-1.1229)
x = 0.3367
ans = -x + log(x_sum)
print(ans) # 1.3471717976017477

ごり押しです.

さいごに

丸め込み誤差(小数点の二進数表示による循環小数の発生)は今はあまり考えなくてもよさそう
普段はrandom.seed(42)なのですが,Pytorchだとtorch.manual_seed(42)なのではえ~って感じです.

参考文献

(1)TORCH.NN

16
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
11