この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。
苏剑林. (Jun. 14, 2024). 《通向概率分布之路:盘点Softmax及其替代品 》[Blog post]. Retrieved from https://kexue.fm/archives/10145
前回記事の続きです。
Perturb Max
この章では、確率分布を構築する新しい手法を紹介する。Perturb Maxと呼ばれる、Gumbel Maxを一般化したものだ。本サイトでは記事「Reparametrizationの視点から見る離散確率分布の構築」で初めて紹介した。論文「EXACT: How to Train Your Accuracy」でも議論されているが、これ以上前の出店は見つけていない。
問題の振り返り
$\mathbb{R}^n\mapsto\Delta^{n-1}$の写像を作ること自体は難しくない。$\mathbb{R}\mapsto\mathbb{R}^*$(実数から非負実数)な写像$f(x)$さえ用意できれば(たとえば:$f(x)=x^2$)、
$$
p_i = \frac{f(x_i)}{\sum\limits_{j=1}^n f(x_j)}
$$
これで写像の条件を満たすことができる。「単調性」を満たすのも難しくない。$\mathbb{R}\mapsto\mathbb{R}^*$の単調増加関数を使えばいい(たとえば:$\text{sigmoid}(x)$)。では「不変性」はどうだろうか?不変性を満たす$\mathbb{R}^n\mapsto\Delta^{n-1}$写像は簡単に書けるだろうか?(少なくとも私はできない)
こう思う読者もいるだろう。なぜ単調性と不変性を満たさなきゃいけないのか?確かに、確率分布を近似したいだけなら、この性質は不必要にも見える。どのみち大きなモデルで「ゴリ押し」すれば、近似できない分布なんて無いのだから。ただ、「Softmaxの代替品を作る」という目的を思い出すと、新しい分布はSoftmaxと同じく$\text{argmax}$の連続的な近似であってほしい。であれば、なるべく$\text{argmax}$と同じ性質を有してほしい。これが単調性と不変性にこだわる主な理由である。
Gumbel Max
Perturb MaxはGumbel Maxを一般化することで考案された分布だ。Gumbel Maxは、以下の事実に基づく分布である。
$$
P[\text{argmax}(\boldsymbol{x}+\boldsymbol{\varepsilon}) = i] = softmax(\boldsymbol{x})_i,\quad \boldsymbol{\varepsilon}\sim Gumbel\text{ }Noise
$$
$\boldsymbol{\varepsilon}$の各成分はGumbel分布からの独立的なサンプルである。$\boldsymbol{x}$は既知のベクトルなので、本来$\text{argmax}(\boldsymbol{x})$の値も決まっているが、ランダムノイズ$\boldsymbol{\varepsilon}$を足すことで$\text{argmax}(\boldsymbol{x+\varepsilon})$もランダムになり、各$i$に確率が付与される。そして、このランダムノイズがGumbel分布に従う場合、$i$の確率はちょうど$softmax(\boldsymbol{x})_i$になる。
Gumbel Maxの主な用途は、分布$softmax(\boldsymbol{x})$からサンプリングする手段である。もちろん、単にサンプリングするだけならもっと簡単な方法があるが、Gumbel Maxの最大の価値は、ランダム性をパラメーターを持つ変数$\boldsymbol{x}$からパラメーターを持たない変数$\boldsymbol{\varepsilon}$に移す「再パラメーター化(Reparameterization)」にある。加えて、Softmaxがargmaxの連続的な近似であるように、$softmax(\boldsymbol{x}+\boldsymbol{\varepsilon})$もGumbel Maxの連続的近似である。ゆえにGumbel Maxは、離散サンプリングモジュールを持つモデルを訓練する際に多用される道具である。
一般的なノイズ
Purturb MaxはGumbel Maxから着想を得た。Gumbel分布からSoftmaxを得ることができるなら、Gumbel分布を別の一般的な分布、たとえば正規分布に置き換えれば、新しい分布が得られるのではないだろうか?つまり、
p_i = P[\text{argmax}(\boldsymbol{x}+\boldsymbol{\varepsilon}) = i],\quad \varepsilon_1,\varepsilon_2,\cdots,\varepsilon_n\sim p(\varepsilon)
Gumbel Maxと同じような導出を辿ると、以下のような結果になる。
p_i = \int_{-\infty}^{\infty} p(\varepsilon_i)\left[\prod_{j\neq i} \Phi(x_i - x_j + \varepsilon_i)\right]d\varepsilon_i = \mathbb{E}_{\varepsilon}\left[\prod_{j\neq i} \Phi(x_i - x_j + \varepsilon)\right]
ここで$\Phi(\varepsilon)$は$p(\varepsilon)$の累積分布関数である。一般的な分布の場合、たとえ正規分布であっても、上記の分布を解析的に求めることは困難なので、大まかに推定するしかない。決定的な計算結果をえるためには、逆累積分布関数を利用するといい。つまり、まず$[0,1]$から均等にサンプル$t$を取り、$t=\Phi(\varepsilon)$を解けば$\varepsilon$が求まる。
Perturb Maxの定義、あるいは最終的な$p_i$の形から、Perturb Maxは単調性と不変性を満たすことは明らかである。この分布はどんな場面で役に立つのだろうか?正直よく分かっていない。論文「EXACT: How to Train Your Accuracy」は、新たな確率分布を構築し、正解率の連続的な近似の改善を試みたが、筆者自身で実験してみたところ特に効果は無かった。たぶん、何らかの再パラメーター化が必要な場面で役立つ場合があるのではないかと思う。
SparseMax
続いて紹介するのはSparsemaxと呼ばれる分布写像である。出処は2016年の論文「From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification」で、筆者が提案したSparse Softmaxと同じくスパース性を得るための手法である。ただ、Sparsemax作者のモチベーションは、注意力機構において解釈性を向上させる所にある。Top-kの値を無理やり切り取るだけのSparse Softmaxと違い、Sparsemaxはより適応的なスパース分布の構築手法を提案している。
基本定義
論文ではSparsemaxを下記の最適化問題の解として定義した。
sparsemax(\boldsymbol{x}) = \mathop{\text{argmin}}\limits_{\boldsymbol{p}\in\Delta^{n-1}}\Vert \boldsymbol{p} - \boldsymbol{x}\Vert^2
ラグランジュの未定乗数法を使えば、厳密解の式を求めることができる。ただ、この方法ではSoftmaxとの関係が分かりにくいので、筆者が構想したより簡潔な導出方法を紹介する。
まず、Softmaxは以下のように表すことができる。
$$
\boldsymbol{p}=softmax(\boldsymbol{x})=\text{exp}\left(\boldsymbol{x}-\lambda (\boldsymbol{x})\right)
$$
$\lambda (\boldsymbol{x})$は$\boldsymbol{p}$の成分の和が1になるような定数である。Softmaxの場合、$\lambda (\boldsymbol{x})=\log\sum_ie^{x_i}$である。
Taylor Softmaxを紹介した際、$\text{exp}(x)$の偶数階テイラー展開は常に正であることに触れた。この性質を利用し、偶数階テイラー展開でSoftmaxの変形を作ることができた。では、奇数階はどうか?たとえば$\text{exp}(x)\approx 1+x$は、明らかに非負とは限らない。しかし、$text{relu}$関数を使って、無理やり非負関数に変えてしまうことはできる。つまり、$\text{exp}(x)\approx \text{relu}(1+x)$。この近似式でさっきの$text{exp}$を置き換えてしまえば、Sparsemaxになる。
\boldsymbol{p} = sparsemax(\boldsymbol{x}) = \text{relu}(1+\boldsymbol{x} - \lambda(\boldsymbol{x}))
ここでも$\lambda (\boldsymbol{x})$は$\boldsymbol{p}$の成分の和が1になるような定数である。定数1は$\lambda(\boldsymbol{x})$に統合できるので、実際は以下の式になる。
\boldsymbol{p} = sparsemax(\boldsymbol{x}) = \text{relu}(\boldsymbol{x} - \lambda(\boldsymbol{x}))
定数項の解法
Sparsemaxの形式的な定義は分かったが、$\lambda(\boldsymbol{x})$の求め方がまだ分からないので、検討してみよう。ちなみに、Sparsemaxが単調性と不変性を満たすことは、定義を見れば理解し難くないはずだ。確信が持てない読者は、自ら証明してみればいい。
$\lambda(\boldsymbol{x})$の計算に入ろう。一般性を失わず、$\boldsymbol{x}$の成分は降順に並んでいると仮定する。つまり、$x_1\ge x_2\ge\cdots\ge x_n$。仮に$x_k\geq \lambda(\boldsymbol{x})\geq x_{k+1}$であれば、以下の通りになる。
sparsemax(\boldsymbol{x}) = [x_1 - \lambda(\boldsymbol{x}),\cdots,x_k - \lambda(\boldsymbol{x}),0,\cdots,0]
$\lambda(\boldsymbol{x})$の定義から、
\sum_{i=1}^k[x_i-\lambda(\boldsymbol{x})=1] \Rightarrow 1+k\lambda(\boldsymbol{x})=\sum_{i=1}^kx_i
これで$\lambda(\boldsymbol{x})$が求まる。もちろん、$x_k\geq \lambda(\boldsymbol{x})\geq x_{k+1}$を満たす$k$は分からないが、$k=1,2,\cdots,n$でそれぞれ$\lambda_k(\boldsymbol{x})$を求め、$x_k\geq \lambda_k(\boldsymbol{x})\geq x_{k+1}$を満たす$\lambda_k(\boldsymbol{x})$を探せばいい。
numpyで書くと以下の通りである。
def sparsemax(x):
x_sort = np.sort(x)[::-1]
x_lamb = (np.cumsum(x_sort) - 1) / np.arange(1, len(x) + 1)
lamb = x_lamb[(x_sort >= x_lamb).argmin() - 1]
return np.maximum(x - lamb, 0)
勾配の計算
便利のため、以下の記号を導入しよう。
\Omega (\boldsymbol{x})=\left\{k\vert x_k>\lambda(\boldsymbol{x})\right\}
するとsparsemaxは以下のように書ける。
\boldsymbol{p} = sparsemax(\boldsymbol{x}) = \left\{\begin{aligned}
&x_i - \frac{1}{|\Omega(\boldsymbol{x})|}\left(-1 + \sum_{j\in\Omega(\boldsymbol{x})}x_j\right),\quad &i\in \Omega(\boldsymbol{x})\\
&0,\quad &i \not\in \Omega(\boldsymbol{x})
\end{aligned}\right.
この式のヤコビ行列は、
\frac{\partial p_i}{\partial x_j} = \left\{\begin{aligned}
&1 - \frac{1}{|\Omega(\boldsymbol{x})|},\quad &i,j\in \Omega(\boldsymbol{x}),i=j\\[5pt]
&- \frac{1}{|\Omega(\boldsymbol{x})|},\quad &i,j\in \Omega(\boldsymbol{x}),i\neq j\\[5pt]
&0,\quad &i \not\in \Omega(\boldsymbol{x})\text{ or }j \not\in \Omega(\boldsymbol{x})
\end{aligned}\right.
ここから見て取れるように、$\Omega(\boldsymbol{x})$に含まれているクラスでは、勾配は常に定数なので、勾配消失が起こることはない。ただし、全体の勾配は$\Omega(\boldsymbol{x})$の要素数に左右されている。要素数が少ないほど分布は疎になり、勾配も疎になる。
損失関数
最後に、Sparsemaxを分類器の出力とする際の損失関数を考えてみよう。まず思い当たるのが、Softmaxと同じく交差エントロピー$-\log p_t$を使うことだ。Sparsemaxはゼロ確率を含むため、$\log$の数値エラーを防ぐために$\epsilon$を足し、最終的に交差エントロピーは$-\log\frac{p_t+\epsilon}{1+n\epsilon}$になる。ただ、これだと見栄えが良くないし、凸関数ではないので、理想的な選択ではない。
交差エントロピーがSoftmaxで機能するのは、勾配が$\boldsymbol{p}-\text{onehot}(t)$の形になっているからだ。そこでSparsemaxに対しても、同じように損失関数の勾配が$\boldsymbol{p}-\text{onehot}(t)$であると仮定し、そこから損失関数を逆算すればいい。結論から言うと、損失関数は以下の通りである。
\frac{\partial \mathcal{L}_t}{\partial \boldsymbol{x}} = \boldsymbol{p} - \text{onehot(t)}\quad\Rightarrow\quad \mathcal{L}_t = \frac{1}{2} - x_t + \sum_{i\in\Omega(\boldsymbol{x})}\frac{1}{2}\left(x_i^2 - \lambda^2(\boldsymbol{x})\right)
右の式から左を確かめるのは容易である。左から右を導出するのは少し面倒だが、まあ色々組み合わせれば辿り着くだろう。最初の定数$\frac{1}{2}$は損失関数が非負であることを保証するためのものである。極端な例で確かめてみよう。仮に完璧に最適化できた場合、$\boldsymbol{p}$もone hotになるので、$x_t\to\infty$、かつ$\lambda(\boldsymbol{x})=x_t-1$なので、
- x_t + \sum_{i\in\Omega(\boldsymbol{x})}\frac{1}{2}\left(x_i^2 - \lambda^2(\boldsymbol{x})\right) = -x_t + \frac{1}{2}x_t^2 - \frac{1}{2}(x_t - 1)^2 = -\frac{1}{2}
だから$\frac{1}{2}$を足す必要がある。
Entmax-α
Entmax-αはSparsemaxの一般化である。Sparsemaxは過度に疎な分布になりやすく、学習効率の低下につながる。Entmax-αはパラメーター$\alpha$を導入することで、Softmax($\alpha=1$)からSparsemax($\alpha=2$)へスムーズに変化する関数を提案した。Entmax-αの出処は論文「Sparse Sequence-to-Sequence Models」である。作者はSparsemaxと同じくAndre F.T. Martinsで、スパースなSoftmaxやスパースなAttentionに関する論文を多く出している。興味ある読者は彼のホームページで関連研究を調べてみるといい。
基本定義
Sparsemaxと同じく、論文はEntmax-αを最適化問題の解として定義した。しかしこの定義はTsallis entropyの概念が使われており(Entmax-αのEntもここから来ている)、同じくラグランジュの未定乗数法で解いているので、ここでもこの導入方式は採用しない。
先ほどと同じく、$\text{exp}(x)\approx \text{relu}(1+x)$に基づいて考えよう。まずはSoftmaxとSparsemaxの定義を振り返ってみる。
Softmax: \quad \text{exp}(\boldsymbol{x}-\lambda(\boldsymbol{x}))
Sparsemax: \quad \text{relu}(1+\boldsymbol{x}-\lambda(\boldsymbol{x}))
Sparsemaxが疎になり過ぎるのは、結局$\exp(x)\approx \text{relu}(1 + x)$の近似精度が足りないからだと考えることもできる。ならば、より高精度な近似を考えてみればいい。
\exp(x) = \exp(\beta x / \beta) = \exp^{1/\beta}(\beta x)\approx \text{relu}^{1/\beta}(1 + \beta x)
$0\leq\beta\leq1$であれば、右辺の式は$\text{relu}(1+x)$よりも良い近似になる。この新たな近似を使えばEntmax-αを構築できる。
{{Entmax\text{-}\alpha}}:\quad \text{relu}^{1/\beta}(1+\beta\boldsymbol{x} - \lambda(\boldsymbol{x}))
$\alpha=\beta+1$で置き換えれば論文通りの式になるが、$\beta$で表したほうが簡潔になる。定数1は同じく$\lambda(\boldsymbol{x})$に統合できるので、最終的な定義は以下の通りである。
Entmax_\alpha(\boldsymbol{x})=\text{relu}^{1/\beta}(\beta\boldsymbol{x}-\lambda(\boldsymbol{x}))
定数項の解法
一般的な$\beta$に対して、$\lambda(\boldsymbol{x})$を解くのは少し面倒で、二分法をつかうしかない。
まず、$\boldsymbol{z}=\beta\boldsymbol{x}$とする。一般性を失わず、$z_1\ge z_2\ge \cdots z_n$と仮定し、Entmax-αが不変性を満たすことを利用して更に一般性を失わず$z_1=1$と仮定する(もし$z_1=1$でなければ、各$z_i$から$z_1-1$を減算して新たな$\boldsymbol{z}$を定義し直せばいいから)。
$\lambda=0$のとき、$\text{relu}^{1/\beta}(\beta\boldsymbol{x}-\lambda)$のベクトル成分の和が1を下回らないことを確認でき、また$\lambda=1$のときは$\text{relu}^{1/\beta}(\beta\boldsymbol{x}-\lambda)$のベクトル成分の和が0になる。これにより、ベクトル成分の和が1になるような$\lambda(\boldsymbol{x})$は$[0,1)$の範囲内であることが分かる。ここから、二分法で徐々に最適な$\lambda(\boldsymbol{x})$に近付いていけばいい。
一部の特殊な$\beta$に対しては、直接$\lambda$を求めることもできる。Sparsemaxと等価である$\beta=1$の答えは既に示した。もう一つの例は$\beta=0.5$で、論文が主に取り扱ったケースでもある。特に断りがなければ、通常Entmax-αはEntmax-1.5を指す。Sparsemaxと同じく、$z_k\geq \lambda(\boldsymbol{x})\geq z_{k+1}$とすると、
\sum_{i=1}^k [z_i - \lambda(\boldsymbol{x})]^2 = 1
これはただの$\lambda(\boldsymbol{x})$の一元二次方程式に過ぎないので、解けば以下の通りになる。
\lambda(\boldsymbol{x}) = \mu_k - \sqrt{\frac{1}{k} - \sigma_k^2},\quad \mu_k = \frac{1}{k}\sum_{i=1}^k z_i,\quad\sigma_k^2 = \frac{1}{k}\left(\sum_{i=1}^k z_i^2\right) - \mu_k^2
$k=1,2,\cdots,n$でそれぞれ$\lambda_k(\boldsymbol{x})$を求め、$x_k\geq \lambda_k(\boldsymbol{x})\geq x_{k+1}$を満たす$\lambda_k(\boldsymbol{x})$を見つければいい。ただし、$x_k\ge\lambda_k(\boldsymbol{x})$を満たす最大の$k$を求める、とは異なることに気を付けたい。
def entmat(x):
x_sort = np.sort(x / 2)[::-1]
k = np.arange(1, len(x) + 1)
x_mu = np.cumsum(x_sort) / k
x_sigma2 = np.cumsum(x_sort**2) / k - x_mu**2
x_lamb = x_mu - np.sqrt(np.maximum(1. / k - x_sigma2, 0))
x_sort_shift = np.pad(x_sort[1:], (0, 1), constant_values=-np.inf)
lamb = x_lamb[(x_sort > x_lamb) & (x_lamb > x_sort_shift)]
return np.maximum(x / 2 - lamb, 0)**2
その他
Entmax-αの勾配はSparsemaxと大差ないので、詳しく論じない。読者は自ら導出するか、論文を参照すればよい。損失関数も、同じく$\frac{\partial \mathcal{L}_t}{\partial \boldsymbol{x}} = \boldsymbol{p} - \text{onehot(t)}$から逆算すれば導出できるが、数式は少し複雑になる。興味のある読者は論文「Sparse Sequence-to-Sequence Models」および「Learning with Fenchel-Young Losses」を参照するとよい。
ただ筆者が思うに、stop gradientを用いて損失関数を定義したほうが簡潔で使いやすく、複雑な計算を避けることができる。
\mathcal{L}_t=(\boldsymbol{p}-\text{onehot}(t))\cdot\text{stop_gradient}(\boldsymbol{x})
ここの$\cdot$は内積である。このように損失関数を定義すれば、勾配ちょうど$\boldsymbol{p}-\text{onehot}(t)$になる。ただし、この損失関数は勾配にしか意味が無く、損失の値そのものに意味はない。正でも負でもよく、値が小さいほど良いわけでもない。学習の進捗や効果を評価したければ、交差エントロピーや正解率など、別の指標を求める必要がある。
まとめ
本記事ではSoftmaxとその代替案を振り返った。具体的には、Softmax、Margin Softmax、Taylor Softmax、Sparse Softmax、Perturb Max、Sparsemax、Entmax-αの定義と性質を紹介した。