5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【pytorch】CrossEntropyLoss

Last updated at Posted at 2022-10-11

はじめに

Cross entropy の意味は分かるのですが、これをpytorch の関数 CrossEntropyLoss で計算させるところでつまづきました。
入力のサイズによりエラーが出たりでなかったりで、良く分からなかったので調べました。

内容

CrossEntropyLoss とは

本家の説明はこちら。これでわかるかな。

Cross entropy が何を計算しているかからpytorch での使い方まで、下記の解説がとても分かり易かったです。

K クラスの分類タスクの予測評価

C個のクラスの分類タスクを考えます。入力は、正規化されていないスコアが想定します。
Unbatchの状態では、入力は size C の tensor とありますが、これがどうも 1xCのtensor でしか私は動かせませんでした。
ミニバッチの場合は、入力は、(minibatch, C) あるいは、K次元のデータなら、(minicatch, C, $d_1$, $\cdots$, $d_K$)になるとのことです。

Sonftmax + negative log-likelihood

cross entropy は確率分布の間で計算するものなので、入力は0から1の値をとり和が1になる配列でなければならない。ので、K次元の実数がKクラスの予測のスコアとしてきたら、それを正規化してあげなければならない。それがsoftmax と言われる関数で行える。この関数は今までEMアルゴリズムの重みの計算とかで出てきたけれど、Softmax って呼ばれるんですね。
CrossEntropyLoss が Softmax + NLL(negative log-likelihood) Lossという説明がたくさんありますが、good な説明ですね。

以下、計算してきます。

実行結果

これを使おうとして、入力をK次元のtensor にしていたのですが、どうやら二重にしなくてはいけないみたいです。
1 x K 次元のtensor で入れない方法を探しましたが、見つかりませんでした。ので、かっこ悪いですが下記で計算することにします。

import torch
from torch import nn

loss = nn.CrossEntropyLoss()

K = 3 # number of classes
print(f'number of class: {K}')
y_pred = torch.tensor([[-3, 0.8, 5]]) 
print(f'inputs {type(y_pred)} {y_pred.shape} dtype={y_pred.dtype}')
y_label = torch.tensor([1], dtype=torch.int64) # index must be one of 0, 1, 2
print(f'target {type(y_label)} {y_label.shape} dtype={y_label.dtype}')
output = loss(y_pred, y_label)
#output.backward()
print(f'output={output.item()} dtype={output.dtype} shape={output.shape}')

以下を得ます。

number of class: 3
inputs <class 'torch.Tensor'> torch.Size([1, 3]) dtype=torch.float32
target <class 'torch.Tensor'> torch.Size([1]) dtype=torch.int64
output=4.215214729309082 dtype=torch.float32 shape=torch.Size([])

batch 処理している場合には、Nサンプルまとめて予測が来るので、(N,K) 次元の予測値に対して、(N)個の予測ラベルが対応します。このときのcross entropy は下記で計算できました。多分。

mini-batch 使用時

今度はNサンプルまとめて計算します。ミニバッチ、バッチサイズ3と言うらしいです。用語あっているかな。

N = 3 # batch size
K = 5 # number of classes
print(f'Batch size: {N} number of class: {K}')
y_pred = torch.randn(N, K, requires_grad=True) # predicted score
print(f'inputs {type(y_pred)} {y_pred.shape} dtype={y_pred.dtype}')
y_label = torch.empty(N, dtype=torch.long).random_(K)
print(f'target {type(y_label)} {y_label.shape} dtype={y_label.dtype}')
output = loss(y_pred, y_label)
output.backward()
print(f'output={output.item()} dtype={output.dtype} shape={output.shape}')

下記を得ます。

Batch size: 3 number of class: 5
inputs <class 'torch.Tensor'> torch.Size([3, 5]) dtype=torch.float32
target <class 'torch.Tensor'> torch.Size([3]) dtype=torch.int64
output=1.3194670677185059 dtype=torch.float32 shape=torch.Size([])

ここまで、ターゲットがclass のラベルでしたが、確率にすることもできます。その場合は、各クラスの確率を与えなければなりません。下記でsoftmax をしているのは、確率にするためです。

N = 3 # batch size
K = 5 # number of classes
print(f'Batch size: {N} number of class: {K}')
y_pred = torch.randn(N, K, requires_grad=True) # predicted score
print(f'inputs {type(y_pred)} {y_pred.shape} dtype={y_pred.dtype}')
y_label_prob = torch.randn(N, K).softmax(dim=1)
print(f'target {type(y_label_prob)} {y_label_prob.shape} dtype={y_label_prob.dtype}')
output = loss(y_pred, y_label_prob)
output.backward()
print(f'output={output.item()} dtype={output.dtype} shape={output.shape}')

下記を得ました。

Batch size: 3 number of class: 5
inputs <class 'torch.Tensor'> torch.Size([3, 5]) dtype=torch.float32
target <class 'torch.Tensor'> torch.Size([3, 5]) dtype=torch.float32
output=1.6164592504501343 dtype=torch.float32 shape=torch.Size([])

CrossEntropyLossの reduction

定義では、下記のようにクラスであり、初期化するときの引数でオプションを渡せます。

CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)

  • 重みは、学習データのラベルに偏りがある時に使うものとのことです。学習データが少ないクラスには重みを大きくするのかな。
  • ignore_index は、cross entropy に計算に入れないクラスを指定することができるようです。
  • reduction は mean, sum, none のどれかを文字列で指定します。reduce は廃止予定。

reduction の意味は、動かしてみたら分かった。batch の中のデータについて、Nサンプルをsum up するか mean にするか、各要素を出すか、ですね。

loss_sum = nn.CrossEntropyLoss(reduction = 'sum')
loss_mean = nn.CrossEntropyLoss(reduction = 'mean')
loss_none = nn.CrossEntropyLoss(reduction = 'none')
N = 3 # batch size
K = 5 # number of classes
print(f'Batch size: {N} number of class: {K}')
for idx in range(10):
    y_pred = torch.randn(N, K, requires_grad=True) # predicted score
    y_label = torch.empty(N, dtype=torch.long).random_(K)
    output0 = loss_sum(y_pred, y_label)
    output1 = loss_mean(y_pred, y_label)
    output2 = loss_none(y_pred, y_label)
    print(f'loss sum={output0} mean={output1} none={output2}')

以下を得ました。

Batch size: 3 number of class: 5
inputs <class 'torch.Tensor'> torch.Size([3, 5]) dtype=torch.float32
target <class 'torch.Tensor'> torch.Size([3, 5]) dtype=torch.float32
output=1.7039464712142944 dtype=torch.float32 shape=torch.Size([])
Batch size: 3 number of class: 5
loss sum=7.966089248657227 mean=2.655363082885742 none=tensor([2.5478, 2.9924, 2.4258], grad_fn=<NllLossBackward0>)
loss sum=3.4308128356933594 mean=1.1436042785644531 none=tensor([0.6526, 1.4271, 1.3511], grad_fn=<NllLossBackward0>)
loss sum=8.609332084655762 mean=2.8697774410247803 none=tensor([2.2606, 3.8158, 2.5330], grad_fn=<NllLossBackward0>)
loss sum=6.807758331298828 mean=2.2692527770996094 none=tensor([1.5690, 3.1023, 2.1365], grad_fn=<NllLossBackward0>)
loss sum=4.279692649841309 mean=1.4265642166137695 none=tensor([1.9494, 1.7896, 0.5407], grad_fn=<NllLossBackward0>)
loss sum=5.896562576293945 mean=1.9655208587646484 none=tensor([2.9182, 1.9205, 1.0579], grad_fn=<NllLossBackward0>)
loss sum=5.488359451293945 mean=1.8294531106948853 none=tensor([2.7735, 0.8015, 1.9134], grad_fn=<NllLossBackward0>)
loss sum=4.6195173263549805 mean=1.5398391485214233 none=tensor([1.5718, 1.8436, 1.2041], grad_fn=<NllLossBackward0>)
loss sum=6.097378253936768 mean=2.0324594974517822 none=tensor([0.4906, 2.2456, 3.3612], grad_fn=<NllLossBackward0>)
loss sum=4.015186786651611 mean=1.338395595550537 none=tensor([1.3958, 1.5027, 1.1166], grad_fn=<NllLossBackward0>)

まとめ

とりあえず、分かった気がしました。いよいよ分類タスクのAI学習ができるようになるかな。。。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?