0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【翻訳転載】確率分布に通ずる道:Softmaxとその代替案たち(2)

Posted at

この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。

苏剑林. (Jun. 14, 2024). 《通向概率分布之路:盘点Softmax及其替代品 》[Blog post]. Retrieved from https://kexue.fm/archives/10145

前回記事の続きです。

Perturb Max

この章では、確率分布を構築する新しい手法を紹介する。Perturb Maxと呼ばれる、Gumbel Maxを一般化したものだ。本サイトでは記事「Reparametrizationの視点から見る離散確率分布の構築」で初めて紹介した。論文「EXACT: How to Train Your Accuracy」でも議論されているが、これ以上前の出店は見つけていない。

問題の振り返り

$\mathbb{R}^n\mapsto\Delta^{n-1}$の写像を作ること自体は難しくない。$\mathbb{R}\mapsto\mathbb{R}^*$(実数から非負実数)な写像$f(x)$さえ用意できれば(たとえば:$f(x)=x^2$)、

$$
p_i = \frac{f(x_i)}{\sum\limits_{j=1}^n f(x_j)}
$$

これで写像の条件を満たすことができる。「単調性」を満たすのも難しくない。$\mathbb{R}\mapsto\mathbb{R}^*$の単調増加関数を使えばいい(たとえば:$\text{sigmoid}(x)$)。では「不変性」はどうだろうか?不変性を満たす$\mathbb{R}^n\mapsto\Delta^{n-1}$写像は簡単に書けるだろうか?(少なくとも私はできない)

こう思う読者もいるだろう。なぜ単調性と不変性を満たさなきゃいけないのか?確かに、確率分布を近似したいだけなら、この性質は不必要にも見える。どのみち大きなモデルで「ゴリ押し」すれば、近似できない分布なんて無いのだから。ただ、「Softmaxの代替品を作る」という目的を思い出すと、新しい分布はSoftmaxと同じく$\text{argmax}$の連続的な近似であってほしい。であれば、なるべく$\text{argmax}$と同じ性質を有してほしい。これが単調性と不変性にこだわる主な理由である。

Gumbel Max

Perturb MaxはGumbel Maxを一般化することで考案された分布だ。Gumbel Maxは、以下の事実に基づく分布である。

$$
P[\text{argmax}(\boldsymbol{x}+\boldsymbol{\varepsilon}) = i] = softmax(\boldsymbol{x})_i,\quad \boldsymbol{\varepsilon}\sim Gumbel\text{ }Noise
$$

$\boldsymbol{\varepsilon}$の各成分はGumbel分布からの独立的なサンプルである。$\boldsymbol{x}$は既知のベクトルなので、本来$\text{argmax}(\boldsymbol{x})$の値も決まっているが、ランダムノイズ$\boldsymbol{\varepsilon}$を足すことで$\text{argmax}(\boldsymbol{x+\varepsilon})$もランダムになり、各$i$に確率が付与される。そして、このランダムノイズがGumbel分布に従う場合、$i$の確率はちょうど$softmax(\boldsymbol{x})_i$になる。

Gumbel Maxの主な用途は、分布$softmax(\boldsymbol{x})$からサンプリングする手段である。もちろん、単にサンプリングするだけならもっと簡単な方法があるが、Gumbel Maxの最大の価値は、ランダム性をパラメーターを持つ変数$\boldsymbol{x}$からパラメーターを持たない変数$\boldsymbol{\varepsilon}$に移す「再パラメーター化(Reparameterization)」にある。加えて、Softmaxがargmaxの連続的な近似であるように、$softmax(\boldsymbol{x}+\boldsymbol{\varepsilon})$もGumbel Maxの連続的近似である。ゆえにGumbel Maxは、離散サンプリングモジュールを持つモデルを訓練する際に多用される道具である。

一般的なノイズ

Purturb MaxはGumbel Maxから着想を得た。Gumbel分布からSoftmaxを得ることができるなら、Gumbel分布を別の一般的な分布、たとえば正規分布に置き換えれば、新しい分布が得られるのではないだろうか?つまり、

p_i = P[\text{argmax}(\boldsymbol{x}+\boldsymbol{\varepsilon}) = i],\quad \varepsilon_1,\varepsilon_2,\cdots,\varepsilon_n\sim p(\varepsilon)

Gumbel Maxと同じような導出を辿ると、以下のような結果になる。

p_i = \int_{-\infty}^{\infty} p(\varepsilon_i)\left[\prod_{j\neq i} \Phi(x_i - x_j + \varepsilon_i)\right]d\varepsilon_i = \mathbb{E}_{\varepsilon}\left[\prod_{j\neq i} \Phi(x_i - x_j + \varepsilon)\right]

ここで$\Phi(\varepsilon)$は$p(\varepsilon)$の累積分布関数である。一般的な分布の場合、たとえ正規分布であっても、上記の分布を解析的に求めることは困難なので、大まかに推定するしかない。決定的な計算結果をえるためには、逆累積分布関数を利用するといい。つまり、まず$[0,1]$から均等にサンプル$t$を取り、$t=\Phi(\varepsilon)$を解けば$\varepsilon$が求まる。

Perturb Maxの定義、あるいは最終的な$p_i$の形から、Perturb Maxは単調性と不変性を満たすことは明らかである。この分布はどんな場面で役に立つのだろうか?正直よく分かっていない。論文「EXACT: How to Train Your Accuracy」は、新たな確率分布を構築し、正解率の連続的な近似の改善を試みたが、筆者自身で実験してみたところ特に効果は無かった。たぶん、何らかの再パラメーター化が必要な場面で役立つ場合があるのではないかと思う。

SparseMax

続いて紹介するのはSparsemaxと呼ばれる分布写像である。出処は2016年の論文「From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification」で、筆者が提案したSparse Softmaxと同じくスパース性を得るための手法である。ただ、Sparsemax作者のモチベーションは、注意力機構において解釈性を向上させる所にある。Top-kの値を無理やり切り取るだけのSparse Softmaxと違い、Sparsemaxはより適応的なスパース分布の構築手法を提案している。

基本定義

論文ではSparsemaxを下記の最適化問題の解として定義した。

sparsemax(\boldsymbol{x}) =  \mathop{\text{argmin}}\limits_{\boldsymbol{p}\in\Delta^{n-1}}\Vert \boldsymbol{p} - \boldsymbol{x}\Vert^2

ラグランジュの未定乗数法を使えば、厳密解の式を求めることができる。ただ、この方法ではSoftmaxとの関係が分かりにくいので、筆者が構想したより簡潔な導出方法を紹介する。

まず、Softmaxは以下のように表すことができる。

$$
\boldsymbol{p}=softmax(\boldsymbol{x})=\text{exp}\left(\boldsymbol{x}-\lambda (\boldsymbol{x})\right)
$$

$\lambda (\boldsymbol{x})$は$\boldsymbol{p}$の成分の和が1になるような定数である。Softmaxの場合、$\lambda (\boldsymbol{x})=\log\sum_ie^{x_i}$である。

Taylor Softmaxを紹介した際、$\text{exp}(x)$の偶数階テイラー展開は常に正であることに触れた。この性質を利用し、偶数階テイラー展開でSoftmaxの変形を作ることができた。では、奇数階はどうか?たとえば$\text{exp}(x)\approx 1+x$は、明らかに非負とは限らない。しかし、$text{relu}$関数を使って、無理やり非負関数に変えてしまうことはできる。つまり、$\text{exp}(x)\approx \text{relu}(1+x)$。この近似式でさっきの$text{exp}$を置き換えてしまえば、Sparsemaxになる。

\boldsymbol{p} = sparsemax(\boldsymbol{x}) = \text{relu}(1+\boldsymbol{x} - \lambda(\boldsymbol{x}))

ここでも$\lambda (\boldsymbol{x})$は$\boldsymbol{p}$の成分の和が1になるような定数である。定数1は$\lambda(\boldsymbol{x})$に統合できるので、実際は以下の式になる。

\boldsymbol{p} = sparsemax(\boldsymbol{x}) = \text{relu}(\boldsymbol{x} - \lambda(\boldsymbol{x}))

定数項の解法

Sparsemaxの形式的な定義は分かったが、$\lambda(\boldsymbol{x})$の求め方がまだ分からないので、検討してみよう。ちなみに、Sparsemaxが単調性と不変性を満たすことは、定義を見れば理解し難くないはずだ。確信が持てない読者は、自ら証明してみればいい。

$\lambda(\boldsymbol{x})$の計算に入ろう。一般性を失わず、$\boldsymbol{x}$の成分は降順に並んでいると仮定する。つまり、$x_1\ge x_2\ge\cdots\ge x_n$。仮に$x_k\geq \lambda(\boldsymbol{x})\geq x_{k+1}$であれば、以下の通りになる。

sparsemax(\boldsymbol{x}) = [x_1 - \lambda(\boldsymbol{x}),\cdots,x_k - \lambda(\boldsymbol{x}),0,\cdots,0]

$\lambda(\boldsymbol{x})$の定義から、

\sum_{i=1}^k[x_i-\lambda(\boldsymbol{x})=1] \Rightarrow 1+k\lambda(\boldsymbol{x})=\sum_{i=1}^kx_i

これで$\lambda(\boldsymbol{x})$が求まる。もちろん、$x_k\geq \lambda(\boldsymbol{x})\geq x_{k+1}$を満たす$k$は分からないが、$k=1,2,\cdots,n$でそれぞれ$\lambda_k(\boldsymbol{x})$を求め、$x_k\geq \lambda_k(\boldsymbol{x})\geq x_{k+1}$を満たす$\lambda_k(\boldsymbol{x})$を探せばいい。

numpyで書くと以下の通りである。

def sparsemax(x):
    x_sort = np.sort(x)[::-1]
    x_lamb = (np.cumsum(x_sort) - 1) / np.arange(1, len(x) + 1)
    lamb = x_lamb[(x_sort >= x_lamb).argmin() - 1]
    return np.maximum(x - lamb, 0)

勾配の計算

便利のため、以下の記号を導入しよう。

\Omega (\boldsymbol{x})=\left\{k\vert x_k>\lambda(\boldsymbol{x})\right\}

するとsparsemaxは以下のように書ける。

\boldsymbol{p} = sparsemax(\boldsymbol{x}) = \left\{\begin{aligned} 
&x_i - \frac{1}{|\Omega(\boldsymbol{x})|}\left(-1 + \sum_{j\in\Omega(\boldsymbol{x})}x_j\right),\quad &i\in \Omega(\boldsymbol{x})\\ 
&0,\quad &i \not\in \Omega(\boldsymbol{x}) 
\end{aligned}\right.

この式のヤコビ行列は、

\frac{\partial p_i}{\partial x_j} = \left\{\begin{aligned} 
&1 - \frac{1}{|\Omega(\boldsymbol{x})|},\quad &i,j\in \Omega(\boldsymbol{x}),i=j\\[5pt] 
&- \frac{1}{|\Omega(\boldsymbol{x})|},\quad &i,j\in \Omega(\boldsymbol{x}),i\neq j\\[5pt] 
&0,\quad &i \not\in \Omega(\boldsymbol{x})\text{ or }j \not\in \Omega(\boldsymbol{x}) 
\end{aligned}\right.

ここから見て取れるように、$\Omega(\boldsymbol{x})$に含まれているクラスでは、勾配は常に定数なので、勾配消失が起こることはない。ただし、全体の勾配は$\Omega(\boldsymbol{x})$の要素数に左右されている。要素数が少ないほど分布は疎になり、勾配も疎になる。

損失関数

最後に、Sparsemaxを分類器の出力とする際の損失関数を考えてみよう。まず思い当たるのが、Softmaxと同じく交差エントロピー$-\log p_t$を使うことだ。Sparsemaxはゼロ確率を含むため、$\log$の数値エラーを防ぐために$\epsilon$を足し、最終的に交差エントロピーは$-\log\frac{p_t+\epsilon}{1+n\epsilon}$になる。ただ、これだと見栄えが良くないし、凸関数ではないので、理想的な選択ではない。

交差エントロピーがSoftmaxで機能するのは、勾配が$\boldsymbol{p}-\text{onehot}(t)$の形になっているからだ。そこでSparsemaxに対しても、同じように損失関数の勾配が$\boldsymbol{p}-\text{onehot}(t)$であると仮定し、そこから損失関数を逆算すればいい。結論から言うと、損失関数は以下の通りである。

\frac{\partial \mathcal{L}_t}{\partial \boldsymbol{x}} = \boldsymbol{p} - \text{onehot(t)}\quad\Rightarrow\quad \mathcal{L}_t = \frac{1}{2} - x_t + \sum_{i\in\Omega(\boldsymbol{x})}\frac{1}{2}\left(x_i^2 - \lambda^2(\boldsymbol{x})\right)

右の式から左を確かめるのは容易である。左から右を導出するのは少し面倒だが、まあ色々組み合わせれば辿り着くだろう。最初の定数$\frac{1}{2}$は損失関数が非負であることを保証するためのものである。極端な例で確かめてみよう。仮に完璧に最適化できた場合、$\boldsymbol{p}$もone hotになるので、$x_t\to\infty$、かつ$\lambda(\boldsymbol{x})=x_t-1$なので、

- x_t + \sum_{i\in\Omega(\boldsymbol{x})}\frac{1}{2}\left(x_i^2 - \lambda^2(\boldsymbol{x})\right) = -x_t + \frac{1}{2}x_t^2 - \frac{1}{2}(x_t - 1)^2 = -\frac{1}{2}

だから$\frac{1}{2}$を足す必要がある。

Entmax-α

Entmax-αはSparsemaxの一般化である。Sparsemaxは過度に疎な分布になりやすく、学習効率の低下につながる。Entmax-αはパラメーター$\alpha$を導入することで、Softmax($\alpha=1$)からSparsemax($\alpha=2$)へスムーズに変化する関数を提案した。Entmax-αの出処は論文「Sparse Sequence-to-Sequence Models」である。作者はSparsemaxと同じくAndre F.T. Martinsで、スパースなSoftmaxやスパースなAttentionに関する論文を多く出している。興味ある読者は彼のホームページで関連研究を調べてみるといい。

基本定義

Sparsemaxと同じく、論文はEntmax-αを最適化問題の解として定義した。しかしこの定義はTsallis entropyの概念が使われており(Entmax-αのEntもここから来ている)、同じくラグランジュの未定乗数法で解いているので、ここでもこの導入方式は採用しない。

先ほどと同じく、$\text{exp}(x)\approx \text{relu}(1+x)$に基づいて考えよう。まずはSoftmaxとSparsemaxの定義を振り返ってみる。

Softmax: \quad \text{exp}(\boldsymbol{x}-\lambda(\boldsymbol{x}))
Sparsemax: \quad \text{relu}(1+\boldsymbol{x}-\lambda(\boldsymbol{x}))

Sparsemaxが疎になり過ぎるのは、結局$\exp(x)\approx \text{relu}(1 + x)$の近似精度が足りないからだと考えることもできる。ならば、より高精度な近似を考えてみればいい。

\exp(x) = \exp(\beta x / \beta) = \exp^{1/\beta}(\beta x)\approx \text{relu}^{1/\beta}(1 + \beta x)

$0\leq\beta\leq1$であれば、右辺の式は$\text{relu}(1+x)$よりも良い近似になる。この新たな近似を使えばEntmax-αを構築できる。

{{Entmax\text{-}\alpha}}:\quad \text{relu}^{1/\beta}(1+\beta\boldsymbol{x} - \lambda(\boldsymbol{x}))

$\alpha=\beta+1$で置き換えれば論文通りの式になるが、$\beta$で表したほうが簡潔になる。定数1は同じく$\lambda(\boldsymbol{x})$に統合できるので、最終的な定義は以下の通りである。

Entmax_\alpha(\boldsymbol{x})=\text{relu}^{1/\beta}(\beta\boldsymbol{x}-\lambda(\boldsymbol{x}))

定数項の解法

一般的な$\beta$に対して、$\lambda(\boldsymbol{x})$を解くのは少し面倒で、二分法をつかうしかない。

まず、$\boldsymbol{z}=\beta\boldsymbol{x}$とする。一般性を失わず、$z_1\ge z_2\ge \cdots z_n$と仮定し、Entmax-αが不変性を満たすことを利用して更に一般性を失わず$z_1=1$と仮定する(もし$z_1=1$でなければ、各$z_i$から$z_1-1$を減算して新たな$\boldsymbol{z}$を定義し直せばいいから)。

$\lambda=0$のとき、$\text{relu}^{1/\beta}(\beta\boldsymbol{x}-\lambda)$のベクトル成分の和が1を下回らないことを確認でき、また$\lambda=1$のときは$\text{relu}^{1/\beta}(\beta\boldsymbol{x}-\lambda)$のベクトル成分の和が0になる。これにより、ベクトル成分の和が1になるような$\lambda(\boldsymbol{x})$は$[0,1)$の範囲内であることが分かる。ここから、二分法で徐々に最適な$\lambda(\boldsymbol{x})$に近付いていけばいい。

一部の特殊な$\beta$に対しては、直接$\lambda$を求めることもできる。Sparsemaxと等価である$\beta=1$の答えは既に示した。もう一つの例は$\beta=0.5$で、論文が主に取り扱ったケースでもある。特に断りがなければ、通常Entmax-αはEntmax-1.5を指す。Sparsemaxと同じく、$z_k\geq \lambda(\boldsymbol{x})\geq z_{k+1}$とすると、

\sum_{i=1}^k [z_i - \lambda(\boldsymbol{x})]^2 = 1

これはただの$\lambda(\boldsymbol{x})$の一元二次方程式に過ぎないので、解けば以下の通りになる。

\lambda(\boldsymbol{x}) = \mu_k - \sqrt{\frac{1}{k} - \sigma_k^2},\quad \mu_k = \frac{1}{k}\sum_{i=1}^k z_i,\quad\sigma_k^2 = \frac{1}{k}\left(\sum_{i=1}^k z_i^2\right) - \mu_k^2

$k=1,2,\cdots,n$でそれぞれ$\lambda_k(\boldsymbol{x})$を求め、$x_k\geq \lambda_k(\boldsymbol{x})\geq x_{k+1}$を満たす$\lambda_k(\boldsymbol{x})$を見つければいい。ただし、$x_k\ge\lambda_k(\boldsymbol{x})$を満たす最大の$k$を求める、とは異なることに気を付けたい。

def entmat(x):
    x_sort = np.sort(x / 2)[::-1]
    k = np.arange(1, len(x) + 1)
    x_mu = np.cumsum(x_sort) / k
    x_sigma2 = np.cumsum(x_sort**2) / k  - x_mu**2
    x_lamb = x_mu - np.sqrt(np.maximum(1. / k - x_sigma2, 0))
    x_sort_shift = np.pad(x_sort[1:], (0, 1), constant_values=-np.inf)
    lamb = x_lamb[(x_sort > x_lamb) & (x_lamb > x_sort_shift)]
    return np.maximum(x / 2 - lamb, 0)**2

その他

Entmax-αの勾配はSparsemaxと大差ないので、詳しく論じない。読者は自ら導出するか、論文を参照すればよい。損失関数も、同じく$\frac{\partial \mathcal{L}_t}{\partial \boldsymbol{x}} = \boldsymbol{p} - \text{onehot(t)}$から逆算すれば導出できるが、数式は少し複雑になる。興味のある読者は論文「Sparse Sequence-to-Sequence Models」および「Learning with Fenchel-Young Losses」を参照するとよい。

ただ筆者が思うに、stop gradientを用いて損失関数を定義したほうが簡潔で使いやすく、複雑な計算を避けることができる。

\mathcal{L}_t=(\boldsymbol{p}-\text{onehot}(t))\cdot\text{stop_gradient}(\boldsymbol{x})

ここの$\cdot$は内積である。このように損失関数を定義すれば、勾配ちょうど$\boldsymbol{p}-\text{onehot}(t)$になる。ただし、この損失関数は勾配にしか意味が無く、損失の値そのものに意味はない。正でも負でもよく、値が小さいほど良いわけでもない。学習の進捗や効果を評価したければ、交差エントロピーや正解率など、別の指標を求める必要がある。

まとめ

本記事ではSoftmaxとその代替案を振り返った。具体的には、Softmax、Margin Softmax、Taylor Softmax、Sparse Softmax、Perturb Max、Sparsemax、Entmax-αの定義と性質を紹介した。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?