温度付きsoftmax (softmax with temperature)
いつ使うか
モデルの蒸留なんかに出てくる損失関数 (多分他にも出てくるんだろうけどあまり知らない).
「ちょっと高い確率で出てきたクラスを重視して学習したい!」とか「低い確率のクラスを切り捨てずに学習したい!」ときに使われる.
数式
$$
S(y_i)=\frac{e^{\frac{y_i}{T}}}{\Sigma{e^{\frac{y_k}{T}}}}
$$
実装(Chainer)
import chainer
import numpy as np
def softmax_with_temperature(x, temperature: float) -> chainer.Variable:
return F.softmax(x / temperature)
softmax_with_temperature(np.array([[-1, 0, 1]], dtype=np.float32), 0.5)
温度による変化
- 温度($T$)が$1$より大きいとき: 低い確率を強調
- 温度($T$)が$1$より小さいとき: 高い確率を強調
- $y=e(x)$は$x$が大きくなると飛躍的に$y$が大きくなるので$T$が小さい→$x$が大きくなる→高い確率が強調される
例
import numpy as np
from scipy.stats import norm
from matplotlib import pyplot as plt
def plot():
n_classes = 10
left = np.linspace(0, n_classes-1, n_classes)
# 適当に平均4, 分散1の正規分布をつくる.
p = np.array([[norm.pdf(x=i, loc=4, scale=1) for i in range(n_classes)]], dtype=np.float32)
softmax_p = softmax_with_temperature(p, 1) # 普通のsoftmax
softmax_p_t1 = softmax_with_temperature(p, 0.1)
softmax_p_t2 = softmax_with_temperature(p, 0.5)
softmax_p_t3 = softmax_with_temperature(p, 2)
softmax_p_t4 = softmax_with_temperature(p, 10)
ps = [softmax_p_t1, softmax_p_t2, softmax_p, softmax_p_t3, softmax_p_t4]
labels = ['T='+str(i) for i in [0.1, 0.5, 1, 2, 10]]
colors = ['red', 'orange', 'yellow', 'green', 'blue']
for i, p in enumerate(ps):
plt.bar(left, p.data.flatten(), color=colors[i], label=labels[i])
plt.legend()
plt.show()
温度が小さいとラベル4の確率が強調され,温度が大きいと低かった確率が伸びてくるのがわかる.
結構この温度うまいこと調整しないと意図した以上に強調が起きそう.