LoginSignup
1
2

Softmax+多クラス分類の感覚をつかむために基礎からちゃんと調べてみた(交差エントロピー損失)

Posted at

ディープラーニングでソフトマックの多クラス分類って結局何を最小化してるんだろ?と思って調べてみたら結構なボリュームになったので記事にしました。
ちゃんと理解しようとするとなかなか難易度の高い印象です。(専門家ではないので厳密性に欠ける点はご容赦ください)

TL;DR

・Softmax+交差エントロピーの最小化は最終結果がすごくシンプル(なので計算が早い)
・実際の内容はKLダイバージェンスの最小化

ソフトマックスによる多クラス分類について

いわゆるディープラーニングで多クラスに分類する場合を想定しています。
例えばある画像が与えられた場合、[犬:10%, 猿:10%, キジ: 80%] みたいに予測する場合になります。

分類方法としては、Softmax関数で出力を値(0~1)にした後に交差エントロピー(Cross Entropy)を最小化することで学習します。
これがどういう意味を持つかを直感的に把握したかったのでので調べてみました。

まずは前提となる概念から説明していきます。

情報量

情報量とはある事象が(確率的に)どのぐらい起こりにくいかを表した尺度です。
ある事象が起こる確率が $p$ の場合の情報量 $I$ は以下で定義されます。

$$I=\log_2{\frac{1}{p}}=-\log_2{p}$$

logの底ですが一般的には2が使われます。(情報数学では2ですが、機械学習では$e$の方が都合がいい場合が多いです)
例として10%と80%で起こる事象を考えます。
10%の情報量が約3.3($-\log_2(0.1)$)、80%が約0.3($-\log_2(0.8)$)となり、起こりにくい事象ほど情報量が多いことが分かります。
(確率0%で∞になります。)

エントロピー(平均情報量)

一般的にエントロピーというとこの平均情報量の事を指します。
エントロピーを簡単に言うと「予測の難しさ」で予測が難しいほど値が高くなります。(他には「乱雑さ」「不規則さ」「無秩序さ」などと表現したり)

数式は以下で、「自分自身が起こる確率×自分自身の情報量」の期待値となります。

$$H(P) = - \sum_{i=1}^n P(x_i) \log_2{P(x_i)}$$

$P$はある複数の事象が起こる確率を表しています。$P=P(x_1),P(x_2),P(x_3)...$
また$P$は全事象を表しており、合計すると1になります。 $P(x_1)+P(x_2)+P(x_3)...=1$

例として普通のコイン(表:50%、裏:50%)と歪んだコイン(表:10%、裏90%)を考えます。
前者のコインは予測が難しいのでエントロピーが高く、後者は予測しやすいのでエントロピーが低くなります。

\begin{aligned}
H(普通のコイン) &= -(0.5 \times \log_2{0.5} + 0.5 \times \log_2{0.5}) = 1\\
H(歪んだコイン) &= -(0.1 \times \log_2{0.1} + 0.9 \times \log_2{0.9}) \fallingdotseq 0.47\\
\end{aligned}

歪んだコインの方がエントロピーが低いのが数字でもわかります。

KLダイバージェンス

これは2つの確率分布の違いを計る尺度です。1
事象PとQのKLダイバージェンスは以下です。

$$
D_{KL}(P||Q) = \sum_i P(x_i) \log{\frac{P(x_i)}{Q(x_i)}}
$$

これは事象PとQが(確率分布的に)どのぐらい離れているかを示しています。
PとQが同じ場合0になり、PとQの差が大きいほど値が増えていきます。

上記の例で普通のコイン同士と「普通のコイン、歪んだコイン」を見てみます。

\begin{aligned}
D_{KL}(普通のコイン||普通のコイン) &= 0.5 \log{\frac{0.5}{0.5}} + 0.5 \log{\frac{0.5}{0.5}} \\
&= 0\\
D_{KL}(普通のコイン||歪んだコイン) &= 0.5 \log{\frac{0.5}{0.1}} + 0.5 \log{\frac{0.5}{0.9}} \\
&\fallingdotseq 0.51\\
\end{aligned}

PとQが違うほど値が大きくなります。
機械学習的に見れば、この値を0近づけるように学習すれば、PとQが同じ確率分布になるように学習できます。

交差エントロピー

交差エントロピーは以下で、「自分自身が起こる確率×相手側の情報量」の期待値となります。

$$H(P,Q) = - \sum_{i} P(x_i) \log{Q(x_i)}$$

この解釈がなかなか難しくて…
いきなりですが、KLダイバージェンスを以下の通り式変形します。

\begin{aligned}
D_{KL}(P||Q) &= \sum_i P(x_i)(\log{P(x_i) - \log{Q(x_i)}} )\\
&= \sum_i P(x_i)\log{P(x_i) - P(x_i) \log{Q(x_i)}}\\
&= -H(P) + H(P,Q)\\
&= 交差エントロピー - エントロピー\\
\end{aligned}

これより「交差エントロピー = KLダイバージェンス + エントロピー」となります。
ちなみに交差エントロピーの最小値は(KLダイバージェンスが0なので)エントロピーと同じです。

・参考
情報量(Wikipedia)
交差エントロピー(Wikipedia)
カルバック・ライブラー情報量(Wikipedia)
平均情報量:https://www.mnc.toho-u.ac.jp/v-lab/yobology/entropy/entropy.htm
エントロピー、交差エントロピー、KLダイバージェンス:https://kleinblog.net/math-entropy
交差エントロピーの導出: https://nihaoshijie.hatenadiary.jp/entry/2017/04/26/062304
交差エントロピーを理解してみる:https://yaju3d.hatenablog.jp/entry/2018/11/30/225841

交差エントロピーの最小化

交差エントロピー(P,Q)は「KLダイバージェンス(P,Q)」+「Pのエントロピー」です。
一般的にはPが正解データでQが学習する確率分布と考えるので以下となります。

  • Pの確率分布にQの確率分布を近づける
  • (Pのエントロピーは固定)

エントロピーは固定なので実質的にはKLダイバージェンスの最小化(0にする)と同義です。

参考までに、逆の場合だとエントロピーも最小化するので意味が変わってきます。(次の意味になります)

  • Qの確率分布にPの確率分布を近づける
  • Pのエントロピーが小さくする(Pの結果を予測しやすくする)

さて、ここでPのエントロピーは固定なのでいらないのでは?(KLダイバージェンスだけでいいのでは)という疑問が私にわきました。
どうやらSoftmax+交差エントロピーがすごく相性がいいようです。

P,QがそれぞれSoftmax後の確率分布である事を考えます。

$$Softmax(P_i) = \frac{\exp{P(x_i)}}{\sum_j \exp{P(x_j)} }$$

これを交差エントロピーに代入します。

$$H(P,Q) = - \sum_{i} Softmax(P_i) \log{ Softmax(Q_i)}$$

どうやらこの微分がとても数学的に簡単になるらしく「出力データ - 教師データ」になるそうです。
計算は弱いので実際の計算は以下の参考サイトをどうぞ。

・参考
多クラス交差エントロピー誤差関数とソフトマックス関数,その美しき微分: https://qiita.com/klis/items/4ad3032d02ff815e09e6
めちゃめちゃ丁寧にソフトマックス回帰での交差エントロピー誤差を微分する: https://qiita.com/chersky/items/dc85bdf18609eaeb5244
Softmax-with-Lossレイヤの逆伝播の導出【ゼロつく1のノート(数学)】: https://www.anarchive-beta.com/entry/2020/08/06/180000#google_vignette

要するにまとめると以下です。
・Softmax+交差エントロピーは結果の計算がすごくシンプルになる(相性が良い)
・実際の内容はKLダイバージェンスの最小化と同じ

比較

CrossEntropy VS MSE

前から思っていた内容で、Softmaxの学習は交差エントロピーの最小化ではなくMSE(平均二乗誤差)でもいいのでは?と思っていたので比較してみました。

交差エントロピーは確率分布自体を近づけるのでなめらかに学習し、MSEは値自体を学習するので乱数でデータがぶれますが最終的には期待値として確率を学習できるので問題ないのでは?と予測します。

データセットは0~9までの乱数を入力し、整数部分を正解データとしました。(例えばx=1.23の場合、y=1)

コード(Torchです)
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

CLASS_NUM = 10


def create_dataset(size: int):
    x = np.random.uniform(0, CLASS_NUM - 1, size)
    y = np.round(x).astype(np.int8)
    y = np.identity(CLASS_NUM)[y]  # onehot
    return x[..., np.newaxis], y


x, y = create_dataset(3)
print(x)
print(y)


class MyModel(nn.Module):
    def __init__(self, in_size: int, out_size: int, add_softmax: bool):
        super().__init__()
        self.h_layers = nn.ModuleList(
            [
                nn.Linear(in_size, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, out_size),
            ]
        )
        if add_softmax:
            self.h_layers.append(nn.Softmax(-1))

    def forward(self, x):
        for h in self.h_layers:
            x = h(x)
        return x


def _run(train_x, train_y, val_x, val_y, criterion, add_softmax):

    # --- dataset
    train_loader = DataLoader(
        dataset=TensorDataset(torch.Tensor(train_x), torch.Tensor(train_y)),
        batch_size=32,
        shuffle=True,
    )
    val_loader = DataLoader(
        dataset=TensorDataset(torch.Tensor(val_x), torch.Tensor(val_y)),
        batch_size=32,
        shuffle=False,
    )

    # --- model
    model = MyModel(1, CLASS_NUM, add_softmax)
    optimizer = optim.Adam(model.parameters(), lr=0.0004)

    print(model)

    # --- train
    t0 = time.time()
    epochs = 200
    history = []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        running_loss /= len(train_loader)

        # --- acc metric
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                if not add_softmax:
                    outputs = torch.softmax(outputs, -1)
                _, predicted = torch.max(outputs, 1)
                _, labels = torch.max(labels, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
        accuracy = 100 * correct / total
        history.append(accuracy)

        print(f"Epoch {epoch + 1}, Loss: {running_loss}, Validation Accuracy: {accuracy}%")
    print("Finished Training", time.time() - t0)
    return history


def main():

    train_x, train_y = create_dataset(1000)
    val_x, val_y = create_dataset(10000)

    history = _run(
        train_x,
        train_y,
        val_x,
        val_y,
        nn.CrossEntropyLoss(),
        add_softmax=False,
    )
    plt.plot(history, label="CrossEntropyLoss")

    history = _run(
        train_x,
        train_y,
        val_x,
        val_y,
        nn.MSELoss(),
        add_softmax=True,
    )
    plt.plot(history, label="MSELoss")

    history = _run(
        train_x,
        train_y,
        val_x,
        val_y,
        nn.L1Loss(),
        add_softmax=True,
    )
    plt.plot(history, label="MAELoss")

    plt.xlabel("epochs")
    plt.ylabel("Validation Accuracy(%)")
    plt.grid()
    plt.legend()
    plt.show()


main()

Figure_1.png

Time
CrossEntropyLoss: 33.56197547912598s
MSELoss         : 32.66430497169494s
MAELoss         : 32.95515370368957s

やはり確率分布自体を学習する交差エントロピーのほうが優秀ですね。
ただ多分MSEは期待値を学習しているので、時間をかければ学習できそうな気がします(?)。
MAE(平均絶対誤差)はおまけで入れましたが、途中から学習できていませんね…。挙動はMSEと大きく変わらなそうな気がしますがこれだけ違うのは何ででしょうかね。

CrossEntropy VS BinaryCrossEntropy

同じく、2値分類においてSoftmaxによる交差エントロピーか、sigmoidによる交差エントロピー(BinaryCrossEntropy)はどちらにしようか迷う場合があると思いますので見てみました。
sigmoidは出力が0~1の範囲であることを利用し、そのまま確率を出力する方法です。
(前者は正解ラベルが[0,1]or[1,0]の2次元に対し、後者は[0]or[1]と1次元になります)

多分結果は変わらないはず…。

データセットは0~1までの乱数を入力し、0.5以下なら0、そうでないなら1を正解データとしました。

コード(Torchです)
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, TensorDataset


def create_dataset(size: int):
    x = np.random.uniform(0, 1, size)
    y = np.where(x <= 0.5, 0, 1).astype(np.int8)
    y_onehot = np.identity(2)[y]
    return x[..., np.newaxis], y[..., np.newaxis], y_onehot


x, y, y_onehot = create_dataset(3)
print(x)
print(y)
print(y_onehot)


class MyModel(nn.Module):
    def __init__(self, in_size: int, out_size: int):
        super().__init__()
        self.h_layers = nn.ModuleList(
            [
                nn.Linear(in_size, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, out_size),
            ]
        )

    def forward(self, x):
        for h in self.h_layers:
            x = h(x)
        return x


def _run(train_x, train_y, val_x, val_y, criterion, out_size):

    # --- dataset
    train_loader = DataLoader(
        dataset=TensorDataset(torch.Tensor(train_x), torch.Tensor(train_y)),
        batch_size=32,
        shuffle=True,
    )
    val_loader = DataLoader(
        dataset=TensorDataset(torch.Tensor(val_x), torch.Tensor(val_y)),
        batch_size=32,
        shuffle=False,
    )

    # --- model
    model = MyModel(in_size=1, out_size=out_size)
    optimizer = optim.Adam(model.parameters(), lr=0.00001)

    # --- train
    t0 = time.time()
    epochs = 200
    history = []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        running_loss /= len(train_loader)

        # --- acc metric
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images)
                if out_size == 2:
                    # CrossEntropy
                    outputs = torch.softmax(outputs, -1)
                    _, predicted = torch.max(outputs, 1)  # 最大値とする
                    _, labels = torch.max(labels, 1)
                    correct += (predicted == labels).sum().item()
                else:
                    # BinaryCrossEntropy
                    outputs = torch.sigmoid(outputs)
                    predicted = (outputs > 0.5).float()
                    correct += (predicted == labels.float().view_as(predicted)).sum().item()
                total += labels.size(0)
        accuracy = 100 * correct / total
        history.append(accuracy)

        print(f"Epoch {epoch + 1}, Loss: {running_loss}, Validation Accuracy: {accuracy}%")
    print("Finished Training", time.time() - t0)
    return history


def main():
    train_x, train_y, train_y_onehot = create_dataset(1000)
    val_x, val_y, val_y_onehot = create_dataset(10000)

    history = _run(train_x, train_y_onehot, val_x, val_y_onehot, nn.CrossEntropyLoss(), 2)
    plt.plot(history, label="CrossEntropyLoss")

    history = _run(train_x, train_y, val_x, val_y, nn.BCEWithLogitsLoss(), 1)
    plt.plot(history, label="BCEWithLogitsLoss")

    plt.xlabel("epochs")
    plt.ylabel("Validation Accuracy(%)")
    plt.grid()
    plt.legend()
    plt.show()


main()

Figure_2.png

Time
CrossEntropyLoss : 31.034955978393555s
BCEWithLogitsLoss: 23.256561279296875s

精度はほぼ変わらないとみていい気はします。(CrossEntropyの方が早く見えますが、学習率をかなり下げて見ています)
ただ速度ですが、BCEWithLogitsLossの方がかなり早く感じました。(8秒ほど早いですね)

多分 softmax+交差エントロピー より sigmoid+バイナリクロスエントロピー の方がもっと簡単な計算方法があるんでしょうね(調べるの力尽きました)

最後に

やはり適切な損失関数を選ぶのは重要ですね。
実際の計算はフレームワークが良しなにやってくれるのはほんと助かります。

  1. イメージとしては距離ですが、数学的には距離の公理を満たさないらしく距離ではないらしい

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2