3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

記事投稿キャンペーン 「2024年!初アウトプットをしよう」

バイナリクロスエントロピーロスの逆伝搬

Last updated at Posted at 2024-01-20

バイナリクロスエントロピーロス(BCE Loss)とは

2値分類に用いられるロス関数です.
例えば,メールの文章からスパムかどうかを判定するタスクなどで使用できます.
このロスを使用することで,スパムなら1を,そうでないなら0を出力するようにモデルを学習させることができます.

BCE Lossの詳細

以下に,定義式を載せます.

L(y, \hat{y}) = -[\ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\ ]

ここで,$y$は正解ラベル(0 or 1), $\hat{y}$は予測ラベル([0,1])です.

予測ラベルの出力

ここでは,シンプルな線形回帰モデルを使用します.

5次元の特徴量$x$があるとします.
この特徴量$x$を以下の式で1次元に変換します.

o = Wx + b

ここで,$W$は1×5次元の重み行列,$b$は1次元のバイアスです.

この出力$o$をシグモイド関数で$[0,1]$に変換します.

\hat{y} = sigmoid(o) = \frac{1}{1 + e^{-o}} 

BCE Lossの逆伝搬

本題に入ります.

ここでは,$\frac{\partial{L}}{\partial{o}}$を求めます.
本題から逸れるため,$\frac{\partial{L}}{\partial{o}}$以降の逆伝搬$\bigl(\frac{\partial{L}}{\partial{W}}$と$\frac{\partial{L}}{\partial{b}}\bigr)$は求めません.

連鎖律より

\frac{\partial{L}}{\partial{o}} = \frac{\partial{L}}{\partial{\hat{y}}} \cdot \frac{\partial{\hat{y}}}{\partial{o}}

が導かれます.

$\frac{\partial{L}}{\partial{\hat{y}}}$と$\frac{\partial{\hat{y}}}{\partial{o}}$を順に求めます.

L(y, \hat{y}) = -[\ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})\ ]

であったので,

\begin{align}
\frac{\partial{L}}{\partial{\hat{y}}} &= - \frac{y}{\hat{y}} + \frac{1-y}{1-\hat{y}} \\
&= \frac{\hat{y}-y}{\hat{y}(1-\hat{y})}
\end{align}

となります.

\hat{y} = sigmoid(o) = \frac{1}{1 + e^{-o}}

であったので,

\begin{align}
\frac{\partial{\hat{y}}}{\partial{o}} &= - (1+e^{-o})^{-2} \cdot -e^{-o} \\
&=  \frac{e^{-o}}{(1+e^{-o})^{2}}
\end{align}

ここでちょっとしたテクニックを使用します.

\begin{align}
\frac{e^{-o}}{(1+e^{-o})^{2}} &= \frac{e^{-o}}{1+e^{-o}} \cdot \frac{1}{1+e^{-o}} \\
&= \frac{(1+e^{-o})-1}{1+e^{-o}} \cdot \frac{1}{1+e^{-o}} \\
&= \Bigl(1-\frac{1}{1+e^{-o}}\Bigr)\cdot \frac{1}{1+e^{-o}} \\
&= (1-\hat{y})\cdot\hat{y}
\end{align}

つまり,

\frac{\partial{\hat{y}}}{\partial{o}} = (1-\hat{y})\cdot\hat{y}

となります.
これは,シグモイド関数の重要な性質で,$\frac{\partial{\hat{y}}}{\partial{o}}$は$o$を用いずに$\hat{y}$のみで計算が可能であることがわかります.

以上をまとめると,

\frac{\partial{L}}{\partial{o}} = \frac{\partial{L}}{\partial{\hat{y}}} \cdot\frac{\partial{\hat{y}}}{\partial{o}}
\frac{\partial{L}}{\partial{\hat{y}}} = \frac{\hat{y}-y}{\hat{y}(1-\hat{y})}
\frac{\partial{\hat{y}}}{\partial{o}} = (1-\hat{y})\cdot\hat{y}

これより,

\begin{align}
\frac{\partial{L}}{\partial{o}} &=  \frac{\hat{y}-y}{\hat{y}(1-\hat{y})}\cdot  (1-\hat{y})\cdot\hat{y} \\
&= \hat{y}-y
\end{align}

とキレイな結果になります.

つまり, シグモイド関数とバイナリクロスエントロピーロスを使用することで,予測ラベルと正解ラベルの差を逆伝搬できるということです.

おまけ1:PyTorchで試す

import torch

torch.manual_seed(42)

# 入力特徴量
x_1 = torch.tensor([0.1, 0.5, -0.2, 1.2, 0.3])
x_2 = torch.tensor([0.7, -0.2, -0.4, 0.8, -1.0])

# 正解ラベル
y_1 = 1 
y_2 = 0

# 重み行列Wを初期化
W = torch.randn(1, 5, requires_grad=True) # 1x5次元

epoch = 30
learning_rate = 0.1

for i in range(epoch):
    
    print(f"epoch: {i+1}")

    ### 順伝播 ###
    o_1 = W @ x_1
    o_2 = W @ x_2
    
    y_hat_1 = torch.sigmoid(o_1)
    y_hat_2 = torch.sigmoid(o_2)

    print('予測ラベル(y_hat_1):', f'{y_hat_1.item():0.3f}', '正解ラベル(y_1):', y_1)
    print('予測ラベル(y_hat_2):', f'{y_hat_2.item():0.3f}', '正解ラベル(y_2):', y_2)
        

    ### 逆伝播 ###

    # バイナリクロスエントロピーロス
    loss_1 = -(y_1 * torch.log(y_hat_1) + (1 - y_1) * torch.log(1 - y_hat_1))
    loss_2 = - (y_2 * torch.log(y_hat_2) + (1 - y_2) * torch.log(1 - y_hat_2))
    
    loss = loss_1 + loss_2
    loss.backward() # W.gradに勾配が入る
    
    # 勾配を使ってパラメータを更新
    with torch.no_grad():
        W -= learning_rate * W.grad

    # 勾配の初期化
    W.grad.zero_()
    
    print('\n')
出力(予測ラベルが正解ラベルに近づいていることが確認できます.)

epoch: 1
予測ラベル(y_hat_1): 0.498 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.806 正解ラベル(y_2): 0

epoch: 2
予測ラベル(y_hat_1): 0.506 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.781 正解ラベル(y_2): 0

epoch: 3
予測ラベル(y_hat_1): 0.515 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.755 正解ラベル(y_2): 0

epoch: 4
予測ラベル(y_hat_1): 0.524 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.728 正解ラベル(y_2): 0

epoch: 5
予測ラベル(y_hat_1): 0.533 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.700 正解ラベル(y_2): 0

epoch: 6
予測ラベル(y_hat_1): 0.542 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.672 正解ラベル(y_2): 0

epoch: 7
予測ラベル(y_hat_1): 0.551 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.644 正解ラベル(y_2): 0

epoch: 8
予測ラベル(y_hat_1): 0.560 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.617 正解ラベル(y_2): 0

epoch: 9
予測ラベル(y_hat_1): 0.569 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.590 正解ラベル(y_2): 0

epoch: 10
予測ラベル(y_hat_1): 0.578 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.564 正解ラベル(y_2): 0

epoch: 11
予測ラベル(y_hat_1): 0.587 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.539 正解ラベル(y_2): 0

epoch: 12
予測ラベル(y_hat_1): 0.596 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.515 正解ラベル(y_2): 0

epoch: 13
予測ラベル(y_hat_1): 0.605 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.492 正解ラベル(y_2): 0

epoch: 14
予測ラベル(y_hat_1): 0.614 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.470 正解ラベル(y_2): 0

epoch: 15
予測ラベル(y_hat_1): 0.622 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.450 正解ラベル(y_2): 0

epoch: 16
予測ラベル(y_hat_1): 0.631 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.431 正解ラベル(y_2): 0

epoch: 17
予測ラベル(y_hat_1): 0.640 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.412 正解ラベル(y_2): 0

epoch: 18
予測ラベル(y_hat_1): 0.648 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.396 正解ラベル(y_2): 0

epoch: 19
予測ラベル(y_hat_1): 0.656 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.380 正解ラベル(y_2): 0

epoch: 20
予測ラベル(y_hat_1): 0.664 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.365 正解ラベル(y_2): 0

epoch: 21
予測ラベル(y_hat_1): 0.672 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.351 正解ラベル(y_2): 0

epoch: 22
予測ラベル(y_hat_1): 0.680 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.337 正解ラベル(y_2): 0

epoch: 23
予測ラベル(y_hat_1): 0.687 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.325 正解ラベル(y_2): 0

epoch: 24
予測ラベル(y_hat_1): 0.695 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.313 正解ラベル(y_2): 0

epoch: 25
予測ラベル(y_hat_1): 0.702 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.302 正解ラベル(y_2): 0

epoch: 26
予測ラベル(y_hat_1): 0.709 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.292 正解ラベル(y_2): 0

epoch: 27
予測ラベル(y_hat_1): 0.715 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.282 正解ラベル(y_2): 0

epoch: 28
予測ラベル(y_hat_1): 0.722 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.273 正解ラベル(y_2): 0

epoch: 29
予測ラベル(y_hat_1): 0.728 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.265 正解ラベル(y_2): 0

epoch: 30
予測ラベル(y_hat_1): 0.734 正解ラベル(y_1): 1
予測ラベル(y_hat_2): 0.256 正解ラベル(y_2): 0

おまけ2:PyTorchの便利な関数

BCELoss...シグモイド関数で[0,1]に変換した後に,適用するものです.
BCEWithLogitsLoss...シグモイド関数とBCELossを合体させたものです.シグモイド関数 + BCE Lossよりも数値的に安定するらしいので,こちらを使いましょう.

参考

シグモイド関数の微分

ゼロから作るDeep Learning―Pythonで学ぶディープラーニングの理論と実装

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?