表題の数式が思いついたのでプロットしてみたところ係数を$\frac{\sqrt{3}}{\pi}$とした$Gelu$関数+$Gauss$分布が$Softplus$関数に対し非常に良い近似となった。
この関係性について単なる偶然なのかそれとも論理的な必然性があるのか考えてみる。
import matplotlib.pyplot as plt
import numpy as np
from scipy import special
x = np.arange(-6, 6, 0.05)
beta = np.pi/np.sqrt(3)
y0 = np.array([i if i>0 else 0 for i in x]) # Relu
y1 = np.log(1+np.exp(x)) # Softplus
y2 = x*0.5*(1+special.erf(np.sqrt(0.5)*x)) \
+ 1/np.sqrt(2*np.pi)*np.exp(-x*x/2) # (Gelu + Gauss)1
y3 = x*0.5*(1+special.erf(np.sqrt(0.5)*x/beta)) \
+ beta/np.sqrt(2*np.pi)*np.exp(-x*x/2/beta**2) # (Gelu + Gauss)2
plt.plot(x, y1, label="Softplus")
plt.plot(x, y2, label="(Gelu + Gauss)1")
plt.plot(x, y3, label="(Gelu + Gauss)2")
plt.plot(x, y0, label="Relu")
plt.legend()
plt.show()
plt.plot(x, y1-y0, label="Softplus - Relu")
plt.plot(x, y2-y0, label="(Gelu + Gauss)1 - Relu")
plt.plot(x, y3-y0, label="(Gelu + Gauss)2 - Relu")
plt.legend()
plt.show()
#Softplus関数
一般に活性化関数$Softplus$関数は以下の様な$log$と$exp$関数を用いて表される。この関数は既存の$Relu$関数を滑らかにしたような形で表される。また、この微分は標準シグモイド関数で表されるという特徴がある。
この数式だけを見た感じ、$Gelu+Gauss$が$Softplus$関数に近づく根拠は見当たらない。
Softplus(x)=log(1+e^x) \\
Softplus'(x)=\frac{e^x}{1+e^x}=\frac{1}{1+e^{-x}}=\sigma(x)
#Gelu関数+Gauss分布
一方、$Gelu$関数は正規分布(ガウス分布)の累積分布関数$\Phi(x)$と$x$の積で表される。
$Gelu$関数はおそらく$Swish(SiLU)$の$x\cdot\sigma(x)$からの類推で求められており、$Softplus$と直接的な関係は全くない。
Gelu(x)=x\cdot \Phi (x)\\
\Phi (x)=\frac{1+erf(\sqrt{2}x)}{2}\\
Swish(x)=x\cdot\sigma(x)\\
Gauss(x)=\frac{1}{\sqrt{2\pi}}e^{\frac{-x^2}{2}}
ここで$Gelu$関数+$Gauss$分布の微分を計算してみたい。
ここで正規分布(ガウス分布)の累積分布関数の微分はガウス分布になる。
また、$Gauss$分布の微分は$Gauss$分布と$-x$の積になる。
Gelu(x)+Gauss(x)=x\cdot \Phi (x)+\frac{1}{\sqrt{2\pi}}e^{\frac{-x^2}{2}}\\
\Phi (x)=\int_{-\infty}^{x}Gauss(x)dx\\
\Phi' (x)=Gauss(x)\\
Gauss'(x)=-x\cdot Gauss(x)
従って$Gelu$関数+$Gauss$分布の微分を整理すると打ち消しあって正規分布(ガウス分布)の累積分布関数だけが残る。
Gelu'(x)+Gauss'(x)=\Phi (x) + x\cdot \Phi' (x)+Gauss'(x)\\
=\Phi (x) + x\cdot Gauss(x) - x\cdot Gauss(x)\\
=\Phi (x)
ここで$Softplus$の微分である標準シグモイド関数$\sigma(x)$と$Gelu$関数+$Gauss$分布の微分の正規分布(ガウス分布)の累積分布関数$\Phi(x)$を並べてみる。分布の係数を$\frac{\sqrt{3}}{\pi}$とすれば$\sigma(x)$と$\Phi(\frac{\sqrt{3}}{\pi}x)$がほとんど一致することが分かる。
x = np.arange(-6, 6, 0.05)
beta = np.pi/np.sqrt(3)
y1 = 1/(1+np.exp(-x)) # from Softplus
y2 = 0.5*(1+special.erf(np.sqrt(0.5)*x)) # from (Gelu+Gauss)1
y3 = 0.5*(1+special.erf(np.sqrt(0.5)*x/beta)) # from (Gelu+Gauss)2
y0 = 0.5*(1+np.sign(x)) # from Relu
plt.plot(x, y1, label="from Softplus")
plt.plot(x, y2, label="from (Gelu+Gauss)1")
plt.plot(x, y3, label="from (Gelu+Gauss)2")
plt.plot(x, y0, label="from Relu")
plt.legend()
plt.show()
#ロジスティック分布
正規分布(ガウス分布)の累積分布関数に対するガウス分布のように標準シグモイド関数に関する分布に対してはロジスティック分布という分布の名前で知られている。この分布の分散はなんと$\pi^2/3$である。
今回、$Gelu$関数+$Gauss$分布の係数を$\frac{\sqrt{3}}{\pi}$としたのはこの分散からなのである。
そして、その係数はちょうど良く一致した。
logistic分布(x)=\frac{e^{-x}}{(1+e^{-x})^2}
#Softplus関数-Gauss分布=Gelu関数?
微分式がおなじS型のシグモイド型をとるという事から$Gelu$関数+$Gauss$分布$≃$$Softplus$関数という関係であるということを示した。
さて、ここで等式がほぼ成り立つと仮定するなら今度は逆に$Softplus$関数から$Gauss$分布を引いたものが$Gelu$関数になっていると考えることもできる。つまり微分式がシグモイド型(任意の分布の累積分布関数)になる関数からもとの分布を引く形が典型的な$Gelu$型活性化関数の形状なのである。
なぜ正規分布を差し引くのかに関しては$L_1,L_2$正則化の$\lambda_1|x|,\lambda_2 x^2$を損失関数に加える理由と近いものがあると考えられる。また中間層の分布を平均値0、標準偏差1にする係数とバイアスを加える処理が$BatchNormalization$である。活性化関数から$Gauss$分布を引くことで中間層の出力分布はガウス分布に近づく事が期待されるのだろうか。
こう考えるなら$Softplus$関数-$logistic$分布が標準シグモイド関数を元とする$Gelu$型活性化関数となる。
これらを参考のためプロットすると以下の様になった。
余分な$1.2$の係数を与えたとすると$Softplus$関数-$logistic$分布、$Gelu$関数、$Swish$関数の比較は似たような形が得られた。この$1.2$はロジスティック分布の尖度を意識しているが単なる偶然かもしれない。無理やり近づけた何の意味もないプロットである可能性もある。
x = np.arange(-6, 6, 0.05)
beta = np.pi/np.sqrt(3*1.2)
y0 = np.array([i if i>0 else 0 for i in x]) # Relu
y1 = np.log(1+np.exp(x*beta))/beta \
- np.exp(-x*beta)/((1+np.exp(-x*beta))**2)*beta # (Softplus - logstic)
y2 = x*0.5*(1+special.erf(np.sqrt(0.5)*x)) # Gelu
y3 = x/(1+np.exp(-x*beta)) # Swish
plt.plot(x, y1, label="(Softplus - logstic)")
plt.plot(x, y2, label="Gelu")
plt.plot(x, y3, label="Swish")
plt.plot(x, y0, label="Relu")
plt.legend()
plt.show()
plt.plot(x, 1.2*(y1-y0), label="(Softplus - logstic) - Relu")
plt.plot(x, y2-y0, label="Gelu - Relu")
plt.plot(x, y3-y0, label="Swish(beta=1.656) - Relu")
plt.legend()
plt.show()
#シグモイド型関数と分布
幾つかの分布とその累積分布関数の関係を示す。
softplus型は微分した関数がsigmoid型となる関数を書いている。
tanh型 | sigmoid型 | 分布 | softplus型 |
---|---|---|---|
$sgn(x)$ | $\frac{1}{2}(1+sgn(x))$ | ディラックのデルタ関数 | $\frac{1}{2} (x+abs(x))$ |
$tanh(x/2)$ | $\frac{1}{2}(1+tanh(x/2))=\frac{1}{1+e^{-x}}=\sigma(x)$ | ロジスティック分布 | $log(1+e^x)$ |
$erf(x/\sqrt{2})$ | $\frac{1}{2}(1+erf(x/\sqrt{2}))$ | ガウス分布 | $x\cdot \Phi (x)+\frac{1}{\sqrt{2\pi}}e^{\frac{-x^2}{2}}$ |
$\frac{x}{\sqrt{1+x^2}}$ | $\frac{1}{2}(1+\frac{x}{\sqrt{1+x^2}})$ | $\nu=2$のt分布 | $\frac{1}{2}(x+\sqrt{1+x^2})$ |
$\frac{2}{\pi}gd(\frac{\pi}{2}x)$ | $\frac{1}{2}(1+\frac{2}{\pi}gd(\frac{\pi}{2}x))$ | 双曲線正割分布 | 不明 |
$\frac{2}{\pi}\arctan(x)$ | $\frac{1}{2}(1+\frac{2}{\pi}\arctan(x))$ | ローレンツ分布 | $\frac{x}{2}(1+\frac{2}{\pi}\arctan(x))-\frac{log(1+x^2)}{2\pi}+C$ |
$Hardtanh(x)$ | $Hardsigmoid(x)$ | 一様分布 | - |
$softsign(x)$ | $\frac{1}{2}(1+\frac{x}{1+abs(x)})$ | $\frac{1}{2(1+abs(x))^2}$ | - |
$sinh(x)+sgn(x)(1-cosh(x))$ | $\frac{1}{2}(1+sinh(x)+sgn(x)(1-cosh(x)))$ | ラプラス分布 | $\frac{1}{2}(x+cosh(x)+sgn(x)(x-sinh(x)))$ |
※追記:いくつか分布を追加 |
#まとめ
$Gelu$関数に$Gauss$分布を足すと$Softplus$型の関数になる。
ここから活性化関数の微分が広義のシグモイド型関数(標準シグモイド関数、正規分布の累積分布関数)となる$Softplus$型の活性関数からシグモイド型の微分にあたる分布関数(ロジスティック分布、正規分布)を差し引いたものが$Gelu$型関数として解釈できるのではと考察した。