#はじめに
活性化関数と損失関数は学習する問題に合わせて適切な組み合わせを選択する必要があります。
組み合わせが適切でない場合、学習が進まなかったり、学習の速度が遅くなったりします。
今回はその組み合わせにおける微分の値を計算していくつか確認してみました。
学習における重み更新は
W_{new} = W - \eta \frac{\partial L}{\partial W}
と書くことができます。ここで$\eta,$は学習率、$W$は学習するモデルの重みです。
従って損失関数$L$を重み$W$で偏微分した値(勾配)が重み更新(誤差逆伝搬)において重要になります。
今回やったことは損失関数と活性化関数の偏微分をそれぞれ求め、連鎖律から全結合重み$W$の偏微分を計算させているだけになります。
以下$y=y_{pred},,t=y_{true},,z=$全結合後で活性化関数前の出力$,,x=$全結合の入力$,,W=$全結合の重みとします。コードはKerasの損失関数と活性化関数の記述に準拠するとします。
#1.回帰問題
問題:回帰問題
活性化関数:線形関数
損失関数:mean_square_error
#コード例:
y = Dense(1, activation='linear')(x) # activation='linear'は省略してもよい
model.compile(loss='mean_squared_error')
L_{mean\_square\_error} = \frac{1}{2}(t - y)^2\\
y = z\\
z = W\times x
損失関数、活性化関数、直前の全結合計算が上記のように行われたとする。
この場合、各微分式は
\frac{\partial L_{mse}}{\partial y} = -(t - y)\\
\frac{\partial y}{\partial z} = 1\\
\frac{\partial z}{\partial W} = x
からモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{mse}}{\partial W} &= \frac{\partial L_{mse}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}\\
&=-(t - y)x
\end{align}
となります。
ところで、例えば全結合の$W$のサイズを$(10,512)$の変換行列とした場合、$(t,y)$のサイズ$(10)$、$(x)$のサイズ$(512)$となるため、正確に書くなら
\begin{align}\frac{\partial L_{mse}}{\partial W} &=-(t - y)x^{\mathrm{T}}
\end{align}
と行列の転置記号が必要になりますが、本質的な問題ではないので以降省略します。
#2.二値分類問題 その1
問題:二値分類問題
活性化関数:sigmoid関数
損失関数:binary_crossentropy
#コード例:
y = Dense(1, activation='sigmoid')(x)
model.compile(loss='binary_crossentropy')
\begin{align}
L_{binary\_crossentropy} &= -t \log y-(1-t) \log (1-y)\\
y &= \frac{1}{1+\exp(-z)} \qquad\cdots(sigmoid)\\
1-y &= \frac{\exp(-z)}{1+\exp(-z)}\\
z &= W\times x
\end{align}
損失関数、活性化関数、直前の全結合計算が上記のように行われたとする。
この場合、各微分式は
\frac{\partial L_{bin}}{\partial y} = -\frac{t}{y}+\frac{1-t}{1-y}=\frac{-t(1-y)+(1-t)y}{y(1-y)}=\frac{-(t-y)}{y(1-y)}\\
\frac{\partial y}{\partial z} = \frac{\exp(-z)}{(1+\exp(-z))^2}=y(1-y)\\
\frac{\partial z}{\partial W} = x
からモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{bin}}{\partial W} &= \frac{\partial L_{bin}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}\\
&=\frac{-(t-y)}{y(1-y)} \dot{} y(1-y) \dot{} x\\
&=-(t - y)x
\end{align}
となり、回帰問題における重み更新量と一致します。
#3.二値分類問題 その2
問題:二値分類問題
活性化関数:softmax関数
損失関数:category_crossentropy
#コード例:
y = Dense(2, activation='softmax')(x)
model.compile(loss='categorical_crossentropy')
二値分類問題は2.の代わりにone-hotベクトルを使ってこのように書いても構いません。
この場合softmaxの定義とone-hotベクトルの特徴より、条件
y[0] + y[1] = 1\\
t[0] + t[1] = 1
を満たす。
\begin{align}L_{category\_crossentropy} &= -t[0] \log y[0] - t[1] \log y[1] \\
&= -t[0] \log y[0]-(1-t[0]) \log (1-y[0])\\
y[0] &= \frac{\exp(z[0])}{\exp(z[0])+\exp(z[1])} \qquad\cdots(softmax)\\
y[1] &= \frac{\exp(z[1])}{\exp(z[0])+\exp(z[1])} = 1-y[0]
\end{align}\\
z = W\times x
損失関数、活性化関数、直前の全結合計算が上記のように行われたとすると、
この場合、各微分式は
\frac{\partial L_{cat}}{\partial y[0]} = -\frac{t[0]}{y[0]}+\frac{1-t[0]}{1-y[0]}=\frac{-t[0](1-y[0])+(1-t[0])y[0]}{y[0](1-y[0])}=\frac{-(t[0]-y[0])}{y[0](1-y[0])}\\
\frac{\partial y[0]}{\partial z[0]} = \frac{\exp(z[0])}{\exp(z[0])+\exp(z[1])} + \frac{- \exp(z[0])*\exp(z[0])}{(\exp(z[0])+\exp(z[1]))^2}=y[0]-y[0]^2=y[0](1-y[0])\\
\frac{\partial y[0]}{\partial z[1]} = \frac{- \exp(z[0])*\exp(z[1])}{(\exp(z[0])+\exp(z[1]))^2}=-y[0]*y[1]=-y[0](1-y[0])\\
\frac{\partial z[0]}{\partial W} = x u_0\\
\frac{\partial z[1]}{\partial W} = x u_1\\
u_0 = \left( \begin{array}{c} 1 \\ 0 \end{array} \right)\\
u_1 = \left( \begin{array}{c} 0 \\ 1 \end{array} \right)
からモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{cat}}{\partial W} &= \frac{\partial L_{cat}}{\partial y[0]}\frac{\partial y[0]}{\partial z[0]}\frac{\partial z[0]}{\partial W} + \frac{\partial L_{cat}}{\partial y[0]}\frac{\partial y[0]}{\partial z[1]}\frac{\partial z[1]}{\partial W}\\
&=\frac{-(t[0]-y[0])}{y[0](1-y[0])} \dot{} y[0](1-y[0]) \dot{} (xu_0-xu_1) \\
&=-(t[0] - y[0])(xu_0-xu_1)\\
&=-((t[0] - y[0])(xu_0) + (t[1] - y[1])(xu_1))
\end{align}
となり、回帰問題における重み更新量と一致します。
#4.三値分類問題
問題:三値分類問題
活性化関数:softmax関数
損失関数:category_crossentropy
#コード例:
y = Dense(3, activation='softmax')(x)
model.compile(loss='categorical_crossentropy')
この場合softmaxの定義とone-hotベクトルの特徴より、条件
y[0] + y[1] + y[2] = 1\\
t[0] + t[1] + t[2] = 1
を満たす。
\begin{align}L_{category\_crossentropy} &= -t[0] \log y[0] - t[1] \log y[1] - t[2] \log y[2]\\
y[0] &= \frac{\exp(z[0])}{\exp(z[0])+\exp(z[1])+\exp(z[2])} \qquad\cdots(softmax)\\
1-y[0] &= \frac{\exp(z[1])+\exp(z[2])}{\exp(z[0])+\exp(z[1])+\exp(z[2])}\\
y[1] &= \frac{\exp(z[1])}{\exp(z[0])+\exp(z[1])+\exp(z[2])}\\
y[2] &= \frac{\exp(z[2])}{\exp(z[0])+\exp(z[1])+\exp(z[2])}\\
\end{align}\\
z = W\times x
損失関数、活性化関数、直前の全結合計算が上記のように行われたとすると、
この場合、各微分式は
\frac{\partial L_{cat}}{\partial y[0]} = -\frac{t[0]}{y[0]}\\
\frac{\partial L_{cat}}{\partial y[1]} = -\frac{t[1]}{y[1]}\\
\frac{\partial L_{cat}}{\partial y[2]} = -\frac{t[2]}{y[2]}\\
\frac{\partial y[0]}{\partial z[0]} = \frac{\exp(z[0])}{\exp(z[0])+\exp(z[1])+\exp(z[2])} + \frac{- \exp(z[0])*\exp(z[0])}{(\exp(z[0])+\exp(z[1])+\exp(z[2]))^2}\\
=\frac{\exp(z[0])*(\exp(z[1]) + \exp(z[2]))}{(\exp(z[0])+\exp(z[1])+\exp(z[2]))^2}=y[0](1-y[0])\\
\frac{\partial y[1]}{\partial z[0]} = \frac{- \exp(z[0])*\exp(z[1])}{(\exp(z[0])+\exp(z[1])+\exp(z[2]))^2}=-y[0]*y[1]\\
\frac{\partial y[2]}{\partial z[0]} = \frac{- \exp(z[0])*\exp(z[2])}{(\exp(z[0])+\exp(z[1])+\exp(z[2]))^2}=-y[0]*y[2]\\
\frac{\partial z[0]}{\partial W} = xu_0\\
\frac{\partial z[1]}{\partial W} = xu_1\\
\frac{\partial z[2]}{\partial W} = xu_2
である。
ここで一度$\frac{\partial L_{cat}}{\partial z[0]}$を考えると
\begin{align}
\frac{\partial L_{cat}}{\partial z[0]} &= \frac{\partial L_{cat}}{\partial y[0]}\frac{\partial y[0]}{\partial z[0]} + \frac{\partial L_{cat}}{\partial y[1]}\frac{\partial y[1]}{\partial z[0]} + \frac{\partial L_{cat}}{\partial y[2]}\frac{\partial y[2]}{\partial z[0]}\\
&=-\frac{t[0]}{y[0]}\dot{}y[0](1-y[0]) +\frac{t[1]}{y[1]}\dot{}y[0]*y[1] +\frac{t[2]}{y[2]}\dot{}y[0]*y[2]\\
&=-t[0](1-y[0])+y[0]t[1]+y[0]t[2]\\
&= -t[0] + y[0](t[0]+t[1]+t[2])\\
&= -(t[0] - y[0])
\end{align}
また同様に計算すれば
\begin{align}
\frac{\partial L_{cat}}{\partial z[1]} &= -(t[1] - y[1])\\
\frac{\partial L_{cat}}{\partial z[2]} &= -(t[2] - y[2])\\
\end{align}
従ってモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{cat}}{\partial W} &= \frac{\partial L_{cat}}{\partial z[0]}\frac{\partial z[0]}{\partial W} + \frac{\partial L_{cat}}{\partial z[1]}\frac{\partial z[1]}{\partial W} + \frac{\partial L_{cat}}{\partial z[2]}\frac{\partial z[2]}{\partial W}\\
&=-((t[0] - y[0])(xu_0) + (t[1] - y[1])(xu_1) + (t[2] - y[2])(xu_2))
\end{align}
以上、3値以上の多分類においても同様に書ける。
#5.マルチラベル分類問題
問題:マルチラベル分類問題
活性化関数:sigmoid関数
損失関数:binary_crossentropy
#コード例:
y = Dense(2, activation='sigmoid')(x)
model.compile(loss='binary_crossentropy')
\begin{align}
L_{binary\_crossentropy} &= -t[0] \log y[0]-(1-t[0]) \log (1-y[0])\\
& -t[1] \log y[1]-(1-t[1]) \log (1-y[1])\\
y[0] &= \frac{1}{1+\exp(-z[0])} \qquad\cdots(sigmoid)\\
y[1] &= \frac{1}{1+\exp(-z[1])}\\
z &= W\times x
\end{align}
損失関数、活性化関数、直前の全結合計算が上記のように行われたとする。
この場合、各微分式は
\frac{\partial L_{bin}}{\partial y[0]} = \frac{-(t[0]-y[0])}{y[0](1-y[0])}\\
\frac{\partial L_{bin}}{\partial y[1]} = \frac{-(t[1]-y[1])}{y[1](1-y[1])}\\
\frac{\partial y[0]}{\partial z[0]} = y[0](1-y[0])\\
\frac{\partial y[1]}{\partial z[1]} = y[1](1-y[1])\\
\frac{\partial z[0]}{\partial W} = xu_0\\
\frac{\partial z[1]}{\partial W} = xu_1
からモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{bin}}{\partial W} &= \frac{\partial L_{bin}}{\partial y[0]}\frac{\partial y[0]}{\partial z[0]}\frac{\partial z[0]}{\partial W} + \frac{\partial L_{bin}}{\partial y[1]}\frac{\partial y[1]}{\partial z[1]}\frac{\partial z[1]}{\partial W}\\
&=-((t[0] - y[0])(xu_0) + (t[1] - y[1])(xu_1))
\end{align}
となり、回帰問題における重み更新量と一致します。
#6.分類問題(pytorch)
問題:分類問題
活性化関数:log_softmax関数
損失関数:nllloss
pytorchにおける分類問題の活性化関数と損失関数はlog_softmaxとnlllossを使うそうです。
(もしくは活性化関数なしでpytorch流のcategory_crossentropyを使う)
pytorchはほとんど触ったことがないので適当ですが、log_softmaxをsoftmaxにlogを掛けたものとみなせば、
nlllossは
\begin{align}L_{nllloss} &= - \sum t * y\\
y[0] &= \log( \frac{\exp(z[0])}{ \sum \exp(z[0])+\exp(z[1])})
\end{align}
の場合、Kerasのcategory_crossentropyと等しい勾配を持ちます。
#7.回帰問題(損失関数MAEの場合)
問題:回帰問題
活性化関数:線形関数
損失関数:mean_absolute_error
#コード例:
y = Dense(1, activation='linear')(x) # activation='linear'は省略してもよい
model.compile(loss='mean_absolute_error')
L_{mean\_absolute\_error} = |t - y|\\
y = z\\
z = W\times x
損失関数、活性化関数、直前の全結合計算が上記のように行われたとする。
この場合、各微分式は
\frac{\partial L_{mae}}{\partial y} = -\frac{(t - y)}{|t-y|}\\
\frac{\partial y}{\partial z} = 1\\
\frac{\partial z}{\partial W} = x
からモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{mae}}{\partial W} &= \frac{\partial L_{mae}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}\\
&=-\frac{(t - y)}{|t-y|}x
\end{align}
となる。これは$y$が$t$に近い時、通常の回帰問題の重み更新量より更新量が大きくなる。
#8.二値分類問題(損失関数MSEの場合)
問題:二値分類問題
活性化関数:sigmoid関数
損失関数:mean_square_error
#コード例:
y = Dense(1, activation='sigmoid')(x)
model.compile(loss='mean_square_error')
\begin{align}
L_{mean\_square\_error} &= \frac{1}{2}(t - y)^2\\
y &= \frac{1}{1+\exp(-z)} \qquad\cdots(sigmoid)\\
z &= W\times x
\end{align}
損失関数、活性化関数、直前の全結合計算が上記のように行われたとする。
この場合、各微分式は
\frac{\partial L_{mse}}{\partial y} = -(t-y)\\
\frac{\partial y}{\partial z} = y(1-y)\\
\frac{\partial z}{\partial W} = x
からモデル重み$W$の更新量は連鎖律を使って
\begin{align}\frac{\partial L_{mse}}{\partial W} &= \frac{\partial L_{mse}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}\\
&=-(t - y) \dot{} y(1-y) \dot{} x
\end{align}
となる。これは$y$が$0,1$に近い時、通常の分類問題の重み更新量より更新量が小さくなる。
これら通常の回帰問題における更新量比率を1と置いた場合の各活性化関数と損失関数を選んだ更新量比率は以下のようになります。
linear | sigmoid | |
---|---|---|
MSE | $1$ | $y(1-y)$:小 |
MAE | $\frac{1}{abs(t-y)}$:大 | $\frac{y(1-y)}{abs(t-y)}$:≃1 |
binary_crossentropy | $\frac{1}{y(1-y)}$:大 | $\frac{y(1-y)}{y(1-y)}=1$ |
#9.回帰問題(活性化関数tanhの場合)
問題:回帰問題
活性化関数:tanh関数
損失関数:mean_square_error??
活性化関数tanhの場合、この活性化関数後の出力が1~-1に出力されます。自分の知る限りこの活性化関数tanhが使われるケースはあまりありません。
実際にどういう場合に使われるのかというとゲームの強化学習で勝ちを1、負けを-1、引き分けを0とする場合に使われるのを見たことがあります。
問題はこの回帰問題を解く場合に損失関数はMSEでいいのか考えてみます。
L_{mean\_square\_error} = \frac{1}{2}(t - y)^2\\
y = tanh(x)\\
z = W\times x
この場合、各微分式からモデル重み$W$の更新量は
\frac{\partial L_{mse}}{\partial y} = -(t - y)\\
\frac{\partial y}{\partial z} = 1-y^2\\
\frac{\partial z}{\partial W} = x\\
\\
\begin{align}\frac{\partial L_{mse}}{\partial W} &= \frac{\partial L_{mse}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}\\
&=-(t - y)(1-y^2)x
\end{align}
となる。これは$y$が$-1,1$に近い時、通常の回帰問題の重み更新量より更新量が小さくなる。
もし、$W$更新量を通常の回帰問題と同じにしたければ損失関数を
L_{tanh} = -\frac{1}{2}((1-t)\log(1-y)+(1+t)\log(1+y))\\
\frac{\partial L_{tanh}}{\partial y} = -\frac{1}{2}(-\frac{1-t}{1-y}+\frac{1+t}{1+y})\\
=-\frac{(t-y)}{(1-y^2)}\\
\begin{align}\frac{\partial L_{tanh}}{\partial W} &= \frac{\partial L_{tanh}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}\\
&=-\frac{(t-y)}{(1-y^2)}(1-y^2)x\\
&=-(t - y)x
\end{align}
とすればいいことになります。
#まとめ
長々と書きましたが、普通の問題を解くだけなら、
・回帰問題の場合は活性化関数は線形関数、損失関数はmean_square_errorを使う。
・one-hotベクトルの分類問題では活性化関数はsoftmax、損失関数はcategory_crossentropyを使う。
・one-hotベクトルでないマルチラベル分類では活性化関数はsigmoid、損失関数はbinary_crossentropyを使うと良い。
を押さえておけばほぼ問題ないです。
#参考:
https://www.renom.jp/ja/notebooks/tutorial/basic_algorithm/lossfunction/notebook.html
https://qiita.com/43x2/items/50b55623c890564f1893
https://qiita.com/takurooo/items/e356dfdeec768d8f7146