0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CrossEntropyLossの挙動を確認する

Last updated at Posted at 2023-03-21

推定結果(x)が[0.6,0.4]でラベル(y)が[1,0]のクロスエントロピーを求めたい。

期待する数式は以下のようになり、その答えは0.510826となるはずである。

loss = -1*\log(0.6)-0*\log(0.4)=0.510826

これをCrossEntropyLossで実装してみる。

x = torch.Tensor([0.6,0.4])
y = torch.Tensor([1,0])
result = torch.nn.CrossEntropyLoss()(x,y)
print(result)

この結果は0.5981になる。
結果が異なった原因は公式ドキュメントを読むと解決できる。

This criterion computes the cross entropy loss between input logits and target.

この関数は、input logitsとtargetの間のクロスエントロピーを求めるとある。targetはラベルのことだと思うが、input logitsとはなんぞや?
調べてみるとlogitsは, softmax関数に通す前のニューラルネットワークの出力らしい。
さらにこのような説明もある。

The input is expected to contain the unnormalized logits for each class (which do not need to be positive or sum to 1, in general).

正である必要や和が1になっている必要はないとある。いろいろ調べてみると、どうやらCrossEntropyLossは, 第1引数(推定結果側)に入ってきた値のsoftmaxを計算しているようである。
つまり、CrossEntropyLossでは、以下の式を計算していることとなる。

    loss = -1*\log\biggl(\frac{e^{0.6}}{e^{0.6}+e^{0.4}}\biggr)-0*\log\biggl(\frac{e^{0.4}}{e^{0.6}+e^{0.4}}\biggr)

上式を計算すると0.5981となる。pythonでの実行結果と一致した...!!
それなら、入力である0.6や0.4のlogを取ってあげれば、

\begin{align}
loss &= -1*\log\biggl(\frac{e^{\log(0.6)}}{e^{\log(0.6)}+e^{\log(0.4)}}\biggr)-0*\log\biggl(\frac{e^{\log(0.4)}}{e^{\log(0.6)}+e^{\log(0.4)}}\biggr)\\
 &=-1*\log\biggl(\frac{0.6}{0.6+0.4}\biggr)-0*\log\biggl(\frac{0.4}{0.6+0.4}\biggr)\\
 &=-1*\log(0.6)-0*\log(0.4)\\
 &=0.5108
\end{align}

となり、求めたい値と一致するはずである。よって、pythonでの実装にlogをとる処理を追加してみる。

x = torch.Tensor([0.6,0.4])
y = torch.Tensor([1,0])
result = torch.nn.CrossEntropyLoss()(torch.log(x),y)
print(result)

この結果は0.5108となる。期待する値と一致した...!!

weightってなんぞや?

CrossEntropyLossにはweightという引数がある。これの挙動を確認する。
推定結果がx=[0.6,0.4]、ラベルがy=[0.5,0.5]、weight=[1,2]とする。

x = torch.Tensor([0.6,0.4])
y = torch.Tensor([0.5,0.5])
weight = torch.Tensor([1,2])
result = torch.nn.CrossEntropyLoss(weight=weight)(x,y)
print(result)

これの実行結果は, 1.0972となる。
公式ドキュメントを元にこの値と一致する数式を考えると、以下の式のように計算されているようである。

    loss = -1*0.5*\log\biggl(\frac{e^{0.6}}{e^{0.6}+e^{0.4}}\biggr)-2*0.5*\log\biggl(\frac{e^{0.4}}{e^{0.6}+e^{0.4}}\biggr)

よって、それぞれのクラスごとに何倍するかを表しているものがweightらしい。

ignore_indexってなんぞや?

分からん。
特定のクラスを無視する的なこと書いてあるけど、指定するとエラーがでてくるので今回はskip!!

reductionってなんぞや?

"mean"とか"sum"とか書いてあって、クロスエントロピーにmean??みたいなことになったので、挙動を確認していく。
xとyをそれぞれ、バッチサイズが3のデータであるとする。

x = torch.Tensor([[0.6,0.4],
                  [0.6,0.4],
                  [0.6,0.4]])
y = torch.Tensor([[1,0],
                  [1,0],
                  [1,0]])

result = torch.nn.CrossEntropyLoss(reduction="mean")(x,y)
print(result)

この結果は、0.5981

x = torch.Tensor([[0.6,0.4],
                  [0.6,0.4],
                  [0.6,0.4]])
y = torch.Tensor([[1,0],
                  [1,0],
                  [1,0]])

result = torch.nn.CrossEntropyLoss(reduction="sum")(x,y)
print(result)

この結果は1.7944
初めの実験で、x=[0.6,0.4],y=[1,0]とバッチサイズを1で入力したときは、結果は0.5981となっていた。これは、"mean"で行なった場合の結果と一致している。次に"sum"で行なった場合は、0.5981*3=1.7944と3倍になっていることが分かる。これより、データが複数ある場合にそれぞれのクロスエントロピーの和を取るか平均を取るかがreductionであることが分かった。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?