LoginSignup
48
21

More than 3 years have passed since last update.

Kerasのhard_sigmoidが max(0, min(1, (0.2 * x) + 0.5)) である話

Last updated at Posted at 2019-05-30

hard_sigmoid

Kerasにはhard_sigmoidという区分線形関数が用意されている。これは標準シグモイド関数
$f(x) = \frac{e^x}{e^x+1} $ を次の関数で近似するものである。

g(x) = \left\{
\begin{array}{ll}
0 & (x < -2.5) \\
0.2x + 0.5 & (-2.5 \leq x \leq 2.5) \\
1 & (2.5 < x)
\end{array}
\right.

指数関数の計算を必要としないため、如何にも速そうである。

ソースコード

TensorFlowバックエンドのソースコードは次の通り。

def hard_sigmoid(x):
    """Segment-wise linear approximation of sigmoid.

    Faster than sigmoid.
    Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
    In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.

    # Arguments
        x: A tensor or variable.

    # Returns
        A tensor.

    {{np_implementation}}
    """
    x = (0.2 * x) + 0.5
    zero = _to_tensor(0., x.dtype.base_dtype)
    one = _to_tensor(1., x.dtype.base_dtype)
    x = tf.clip_by_value(x, zero, one)
    return x

Theanoバックエンドのソースコードは次の通り。

def hard_sigmoid(x):
    """
    An approximation of sigmoid.

    More approximate and faster than ultra_fast_sigmoid.

    Approx in 3 parts: 0, scaled linear, 1.

    Removing the slope and shift does not make it faster.

    """
    # Use the same dtype as determined by "upgrade_to_float",
    # and perform computation in that dtype.
    out_dtype = scalar.upgrade_to_float(scalar.Scalar(dtype=x.dtype))[0].dtype
    slope = tensor.constant(0.2, dtype=out_dtype)
    shift = tensor.constant(0.5, dtype=out_dtype)
    x = (x * slope) + shift
    x = tensor.clip(x, 0, 1)
    return x

見た目に違いはあれど、どちらも「0.2倍して、0.5を足して、0と1の範囲内に収まるようにclipしている」ということには変わりがない。

なぜxの係数が0.2なのか

0.2を掛けた後に足す数が0.5であるのは、標準シグモイド関数の性質である「点$(0, \frac{1}{2})$ に関して点対称」を考えれば自然である。気になるのは、「なぜ掛ける数が0.2なのか」というところである。実際、標準シグモイド関数の級数展開はWolfram Alpha先生によると

$$ f(x) = \frac{1}{2} + \frac{x}{4} - \frac{x^3}{48} + \frac{x^5}{480} - \frac{17 x^7}{80640} + \dots $$

であるので、 $x$ が $0$ に近いときには $\frac{1}{2} + \frac{x}{4}$ で近似したほうが高い精度が出るはずである。

やってみよう。

標準シグモイド関数とmax(0, min(1, 1/2 + x/4))
desmos.comで生成

見てのとおり、0付近での精度は高いものの、±1.5 ~ 3のあたりでの精度の低さが気になる。
では、傾きを0.25ではなく0.2にしてみるとどうだろう。

標準シグモイド関数とmax(0, min(1, 1/2 + x/5))
desmos.comで生成

見てのとおり、比較的広い範囲にわたってシグモイドをそれなりによく近似できているように見える。

ここで終わってもいいのだが

さて、まあ見た目で納得してもらう、というのも手ではあるのだが、それでは0.2という値がマジックナンバーになってしまう。この値にどのような根拠があるのか、調べてみたいものである。

ということで、関数 $g(x)$ を

g(x) = \left\{
\begin{array}{ll}
0 & (x < -a/2) \\
\frac{x}{a} + \frac{1}{2} & (-a/2 \leq x \leq a/2) \\
1 & (a/2 < x)
\end{array}
\right.

と定義して、二乗誤差 $ \int_{-\infty}^\infty \left(f(x) - g(x) \right)^2 dx $ を最小化するような $a$ を求めよう。原点周りでの傾きが $0.2$ に近い値に、つまり、 $a$ が $5$ に近い値になれば成功である。

対称性より、$ \int_{-\infty}^0 \left(f(x) - g(x) \right)^2 dx $ は二乗誤差の半分である。以下これを計算する。

\begin{align}
\int_{-\infty}^0 \left(f(x) - g(x) \right)^2 dx &= \int_{-\infty}^0 \left(f(x)\right)^2 dx - 2\int_{-\infty}^0 f(x)g(x) dx + \int_{-\infty}^0 \left( g(x) \right)^2 dx \\
&= \int_{-\infty}^0 \left(f(x)\right)^2 dx - 2\int_{-a/2}^0 f(x)g(x) dx + \int_{-a/2}^0 \left( g(x) \right)^2 dx \\
&= \int_{-\infty}^0 \left(f(x)\right)^2 dx - 2\int_{-a/2}^0 \frac{e^x}{e^x+1} \left(\frac{x}{a} + \frac{1}{2}\right)dx + \int_{-a/2}^0 \left(\frac{x}{a} + \frac{1}{2}\right)^2 dx 
\end{align}

最終行で区分的に定義された関数が登場しなくなったので、これは $a > 0$ において $a$ で微分可能であると推察できる。

第3項 $\int_{-a/2}^0 \left(\frac{x}{a} + \frac{1}{2}\right)^2 dx $ は、 $y = x + \frac{a}{2} $ とすることで

\begin{align}
\int_{-a/2}^0 \left(\frac{x}{a} + \frac{1}{2}\right)^2 dx &= \int_{0}^{a/2} \left(\frac{y}{a} \right)^2 dy = \int_{0}^{a/2} \frac{y^2}{a^2} dy = \frac{1}{a^2} \frac{1}{3}\left(\frac{a}{2}\right)^3 = \frac{a}{24}
\end{align}

と計算できる。ということで、二乗誤差(の半分)を $a$ で微分すると、

\begin{align}
\frac{d}{da}\left( - 2\int_{-a/2}^0 \frac{e^x}{e^x+1} \left(\frac{x}{a} + \frac{1}{2}\right)dx\right) + \frac{1}{24}
\end{align}

これが $0$ となるような $a$ を求めていきたい。

ここで

$$ M(a,x) = 2 \left(\frac{x}{a} + \frac{1}{2}\right) \frac{e^x}{e^x+1} $$

と定義すると、Leibniz integral rule 1

$$ \frac{d}{da} \left(\int_{A(a)}^{B(a)} M(a,x) dx \right)= M\left(a,B(a)\right) \frac{d}{da} B(a) - M\left(a,A(a)\right) \frac{d}{da} A(a) + \int_{A(a)}^{B(a)}\frac{\partial}{\partial a} M(a,x) dx $$

より

$$ \frac{d}{da} \left (\int_{-a/2}^{0} M(a,x) dx \right)= - M\left(a,-\frac{a}{2}\right) \left(-\frac{1}{2}\right) + \int_{-a/2}^{0}\frac{\partial}{\partial a} M(a,x) dx $$

ここで $M\left(a,-\frac{a}{2}\right) = 2 \left(\frac{-a/2}{a} + \frac{1}{2}\right) \frac{e^{-\frac{a}{2}}}{e^{-\frac{a}{2}}+1} = 0$ なので

\begin{align}
\frac{d}{da} \left (\int_{-a/2}^{0} M(a,x) dx \right) &= \int_{-a/2}^{0}\frac{\partial}{\partial a} M(a,x) dx \\ 
&= \int_{-a/2}^{0} 2 \left(-\frac{x}{a^2} \right) \frac{e^x}{e^x+1} dx
\end{align}

ゆえに、二乗誤差(の半分)を $a$ で微分したものは

$$ \frac{1}{24} + 2\int_{-a/2}^{0} \frac{x}{a^2} \frac{e^x}{e^x+1} dx $$

と書けることが分かった。これの $a$ についてのグラフを書くと、

二乗誤差(の半分)を $a$ で微分したもの、のグラフ。<br>

たしかに $a=5$ 付近でゼロになっていることが分かる。しかも、一階微分が増加関数であることから、ここが極小であり最小であることも分かる。 2

さあ、あとはWolfram Alphaに突っ込んで $a$ を求めるだけ…

Try the following:

うーむ。

じゃあせめて積分だけでも…
Standard computation time exceeded...
(´・ω・`)

Wolfram Alphaが読んでくれるように式変形しよう

さて、ということで

$$ 0 = \frac{1}{24} + 2\int_{-a/2}^{0} \frac{x}{a^2} \frac{e^x}{e^x+1} dx $$

を式変形していこう。

$$ \frac{d}{dx} \log(1+e^x) = \frac{\frac{d}{dx} (1+e^x)}{1+e^x} = \frac{e^x}{e^x+1} $$

なので、部分積分により先程の式は

\begin{align}
0 &= \frac{1}{24} + 2 \left. \frac{x}{a^2}  \log(1+e^x)\right|_{x=-a/2}^{x=0} - 2 \int_{-a/2}^{0}  \frac{1}{a^2}  \log(1+e^x) dx \\
&= \frac{1}{24} + \frac{1}{a}  \log(1+e^{-a/2}) - \frac{2}{a^2} \int_{-a/2}^{0}  \log(1+e^x) dx
\end{align}

ここで $u = -e^{x}$ とおくと $du = -e^{x}dx$ なので

\begin{align}
\int_{-a/2}^{0} \log(1+e^x) dx = \int_{-e^{-a/2}}^{-1} \log(1-u) \frac{du}{u}
\end{align}

これは高校数学の範囲では処理できないが、多重対数関数 というものがあって、その中でも $\operatorname{Li}_2(x)$ と呼ばれる関数は負の実数$z$に対して

\operatorname{Li}_2(z) = -\lim_{h \to 0^{-}} \int_h^z \frac{\log(1-u)}{u} du

という性質を満たす。

ということで、解くべき式は

0 = \frac{1}{24} + \frac{1}{a}  \log(1+e^{-a/2}) + \frac{2}{a^2} \left( \operatorname{Li}_2(-1)- \operatorname{Li}_2(-e^{-a/2}) \right)

ということが分かった。ここまで解きほぐしてやればWolfram Alphaも処理することができ

$$ a = 5.19936381662864... $$

ということが分かる。よって、$x$ に掛ける定数を

$$ \frac{1}{a} = 0.19233122267801... $$

にすると標準シグモイド関数との二乗誤差が最小になることが分かった。

おまけ

hard_sigmoidは $h_1(x) = 1$ と $h_2(x) = 0$ の間を1つの一次関数で繋ぐことで 標準シグモイド関数を近似したが、2つの二次関数で繋いだらどうなるだろう。つまり、

g(x) = \left\{
\begin{array}{ll}
0 & (x < -a) \\
\frac{(x+a)^2}{2a^2} & (-a \leq x < 0) \\
1 - \frac{(x-a)^2}{2a^2} & (0 \leq x \leq a) \\
1 & (a < x)
\end{array}
\right.

とするのである。

この場合、原点での傾きを標準シグモイド関数に合わせる(これは $a = 4$ )のがほぼ最適解となる。詳細は省略するが、

\frac{17}{60} = \frac{4}{a^3}  \left(-\frac{a}{2} \operatorname{Li}_2(-e^a) - \frac{a}{2} \operatorname{Li}_2(-1) + \operatorname{Li}_3(-e^a) -  \operatorname{Li}_3(-1)\right)

解いて、 $a = 3.99197948719976...$ を得る。

参考文献

追記

Theanoでの原作者はhard_sigmoidの係数が0.2である真の理由を覚えていないらしい、という話があるそうだ。


  1. この定理は日本語圏で恐ろしく知名度がないようであり、検索しても日本語訳が全然ヒットしない。 

  2. お気づきかもしれないが、今回の問題設定上 $a > 0$ なので、負であるときの振る舞いは関係ない。というかそもそも $a=0$ のときはゼロ除算である。 

48
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
48
21