LoginSignup
58
44

PyTorchのCrossEntropyLossの解説

Last updated at Posted at 2021-07-16

いつも混乱するのでメモ。

Cross Entropy = 交差エントロピーの定義

確率密度関数$p(x)$および$q(x)$に対して、Cross Entropyは次のように定義される。1

H(p, q) = -\sum_{x} p(x) \log(q(x))

これは情報量$\log(q(x))$の確率密度関数$p(x)$による期待値である。ここで、$p$ の $q$ に対するカルバック・ライブラー情報量は次のように与えられる。

\begin{aligned}
D_{\mathrm{KL}}(p|q) &= \sum_{x} p(x) \log(p(x)) - \sum_{x} p(x) \log(q(x)) \\
&= \sum_{x} p(x) \log(p(x)) + H(p, q)
\end{aligned}

いま、$p(x)$は既知であるとし、$q(x)$を求める問題に取り組んでいるとする。第一項は計算できて、$q(x)$に依存しない。

\begin{aligned}
D_{\mathrm{KL}}(p|q) &= \mathrm{const} + H(p, q)
\end{aligned}

カルバック・ライブラー情報量は分布の"近さ"を測る量で、常に$D_{\mathrm{KL}}(p|q) \geq 0$を満たし、等号が成立するのは$p(x)=q(x)$のときのみである。したがって、$q(x)$を求める問題ではCross Entropyが最小になるような$q(x)$を求めれば良い。

機械学習での使われ方

Cross Entropy を経験分布、簡単に言うと、得られたデータで計算したものは次のように書ける。

H(p, q) = -\sum_{n=1}^N\sum_{i=1}^C p_{n,i} \log(q_{n,i})

添字の$n$はデータのインデックスを表し、$i$はデータのクラスを表す。$p_{n,i}$はデータに対する正しい確率である。たとえば、1番目のデータがクラス3に属するときは、次のように書ける。

\begin{align}
p_{1,1} &= 0 \\
p_{1,2} &= 0 \\
p_{1,3} &= 1 \\
p_{1,4} &= 0 \\
... \\
p_{1,C} &= 0
\end{align}

$q_{n,i}$は機械学習モデル出力した$n$番目のデータがクラス$i$に属する確率である。機械学習モデルはCross Entropy が最小になるように、つまり$p_{n,i}$に一致するように、出力$q_{n,i}$を学習していく。

また、機械学習モデルは確率$q_{n,i}$を出力する必要があるので、$C$次元の出力$x_{n,i}$2は規格化されていなければならない。

q_{n,i} = \sigma_i (\vec{x}_{n}) = \frac{\exp(x_{n,i})}{\sum_{j=1}^C \exp(x_{n,j})}

ここで、$\vec{x}_{n}$は添字$i$を成分のインデックスとみなしたベクトルである。$\vec{\sigma}$をソフトマックス関数と呼ぶ。

PyTorchのCrossEntropyLoss

リファレンスには次のように書かれている。

\mathrm{loss}(x, \mathrm{class}) = -\log\left(\frac{\exp(x[\mathrm{class}])}{\sum_j \exp(x[j])}\right) = -x[\mathrm{class}] + \log\left(\sum_j \exp(x[j])\right)

ここで、リファレンスの式の$x$はニューラルネットワークの出力(関数torch.nn.CrossEntropyLoss()の入力)で、$\mathrm{class}$は正解クラスのインデックスである。

今までの表記に合わせると、データ$n$に対して次のように書ける。

\mathrm{loss}(\vec{x}_n, i) = -\log(q_{n,i}) = -\log\left(\frac{\exp(x_{n,i})}{\sum_{j=1}^C \exp(x_{n,j})}\right) = -x_{n,i} + \log\left(\sum_{j=1}^C \exp(x_{n,j})\right)

クロスエントロピーの式も、データ$n$の正解ラベルが$i$だとすると、該当部分のみ取り出して、

-\sum_{k=1}^C p_{n,k} \log(q_{n,k}) = -\sum_{k=1}^C \delta_{i,k} \log(q_{n,k}) = - \log(q_{n,i}),

となり、確かに一致する。

つまり、PyTorchの関数torch.nn.CrossEntropyLoss()は、損失関数内でソフトマックス関数の処理をしたことになっているので、ロスを計算する際はニューラルネットワークの最後にソフトマックス関数を適用する必要はない。モデルの構造を汎用的にするため、モデル自体はFC層のLinearで終わるようにするためであろう。むしろ推論時にはモデルの出力にソフトマックス関数を適用することを忘れないようにしなければならない。

おまけ

torch.nn.BCEWithLogitsLoss

リファレンスには次のように書かれている。

l_n = - w_n \left[ y_n \cdot \log \sigma(x_n)+ (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right]

こちらも損失関数内でシグモイド計算がされているので、ニューラルネットワークの最後にシグモイド関数を適用する必要はない。

torch.nn.BCELOSS

リファレンスには次のように書かれている。

l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]

こちらは損失関数内でシグモイド計算がされていないので、ニューラルネットワークの最後にシグモイド関数を適用する必要がある。予め数式を展開して計算するtorch.nn.BCEWithLogitsLossの方が数値的に安定する。

torch.nn.NLLLoss

リファレンスには次のように書かれている。

l_n = - w_{y_n} x_{n,y_n}

ニューラルネットワークの出力をマイナスするだけである(重みも加えられる)。ニューラルネットワークの出力が対数尤度、すなわち確率への規格化+対数計算のときに使う。torch.nn.LogSoftmaxを適用した後にtorch.nn.NLLLossを適用すればtorch.nn.CrossEntropyLossと同じになる。予め数式を展開して計算するtorch.nn.CrossEntropyLossの方が数値的に安定する。

Negative log likelihood は、log likelihoodをnegativeにするという意味っぽい。わかりにくい。

  1. $\sum_{x}$は$x$が離散変数ならば和を、連続変数ならば積分をとることを抽象的にまとめて書いた。

  2. $x$が異なる意味で複数回使われていてややこしい。ここではニューラルネットワークの最後のFC層などの出力を意味する。

58
44
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
58
44