ニューラルネットワークの出力層でよく使われる交差エントロピー誤差が学習速度の観点から重宝される理由を直感的に理解するために、グラフにして可視化してみました。
関数の形をグラフにすることでよく理解できました。結論のグラフだけ見たい人は、この記事の最後から2番目と3番目のグラフを参照。
この記事では $x_1, x_2$ はソフトマックス関数の入力、 $y_1, y_2$ はソフトマックス関数の出力であり損失関数の入力、$E$は損失関数の出力を表すものとします。
グラフの作成にはJupyter Notebook上でmatplotlibを使いました。
ソフトマックス関数のグラフ
まずは、2クラスのソフトマックス関数をグラフにします。クラス1と2があって、入力を $x_1, x_2$ とし、出力を $y_1, y_2$ としたときのグラフです。
\begin{align}
y_1 =& \frac{e^{x_1}}{e^{x_1} + e^{x_2}} = 1 - y_2 \\
=& \frac{1}{1 + e^{x_2-x_1}} \\
y_2 =& \frac{e^{x_2}}{e^{x_1} + e^{x_2}} = 1 - y_1 \\
=& \frac{1}{1 + e^{x_1-x_2}} \\
\end{align}
横軸は$x_1$、奥方向の軸は$x_2$、縦軸は$y_1$です。
$y_2$は $y_2 = 1 - y_1$ ですので、$y_1$を縦方向に反転した形になります。
正解ラベルが1、つまりソフトマックス関数の出力 $y_1=1, y_2=0$ は右下のほうの $y_1$ が1に近いエリアです。$x_1$が大きく$x_2$が小さいほど正解に近づきます。
少しだけ反時計回りに回転させたグラフも載せておきます。
右方向の軸は$x_1$、左方向の軸は$x_2$、縦軸は$y_1$です。
シグモイド関数を2次元に拡張したようなイメージです。
実際、数式を見れば、ソフトマックス関数のグラフの断面はシグモイド関数と一致することがあきらかです。
シグモイド関数は次のとおりです。
y = \frac{1}{1+e^{-x}}
2乗誤差のグラフ
次に交差エントロピー誤差との比較のために2乗誤差を見てみます。1次元で正解が1のときの2乗誤差は次のようになります。
E = \frac{1}{2}(y - 1)^2
これを2次元に拡張して、正解が $y_1=1, y_2=0$ のときの2乗誤差は次のようになります。
E = \frac{1}{2}((y_1 - 1)^2 + {y_2}^2)
これにソフトマックス関数を組み合わせると次のようになります。
E = {y_2}^2
横軸は$x_1$、奥方向の軸は$x_2$、縦軸は$E$です。
見づらいので少しだけ反時計回りに回転させたグラフも載せておきます。
右方向の軸は$x_1$、左方向の軸は$x_2$、縦軸は$E$です。
右下のほうの正解エリアでは$E$はほぼ0です。
ニューラルネットワークの学習が進むにつれ、右のほうの正解エリアに移動していくことが期待されます。このグラフを見ると真ん中のエリアでは傾きが急なので速く転げ落ちそう(=学習が速く進みそう)ですが、それよりも左のエリアだと傾きが平坦に近く、学習がなかなか進まないです。学習開始時点で偶然大きく左のエリアに入ってしまうと、なかなか抜け出せなくて、ちっとも学習できないのです。2乗誤差の欠点です。
交差エントロピー誤差のグラフ
正解が $y_1=1, y_2=0$ のときの交差エントロピー誤差は$y_1$だけで表せます。
E = -\log{y_1}
横軸は$y_1$、縦軸は誤差$E$です。単に自然対数のグラフを上下反転しただけです。
ソフトマックス関数と交差エントロピー誤差を組み合わせると次のようになります。
\begin{align}
E =& -\log{ \frac{e^{x_1}}{e^{x_1} + e^{x_2}} } \\
=& \log(1 + e^{x_2-x_1})\\
\end{align}
横軸は$x_1$、奥方向の軸は$x_2$、縦軸は$E$です。
右下のほうの正解エリアでは$E$はほぼ0です。逆の左上のほうは正解から遠く、$E$が大きいのですが、2乗誤差と違って、平坦にはならずに傾きが維持されているので、ニューラルネットワークの学習は速いです。ニューラルネットワークの出力層で交差エントロピーがよく使われる理由のひとつはこの点だと思われます。このグラフを見たかったのがこの記事の目的でした。
少しだけ反時計回りに回転させたグラフも載せておきます。
右方向の軸は$x_1$、左方向の軸は$x_2$、縦軸は$E$です。
ここからはおまけです。グラフを見て気が付いたのですが、断面はReLU(rectified linear)に似てますね。実際、数式を見れば、ReLUを滑らかにしたSoftplusという活性化関数と一致することがあきらかです。
Softplus関数は次のとおりです。
y = \log(1 + e^x)
このことから、中間層では学習速度の問題でシグモイド関数の代わりにReLUやSoftplusなどが使われますが、同じ理由だということがわかりますね。
なお、ソフトマックス関数とクロスエントロピー誤差の組み合わせの偏微分は簡単な形になります。
\begin{align}
E =& \log(1 + e^{x_2-x_1})\\
\frac{\partial E}{\partial x_1} =& \frac{-e^{x_2-x_1}}{1 + e^{x_2-x_1}} \\
=& -\frac{e^{x_2}}{e^{x_1}+e^{x_2}} \\
=& -y_2 \\
=& y_1 - 1 \\
\frac{\partial E}{\partial x_2} =& \frac{e^{x_2-x_1}}{1 + e^{x_2-x_1}} \\
=& \frac{e^{x_2}}{e^{x_1}+e^{x_2}} \\
=& y_2 \\
\end{align}
グラフと見比べてみても、確かになんとなくこうなりそうです。
以上。