0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

出力の活性化関数が他の累積分布関数の損失関数

Posted at

はじめに

二値分類問題を解きたいとき、出力の活性化関数を標準シグモイド関数とし、損失関数をbinary_crossentropyとする事が良く知られている。
これは出力の活性化関数の微分と損失関数の微分の連鎖積が、活性化関数を恒等関数、損失関数を平均二乗誤差(MSE)の微分の連鎖積に等しいからである。(自記事参考)
回帰問題

z=Wx\\ 
y=z\\
L_{mse}=\frac{1}{2}(y-t)^2\\
\frac{\partial L_{mse}}{\partial W}=\frac{\partial L_{mse}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}=(y-t)\cdot 1 \cdot x \\

二値分類問題

z=Wx\\ 
y=\frac{1}{1+e^{-z}}\\
L_{BCE}=-tlog(y)-(1-t)log(1-y)\\
\frac{\partial y}{\partial z} = \frac{e^{-z}}{(1+e^{-z})^2}=y(1-y)\\
\frac{\partial L_{BCE}}{\partial y} =- \frac{t}{y} + \frac{1-t}{1-y}=\frac{(y-t)}{y(1-y)}\\
\frac{\partial L_{BCE}}{\partial W}=\frac{\partial L_{BCE}}{\partial y}\frac{\partial y}{\partial z}\frac{\partial z}{\partial W}=\frac{(y-t)}{y(1-y)}\cdot y(1-y) \cdot x =(y-t)\cdot x\\

ここで標準シグモイド関数はロジスティック分布の累積分布関数だがシグモイド型の関数は任意分布の累積分布関数を考えれば他にもいくつか種類がある。
image.png
さて、出力の活性化関数を標準シグモイド関数($\frac{1}{1+e^{-x}}$)から正規分布の累積分布関数($\frac{1}{2}(1+erf(\frac{x}{\sqrt{2}}))$)に変えた場合、損失関数をbinary_crossentropyからどのような関数に変えるべきかを考える。

関係表

そもそも今回の考察を考え始めた経緯について説明する。
シグモイド型の活性化関数を微分すると任意の分布関数が現れる。標準シグモイド関数はロジスティック分布という分布を元にしている。その微分を活性化関数の出力$y$で表現し、$(t-y)$をその微分形で割った値を不定積分すると任意分布の累積分布関数を出力の活性化関数とした場合の適正な損失関数を求められる事に気付いた。現にbinary_crossentropyの式$(Loss_{BCE}=-tlog(y)-(1-t)log(1-y))$に関しては、$\frac{(t-y)}{y(1-y)}$の不定積分を計算すれば求まる。微分形の$\frac{dy}{dx}=y(1-y)$は標準シグモイド関数から求められる。なお、これらの数式の計算はwolfram alphaに頼った。
しかし、活性化関数と損失関数がいかなる形になっても結局平均二乗誤差(MSE)の微分と等しい形に持っていくなら、出力の活性化関数と損失関数の組み合わせはあまりどうでもいい事になる。出力の活性化関数と損失関数の連鎖積は結局同じになるよう損失関数を定義するからである。

image.png
出力の活性化関数を正規分布の累積分布関数とした時、不定積分を整理した正しい損失関数は損失関数1の方だが、あえてガウス分布を無視した損失関数2を選ぶとする。これは活性化関数のGelu関数+Gauss分布=Softplus関数?で示した(Softplus関数-Gauss分布)=Gelu関数という関係と若干近い。活性化関数としてGeluが優秀なのはSoftplus関数型からGauss分布引くからだと解釈することもできる。
正規分布の累積分布関数の場合の$(t-y)$を$dy/dx$で割った式を不定積分する部分を拡大すると以下であるから
image.png
下記のように整理できる。

活性化関数(正規分布の累積分布関数):y=\frac{1}{2}(1+erf(\frac{x}{\sqrt{2}}))\\
分布関数(正規分布):\frac{dy}{dx}=\frac{e^{-x^2/2}}{\sqrt{2\pi}}\\
\frac{x}{\sqrt{2}}=erf^{-1}(2y-1)\\
\frac{dy}{dx}=\frac{e^{-(erf^{-1}(2y-1))^{2}}}{\sqrt{2\pi}}\\
損失関数1:L=x(t-y)-\frac{e^{-x^2/2}}{\sqrt{2\pi}}\\
損失関数2:L=x(t-y)

custom_sigmoid

MNISTで0~4の時を0、5~9の時を1とする二値分類器を作ってみる。
この時、前述する正規分布の累積分布関数を活性化関数にするカスタム活性化関数と、連鎖積を平均二乗誤差(MSE)にする損失関数からガウス分布を引いたカスタム損失関数を考える。
非常に単純なモデルでカスタム活性化関数カスタム損失関数の結果と、標準シグモイド関数binary_crossentropyの結果を比べるとカスタム関数の結果が良い事が多かった。(やるたびに結果は変わるが大体カスタム関数の結果が良い)
epochs=10の場合で30回学習させたがval_acc平均と標準偏差は以下の通りだった。平均も改善しているし、標準偏差も小さくなる。

30回平均
custom_sigmoid 0.98515±0.00091
default_sigmoid 0.98298±0.00180
import tensorflow as tf
import numpy as np
import math as m

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation
from tensorflow.keras.utils import get_custom_objects

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train  = x_train.reshape(60000, 784)
x_test   = x_test.reshape(10000, 784)
x_train  = x_train.astype('float32')
x_test   = x_test.astype('float32')

x_train /= 255.0
x_test /= 255.0
y_train = np.where(y_train>4.5, 1.0, 0.0)
y_test  = np.where(y_test >4.5, 1.0, 0.0)

def custom_loss(x, y, t):
    loss = x * (y - t)
    #loss = x * (y - t) + tf.math.exp(-x*x/2.0)/tf.math.sqrt(2.0*tf.constant(m.pi))
    return loss

class Custom_sigmoid(Activation):
    def __init__(self, activation, **kwargs):
        super(Custom_sigmoid, self).__init__(activation, **kwargs)
        self.__name__ = 'custom_sigmoid'

def custom_sigmoid(x):
    return (1.0 + tf.math.erf(x/tf.math.sqrt(2.0)))/ 2.0

get_custom_objects().update({'custom_sigmoid': Custom_sigmoid(custom_sigmoid)})

##################################################################

print('custom sigmoid')
input1 = Input(shape=(784,))
input2 = Input(shape=(1,))
x = Dense(256, activation='relu')(input1)
x = Dense(256, activation='relu')(x)
x = Dense(1)(x)
y = Activation('custom_sigmoid')(x)
model = Model(inputs=[input1,input2], outputs=y)

model.add_loss(custom_loss(x, y, input2))
model.compile(optimizer='adam', metrics=['accuracy'])
history = model.fit([x_train, y_train], y_train, batch_size=128, epochs=10, verbose=1, validation_data=([x_test, y_test], y_test))

##################################################################

print('default sigmoid')
input = Input(shape=(784,))
x = Dense(256, activation='relu')(input)
x = Dense(256, activation='relu')(x)
x = Dense(1)(x)
y = Activation('sigmoid')(x)
model2 = Model(inputs=input, outputs=y)

model2.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
history = model2.fit(x_train, y_train, batch_size=128, epochs=10, verbose=1, validation_data=(x_test, y_test))

image.png

二値分類のsoftmaxの正規分布の累積分布化

$y=softmax(x)$は二値分類なら以下の様に$tanh$関数を使って整理できる。

y_1=\frac{e^{x_1}}{e^{x_1}+e^{x_2}}\\
y_2=\frac{e^{x_2}}{e^{x_1}+e^{x_2}}\\
y_1=\frac{1}{2}(1+tanh(\frac{x_1-x_2}{2}))\\
y_2=\frac{1}{2}(1-tanh(\frac{x_1-x_2}{2}))

ここで$tanh(\frac{x}{2})→erf(\frac{x}{\sqrt{2}})$に変換すれば標準シグモイド関数を正規分布の累積分布関数的なsoftmaxを考える事が出来るのではないかと考えられる。

y_1=\frac{1}{2}(1+erf(\frac{x_1-x_2}{\sqrt{2}}))\\
y_2=\frac{1}{2}(1-erf(\frac{x_1-x_2}{\sqrt{2}}))

三値分類のsoftmaxの正規分布の累積分布化

同様に三値分類の$softmax$を$tanh$関数で表現できる事のみ示しておく。

y_1=\frac{e^{x_1}}{e^{x_1}+e^{x_2}+e^{x_3}}\\
y_2=\frac{e^{x_2}}{e^{x_1}+e^{x_2}+e^{x_3}}\\
y_3=\frac{e^{x_3}}{e^{x_1}+e^{x_2}+e^{x_3}}\\
y_1=\frac{1}{2}(1+tanh(\frac{x_1-log(e^{x_2}+e^{x_3})}{2}))\\
y_2=\frac{1}{2}(1+tanh(\frac{x_2-log(e^{x_1}+e^{x_3})}{2}))\\
y_3=\frac{1}{2}(1+tanh(\frac{x_3-log(e^{x_1}+e^{x_2})}{2}))\\

wolfram alphaで上記がsoftmaxに等しいのは確認した。
image.png
しかし、$tanh(\frac{x}{2})→erf(\frac{x}{\sqrt{2}})$に変換した時、$y_1+y_2+y_3=1$を満たすかは確認できなかった。

まとめ

正規分布の累積分布関数をカスタム活性化関数にして正当な損失関数からガウス分布を引いたカスタム損失関数が、標準シグモイド関数とbinary_crossentropyの組み合わせよりも性能が高かった。
$Normalization$の全く入ってないバニラなNNだからたまたま効果が見えるだけかもしれない。正規化や正則化を沢山入れたNNでは性能差はない可能性はある。
softmaxをロジスティック分布以外の累積分布関数に出来るかに関しては途中でよく分からなくなった。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?