LoginSignup
53
36

More than 5 years have passed since last update.

温度付きsoftmax 覚書

Last updated at Posted at 2018-07-31

温度付きsoftmax (softmax with temperature)

いつ使うか

モデルの蒸留なんかに出てくる損失関数 (多分他にも出てくるんだろうけどあまり知らない).
「ちょっと高い確率で出てきたクラスを重視して学習したい!」とか「低い確率のクラスを切り捨てずに学習したい!」ときに使われる.

数式

$$
S(y_i)=\frac{e^{\frac{y_i}{T}}}{\Sigma{e^{\frac{y_k}{T}}}}
$$

実装(Chainer)

import chainer
import numpy as np

def softmax_with_temperature(x, temperature: float) -> chainer.Variable:
    return F.softmax(x / temperature)

softmax_with_temperature(np.array([[-1, 0, 1]], dtype=np.float32), 0.5)

温度による変化

  • 温度($T$)が$1$より大きいとき: 低い確率を強調
  • 温度($T$)が$1$より小さいとき: 高い確率を強調
  • $y=e(x)$は$x$が大きくなると飛躍的に$y$が大きくなるので$T$が小さい→$x$が大きくなる→高い確率が強調される

import numpy as np
from scipy.stats import norm
from matplotlib import pyplot as plt

def plot():
    n_classes = 10
    left = np.linspace(0, n_classes-1, n_classes)
    # 適当に平均4, 分散1の正規分布をつくる.
    p = np.array([[norm.pdf(x=i, loc=4, scale=1) for i in range(n_classes)]], dtype=np.float32)
    softmax_p = softmax_with_temperature(p, 1)  # 普通のsoftmax
    softmax_p_t1 = softmax_with_temperature(p, 0.1)
    softmax_p_t2 = softmax_with_temperature(p, 0.5)
    softmax_p_t3 = softmax_with_temperature(p, 2)
    softmax_p_t4 = softmax_with_temperature(p, 10)
    ps = [softmax_p_t1, softmax_p_t2, softmax_p, softmax_p_t3, softmax_p_t4]
    labels = ['T='+str(i) for i in [0.1, 0.5, 1, 2, 10]]
    colors = ['red', 'orange', 'yellow', 'green', 'blue']
    for i, p in enumerate(ps):
        plt.bar(left, p.data.flatten(), color=colors[i], label=labels[i])
        plt.legend()
        plt.show()

温度が小さいとラベル4の確率が強調され,温度が大きいと低かった確率が伸びてくるのがわかる.
結構この温度うまいこと調整しないと意図した以上に強調が起きそう.

T0_1.png
T0_5.png
T1.png
T2.png
T10.png

53
36
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
53
36