機械学習
MachineLearning
DeepLearning
深層学習

多クラス交差エントロピー誤差関数とソフトマックス関数,その美しき微分

はじめに

多クラス交差エントロピーはcategorical cross entropyとも呼ばれます.
実際に,深層学習フレームワークのkerasではcategorical_crossentropyという名前が使われています.
分類問題などに取り組む際,入力をソフトマックス関数に通して多クラス交差エントロピーをロス関数にすることは多いのではないでしょうか.
今回はこのソフトマックス関数+多クラス交差エントロピー誤差関数をソフトマックス関数の入力で微分します.
本稿はDeep Learning本の式10.18の行間埋めです.
このため,時刻$t$という添字が入っています.
一般的に考えると,$t$はデータのインデックスに対応するはずです.

定義

$t$番目の学習データのラベルの1-of-K表現を$\mathbf{y^{(t)}}$とします.
また,$t$番目のモデルの出力を$\mathbf{\hat{y}}^{(t)}$とします.
これは何らかのベクトル (LSTMの出力とか) にソフトマックス関数を適用してできたベクトルです.
$$
\begin{eqnarray}
\hat{\mathbf{y}}^{(t)} = softmax(\mathbf{o}^{(t)})
\end{eqnarray}
$$

ソフトマックス関数とは,指数関数と正規化を組み合わせた関数で,出力されたベクトルの要素の総和が1になります.
cを各クラスを表す変数として,ソフトマックス関数のi番目の要素は以下のように表されます.
$$
\begin{eqnarray}
{softmax(\mathbf{o}^{(t)})}_i = \frac{\exp(o^{(t)}_i)}{\sum_{c} \exp(o^{(t)}_{c})}
\end{eqnarray}
$$

多クラス交差エントロピー誤差関数は$\mathbf{y}^{(t)}$と$\hat{\mathbf{y}}^{(t)}$を用いて以下のように定義されます.
ここで,cはクラスを表す変数です.
$$
\begin{eqnarray}
L = - \sum_{t} \sum_{c} y^{(t)}_c \log \hat{y}^{(t)}_c
\end{eqnarray}
$$

微分

まず,$L$についてみてみます.
$L$は学習データでイテレーションを回しているので,データのインデックスごとに分けられそうです.
$$
\begin{eqnarray}
L = \sum_t L^{(t)} & (where\ L^{(t)} = - \sum_{c} y^{(t)}_c \log \hat{y}^{(t)}_c)
\end{eqnarray}
$$

入力ベクトル$\mathbf{o}^{(t)}$で微分したいので,連鎖律を使います.
各クラスについての偏微分の総和が出てくることに注意してください.
$$
\begin{eqnarray}
\frac{\partial{L}}{\partial{o^{(t)}_i}} = \frac{\partial{L}}{\partial{L^{(t)}}} \cdot
\bigl(\sum_c \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot
\frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i} \bigr)
\end{eqnarray}
$$

まず,$L$を$L^{(t)}$で微分することを考えます.
$$
\begin{eqnarray}
\frac{\partial L}{\partial L^{(t)}} = \frac{\partial{(L^{(1)} + L^{(2)} + \cdots + L^{(t)} + \dots)}}{\partial L^{(t)}} = 1
\end{eqnarray}
$$

よって
$$
\begin{eqnarray}
\frac{\partial{L}}{\partial{o^{(t)}_i}} =
\sum_c \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot \frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i}
\end{eqnarray}
$$

$\frac{\partial{\hat{y}^{(t)}_c}}{o^{(t)}_i}$について考えます.
$\hat{y}^{(t)}_c = {softmax(\mathbf{o}^{(t)})}_c = \frac{\exp(o^{(t)}_c)}{\sum_{k} \exp(o^{(t)}_{k})}$であり,
$c = i$が成り立つときと成り立たないときで微分の結果が変化します.
$$
\begin{eqnarray}
\frac{\partial \exp (o_c) }{\partial \exp (o_i)} =
\begin{cases}
\exp (o_i) & (c = i) \\
0 & (otherwise)
\end{cases}
\end{eqnarray}
$$
このことから,$\sum_c \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot \frac{\partial{\hat{y}^{(t)}_c}}{o^{(t)}_i}$を$c=i$が成立する項としない項で分けます.
$$
\begin{eqnarray}
\sum_c \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot \frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i} =
\frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{i}}} \cdot \frac{\partial{\hat{y}^{(t)}_i}}{\partial o^{(t)}_i} +
\sum_{c \neq i} \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot \frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i}
\end{eqnarray}
$$

${\frac{f}{y}}^\prime = \frac{f^\prime g - f g^\prime}{g^2}$であることを思い出して微分をしていきます.
$$
\begin{eqnarray}
\frac{\partial{\hat{y}^{(t)}_i}}{\partial o^{(t)}_i} &= \frac{\exp (o_i) \sum_{k} \exp(o^{(t)}_{k}) - \exp (o^{(t)}_i) \exp (o^{(t)}_i)}{{\sum_{k} \exp(o^{(t)}_{k})}^2} \\
&= \hat{y}^{(t)}_i ( 1 - \hat{y}^{(t)}_i ) \\
\frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i} &= \frac{- \exp (o^{(t)}_c) \exp (o^{(t)}_i)}{{\sum_{k} \exp(o^{(t)}_{k})}^2} \\
&= - \hat{y}^{(t)}_c \hat{y}^{(t)}_i \\
\end{eqnarray}
$$

$\frac{\partial L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}}$は$\hat{y}^{(t)}_c$がかかるとき以外は0になります.
$$
\begin{eqnarray}
\frac{\partial L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} &= - \sum_k \frac{\partial y^{(t)}_k \log \hat{y}^{(t)}_k }{\partial{\hat{y}^{(t)}_{c}}} \\
&= - \frac{y^{(t)}_c}{\hat{y}^{(t)}_c}
\end{eqnarray}
$$

以上をまとめると,
$$
\begin{eqnarray}
\sum_c \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot \frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i} &=&
\frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{i}}} \cdot \frac{\partial{\hat{y}^{(t)}_i}}{\partial o^{(t)}_i} +
\sum_{c \neq i} \frac{L^{(t)}}{\partial{\hat{y}^{(t)}_{c}}} \cdot \frac{\partial{\hat{y}^{(t)}_c}}{\partial o^{(t)}_i} \\
&=& - \frac{y^{(t)}_i}{\hat{y}^{(t)}_i} \hat{y}^{(t)}_i ( 1 - \hat{y}^{(t)}_i ) - \sum_{c \neq i} \frac{y^{(t)}_c}{\hat{y}^{(t)}_c} ( - \hat{y}^{(t)}_c \hat{y}^{(t)}_i ) \\
&=& - y^{(t)}_i ( 1 - \hat{y}^{(t)}_i ) + \sum_{c \neq i} y^{(t)}_c \hat{y}^{(t)}_i \\
&=& \sum_{c} y^{(t)}_c \hat{y}^{(t)}_i - y^{(t)}_i \\
&=& \hat{y}^{(t)}_i - y^{(t)}_i (\because \sum_c y^{(t)}_c = 1)
\end{eqnarray}
$$

このように,ソフトマックス関数と多クラス交差エントロピー損失関数のソフトマックス関数の値での微分(?)は正解データと予測値のみで綺麗に書けることがわかります.
Deep Learning本ではindicator functionを使っていますが,本稿では1-of-K符号化と明記したのでindicator functionなしで書いています.
お付き合いいただきありがとうございました.不明な点やミスを見つけた方はどうかご一報ください.

参考

Lecture Note: 試行錯誤中に見つけた.こちらのほうが綺麗にまとまっている...