3
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

めちゃめちゃ丁寧にソフトマックス回帰での交差エントロピー誤差を微分する

Last updated at Posted at 2023-07-07

はじめに

深層学習をイチから勉強しなおそうと思って,ソフトマックス回帰について勉強・実装したので誰かの参考になると幸いです.

対象:
教師あり学習・分類はわかるよ〜な人.
ロジスティック回帰ならわかるよ〜な人.
高校レベルの微分ならわかるよ〜な人.
行列の掛け算ならできるよ〜な人.

ソフトマックス関数とは(1)

ソフトマックス関数は,多クラス分類に良く使われる関数です.

深層学習の出力層なんかでも登場しますね.

こんな形

f(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = \frac{exp(x_i)}{\sum_{k=1}^{K}exp(x_k)} 
i = 1, 2, \cdots, K

ひ〜〜〜
数学アレルギーだった僕にとって見るだけでも非常に辛かったのを覚えてます。。。

そもそも.
どうしてこのような変換をするのか.

それはクラスラベルの確率を知りたいからです.

???

例えば画像データを{犬,猫,鳥}の3クラスに分類するモデルを考えます.

この時,モデルにある画像を入力して次のような出力を得たとします.

\{犬, 猫, 鳥\} = \{0.6, 0.3, 0.1\}

この数値を見ると「あ、この画像は犬の画像なんだ〜」と分かりますよね??

なので,各クラスラベルに対応する確率を得ることが出来れば分類のタスクに使えると.

そしてそれを行うのが先のソフトマックス関数です.

ソフトマックス関数とは(2)

ソフトマックス関数が確率を表せるから,分類に使えるという話をしました.

じゃあ,かなーーーーーり雑に確率ってどんなものかっていうと,足したら「1」になる数値たちのことです.

これも確率です.

\{犬, 猫, 鳥\} = \{0.6, 0.3, 0.1\}

で,ソフトマックス関数は

\sum_{i=1}^{K}f(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = 1

の性質を満たします.

だから,

f(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = \frac{exp(x_i)}{\sum_{k=1}^{K}exp(x_k)} 
i = 1, 2, \cdots, K

はクラスiに属するの確率を表しています.

でもそれなら,指数関数なんぞ使わずに

g(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = \frac{x_i}{\sum_{k=1}^{K}x_k} 
i = 1, 2, \cdots, K

でもいい気がしませんか??(最初僕もそう思いました、だって簡単な方が嬉しいやーん.)

実際

\sum_{i=1}^{K}g(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = 1

ですしね.

でも,今回のソフトマックス関数を使うと非常に良いということがわかります.

下の図をみてください.
softmax.png

青が

g(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = \frac{x_i}{\sum_{k=1}^{K}x_k} 
i = 1, 2, \cdots, K

オレンジが

f(x_1, x_2, \cdots, x_i, \cdots, x_K)_i = \frac{exp(x_i)}{\sum_{k=1}^{K}exp(x_k)} 
i = 1, 2, \cdots, K

です.

見比べてどうでしょう.

大きいところと小さいところの差がより顕著に表れていませんか??

つまり,これまた雑に言うと,画像などのわずかな差に敏感に反応してくれるということです.

てことで,ソフトマックス関数が多クラス分類に対して有用だということを解説しました.

ソフトマックス回帰のアーキテクチャ

クラスiに対する重み付き線形和→ソフトマックス関数というだけです.

softmax_1.png

こんな感じ.

こんな $z_i$ がたくさんあって,それをソフトマックス関数に入力しているというイメージです.

softmax_2.png

見辛い図ですんません。笑
(てか,図デカくね。。。)

これを数式で書き起こすと,

重み付き線形和

z_i = w_{1i}x_1+w_{2i}x_2+\cdots+w_{Ni}x_N+b

ソフトマックス関数

f(z_1, z_2, \cdots, z_K)_i = \frac{exp(z_i)}{\sum_{k=1}^{K}exp(z_k)}

ですね.そしてこれがクラスラベルiの予測値なわけです.

y_i = f(z_1, z_2, \cdots, z_K)_i

とでも書いておきましょうか.

交差エントロピー誤差

交差エントロピーは二つの確率分布間の類似度を測る指標みたいなものです.
この値が大きいと$y$と$t$は似ているね,という意味です.ざっくり言うと.

以下の式.($t$は正解ラベル=教師データ)

E(y_1, y_2, \cdots, y_k) = \sum_{k=1}^{K}t_klog(y_k)

いや・・・こんなの突然出されても。。。

わかりますよ.お気持ち.

そう言う時は実験してみたらいいんです.

じゃあここで問題です.

ある分類問題において,下図のように真の値がクラスラベル4であるとします.
(つまり,ラベル4は100%,他は0%)

softmax_3.png

そして,今モデルから次のような3パターンの出力を得ました.
softmax_4.png

t = np.array([0,0,0,1,0])
y_1 = np.array([0.1, 0.3, 0.1, 0.75, 0.25])
y_2 = np.array([0.2, 0.05, 0.6, 0.05, 0.1])
y_3 = np.array([0.1, 0.1, 0.3, 0.2, 0.3])

どの出力が良さそうですか?

さっと考えてみてください.

いいですか?

いきますよ?

はい.

おそらくpred value 1が良い予測値なのではないでしょうか?

なぜなら,クラスラベル4の確率が一番高いからです.

今は,視覚的に求めましたがこれを定量的に表したものが,交差エントロピーです.

交差エントロピーの値比較

じゃあ,そう言うなら計算してみましょう.

print("true value vs pred value 1: {:.3f}".format((t*np.log(y_1)).sum()))
print("true value vs pred value 2: {:.3f}".format((t*np.log(y_2)).sum()))
print("true value vs pred value 3: {:.3f}".format((t*np.log(y_3)).sum()))
>>
true value vs pred value 1: -0.288
true value vs pred value 2: -2.996
true value vs pred value 3: -1.609

ほらね(ドヤ)

pred value 1が一番良い予測値であることがわかりました.

負の値だから一見わかりづらいですし,誤差っていうくらいだから小さい方が嬉しいので

交差エントロピーが大きいときに似ている,というより

負の交差エントロピーが小さいときに似ていると言う方が嬉しいので,一般に交差エントロピー誤差は次のように書きます.(マイナスつけただけ)

E(y_1, y_2, \cdots, y_K) = -\sum_{k=1}^{K}t_klog(y_k)

各クラスラベルに対する予測値を求めて,上記の式に代入すると誤差が得られます.

コスト関数(交差エントロピーの微分)へ

これで,割と準備が整いました.

教師あり学習の大きな流れを確認すると,

予測値を計算→誤差を計算→微分を計算→パラメータの更新(学習)

でしたね.

今までの話で,予測値を計算(ソフトマックス関数)→誤差を計算(交差エントロピー)をやりました.

それではいよいよ,交差エントロピーのパラメータについての微分をしていきましょう.

交差エントロピーのパラメータについての微分

求めたいもの: 交差エントロピーの全パラメータでの微分(値)

とりあえず,いっぺんに求めるのは大変なので
(行列単位で微分できる方はその限りではないかもですが。笑)

下図のような$w_{ni}$による微分を求めましょうか.

\frac{\partial E(y_1, y_2, \cdots, y_K)}{\partial w_{ni}}

softmax_5.png

ここで,簡単に連鎖律の公式を導入します(証明などは割愛)

連鎖率

まず,誤解を恐れずに言うと連鎖律とはネストされてるパラメータによる微分を楽に行えるものです.

こんな感じ.

f(u(x), v(x), w(x))

のように,fはu,v,wの関数で,u,v,wはxの関数であるとします.
この時,fのxによる微分は

\frac{\partial f}{\partial x} = \frac{\partial f}{\partial u}\frac{\partial u}{\partial x}+\frac{\partial f}{\partial v}\frac{\partial v}{\partial x}+\frac{\partial f}{\partial w}\frac{\partial w}{\partial x}

となります.一般に

f(y_1(x), y_2(x), \cdots, y_n(x))

である時,

\frac{\partial f}{\partial x} = \sum_{i=1}^{n}\frac{\partial f}{\partial y_i}\frac{\partial y_i}{\partial x}

です.

なぜこれを導入したかというと,ネストを解いて計算してもいいんですが結構大変なんですよね.

おそらく.

今回の微分したい関数をまじめに書くと次のようになってます.

E(y_1, y_2, \cdots, y_K) = E(f(z_1, z_2, \cdots, z_K)_1, f(z_1, z_2, \cdots, z_K)_2, \cdots, f(z_1, z_2, \cdots, z_K)_K)
E(f(z_1, z_2, \cdots, z_K)_1, f(z_1, z_2, \cdots, z_K)_2, \cdots, f(z_1, z_2, \cdots, z_K)_K)
=E(f(z_1(w_{11}, w_{21}, \cdots, w_{N1}), z_2(w_{12}, w_{22}, \cdots, w_{N2}), \cdots, z_K(w_{12}, w_{22}, \cdots, w_{N2}))_1\cdots 

書きたくないので諦めました.笑

というわけで連鎖律を用いて書き換えてみましょうか.

連鎖率を用いて微分を書いてみる.

\frac{\partial E}{\partial w_{ni}}=\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}\frac{\partial y_l}{\partial w_{ni}}
y_l = f(z_1, z_2, \cdots, z_K)_l

なので,

\frac{\partial y_l}{\partial w_{ni}}=\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}}

よって,

\frac{\partial E}{\partial w_{ni}}=\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}(\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}})

やや雑な表記かもですが,まぁこんな感じです.

地道に微分してきます〜

まずは,

\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}}

から.

\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}}
=\frac{\partial y_l}{\partial z_{1}}\frac{\partial z_1}{\partial w_{ni}}+\frac{\partial y_l}{\partial z_{2}}\frac{\partial z_2}{\partial w_{ni}}+\cdots+\frac{\partial y_l}{\partial z_{i}}\frac{\partial z_i}{\partial w_{ni}}+\cdots+\frac{\partial y_l}{\partial z_{K}}\frac{\partial z_K}{\partial w_{ni}}

ここで注目してほしいのが

\frac{\partial y_l}{\partial z_{i}}\frac{\partial z_i}{\partial w_{ni}}

の項.

これは,

\frac{\partial y_l}{\partial z_{i}}\frac{\partial z_i}{\partial w_{ni}}=\frac{\partial y_l}{\partial z_{i}}\frac{\partial}{\partial w_{ni}}(w_{1i}x_1+w_{2i}x_2+\cdots+w_{ni}x_n+\cdots+w_{Ni}x_N)

なので,$w_{ni}$以外の項は微分したら0だから

\frac{\partial y_l}{\partial z_{i}}\frac{\partial z_i}{\partial w_{ni}}=\frac{\partial y_l}{\partial z_{i}}x_n

ですね.
他の項もみてみると,

\frac{\partial y_l}{\partial z_{1}}\frac{\partial z_1}{\partial w_{ni}}=\frac{\partial y_l}{\partial z_{1}}\frac{\partial}{\partial w_{ni}}(w_{11}x_1+w_{21}x_2+\cdots+w_{n1}x_n+\cdots+w_{N1}x_N)

なので,全部0になります.
つまり,

\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}}
=\frac{\partial y_l}{\partial z_{1}}\frac{\partial z_1}{\partial w_{ni}}+\frac{\partial y_l}{\partial z_{2}}\frac{\partial z_2}{\partial w_{ni}}+\cdots+\frac{\partial y_l}{\partial z_{i}}\frac{\partial z_i}{\partial w_{ni}}+\cdots+\frac{\partial y_l}{\partial z_{K}}\frac{\partial z_K}{\partial w_{ni}}
=0+0+\cdots+\frac{\partial y_l}{\partial z_{i}}x_n+\cdots+0=\frac{\partial y_l}{\partial z_{i}}x_n

です.

元の式を思い出すと,

\frac{\partial E}{\partial w_{ni}}=\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}(\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}})
=\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}(\frac{\partial y_l}{\partial z_{i}}x_n)

じゃああとは,$x_n$は一旦放っておいて,

\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}\frac{\partial y_l}{\partial z_{i}}

を求めればいいですね.

\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}\frac{\partial y_l}{\partial z_{i}}=\frac{\partial E}{\partial y_1}\frac{\partial y_1}{\partial z_{i}}+\frac{\partial E}{\partial y_1}\frac{\partial y_1}{\partial z_{i}}+\cdots+\frac{\partial E}{\partial y_i}\frac{\partial y_i}{\partial z_{i}}+\cdots+\frac{\partial E}{\partial y_K}\frac{\partial y_K}{\partial z_{i}}

であり,

\frac{\partial E}{\partial y_l} = -\frac{\partial}{\partial y_l}(\sum_{k=1}^{K}t_klog(y_k))
=-\frac{\partial}{\partial y_l}(t_1log(y_1)+t_2log(y_2)+\cdots+t_llog(y_l)+\cdots+t_Klog(y_K))

logの微分より,

-\frac{\partial}{\partial y_i}(t_1log(y_1)+t_2log(y_2)+\cdots+t_llog(y_l)+\cdots+t_Klog(y_K))
=-(0+0+\cdots+\frac{t_l}{y_l}+\cdots+0)=-\frac{t_l}{y_l}

なので,

\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}\frac{\partial y_l}{\partial z_{i}}=-\frac{t_1}{y_1}\frac{\partial y_1}{\partial z_{i}}-\frac{t_2}{y_2}\frac{\partial y_1}{\partial z_{i}}-\cdots-\frac{t_i}{y_i}\frac{\partial y_i}{\partial z_{i}}-\cdots-\frac{t_K}{y_K}\frac{\partial y_K}{\partial z_{i}}

さて,最後です.

\frac{\partial y_l}{\partial z_{i}}

を微分しましょう.

ここで注意したいのが,

\frac{\partial y_i}{\partial z_{i}} = \frac{\partial}{\partial z_{i}} \frac{exp(z_i)}{\sum_{k=1}^{K}exp(z_k)}

\frac{\partial y_{その他}}{\partial z_{i}} = \frac{\partial}{\partial z_{i}} \frac{exp(z_{その他})}{\sum_{k=1}^{K}exp(z_k)}

で分けなきゃだめだと言うこと.
簡単のため,

S = \sum_{k=1}^{K}exp(z_k)
S^\prime = \frac{\partial}{\partial z_{i}}\sum_{k=1}^{K}exp(z_k)=exp(z_i)

としておきますね.
まずこっち.

\frac{\partial y_i}{\partial z_{i}} = \frac{\partial}{\partial z_{i}} \frac{exp(z_i)}{\sum_{k=1}^{K}exp(z_k)}

商の微分より,

\frac{\partial y_i}{\partial z_{i}} = \frac{\partial}{\partial z_{i}} \frac{exp(z_i)}{\sum_{k=1}^{K}exp(z_k)}= \frac{exp(z_i)(S-exp(z_i))}{S^2}
y_i = \frac{exp(z_i)}{\sum_{k=1}^{K}exp(z_k)}=\frac{exp(z_i)}{S}

より,

\frac{\partial y_i}{\partial z_{i}} = \frac{exp(z_i)(S-exp(z_i))}{S^2} = \frac{exp(z_i)}{S}(1-\frac{exp(z_i)}{S}) = y_i(1-y_i)

次にこっち.

\frac{\partial y_{その他}}{\partial z_{i}} = \frac{\partial}{\partial z_{i}} \frac{exp(z_{その他})}{\sum_{k=1}^{K}exp(z_k)}
= \frac{0-exp(z_{その他})exp(z_{i})}{S^2}=\frac{-exp(z_{その他})exp(z_{i})}{S^2}=-y_{その他}y_i

よって,

\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}\frac{\partial y_l}{\partial z_{i}}=-\frac{t_1}{y_1}\frac{\partial y_1}{\partial z_{i}}-\frac{t_2}{y_2}\frac{\partial y_1}{\partial z_{i}}-\cdots-\frac{t_i}{y_i}\frac{\partial y_i}{\partial z_{i}}-\cdots-\frac{t_K}{y_K}\frac{\partial y_K}{\partial z_{i}}
=-\frac{t_1}{y_1}(-y_1y_i)-\frac{t_2}{y_2}(-y_2y_i)-\cdots-\frac{t_i}{y_i}(y_i(1-y_i))-\cdots-\frac{t_K}{y_K}(-y_Ky_i)
=t_1y_i+t_2y_i+\cdots-t_i(1-y_i)+\cdots+t_Ky_i
=t_1y_i+t_2y_i-\cdots-t_i+t_iy_i+\cdots+t_Ky_i
=y_i(\sum_{l=1}^{K}t_l)-t_i
\sum_{l=1}^{K}t_l = 1

より,

\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}\frac{\partial y_l}{\partial z_{i}}=y_i-t_i

お疲れ様です!!!大変でしたね!!!

あとは,最後に

\frac{\partial E}{\partial w_{ni}}=\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}(\sum_{m=1}^{K}\frac{\partial y_l}{\partial z_{m}}\frac{\partial z_m}{\partial w_{ni}})
=\sum_{l=1}^{K}\frac{\partial E}{\partial y_l}(\frac{\partial y_l}{\partial z_{i}}x_n)

だったので,

\frac{\partial E}{\partial w_{ni}}=(y_i-t_i)x_n

となります.

実装は疲れたのでまた今度やります。。。

何か間違いなどあればご指摘ください!

実装編

ソフトマックス回帰を実装する

3
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?