この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。
苏剑林. (Jun. 14, 2024). 《通向概率分布之路:盘点Softmax及其替代品 》[Blog post]. Retrieved from https://kexue.fm/archives/10145
基礎的な分類タスクから、今や至る所で見る注意力機構まで、確率分布の構築は決定的な役割を担っている。具体的には、$n$次元の任意ベクトルを、$n$要素の離散確率分布に変換するような処理である。ご存知のとおり、この処理の標準的な手法はSoftmaxだ。指数的正規化の形をとるSoftmaxは、比較的分かりやすく、また様々な優れた性質を持つことから、あらゆる場面における「標準装備」となっている。
それでも場合によっては、Softmaxも物足りない点がある。たとえば、スパース性が足りなかったり、ゼロ確率を表現できないなどの性質が挙げられる。そのため、これまで多くの代替案が提案されてきた。本記事では、Softmaxの性質を簡単に振り返り、一部の代替案を列挙して見比べてみようと思う。
Softmaxの振り返り
まずは記号を導入しておこう。$\boldsymbol{x}=(x_1, x_2,\cdots,x_n)\in\mathbb{R}^n$は確率分布に変換する前の$n$次元ベクトルである。ベクトルの成分は正でも負でもよく、上限・下限も特に限られていない。$\Delta^{n-1}$は$n$要素の離散確率分布全体の集合である。すなわち、
\Delta^{n-1}=\left\{\boldsymbol{p}=(p_1,p_2,\cdots,p_n)\mid p_1,p_2,\cdots,p_n\ge0,
\sum^n_{i=1}p_i=1\right\}
$n$ではなく$n-1$と表記したのは、制約$\sum_{i=1}^np_i=1$によって定義されるのは$n$次元空間内の$n-1$次元の部分平面であり、更に$p_i\geq0$の制約もあるので、$(p_1,p_2,\cdots,p_n)$の集合はその部分平面の部分集合に過ぎないからである。つまり、実際の次元数は$n-1$である。
上記の記号に基づけば、本記事が議論する対象は「$\mathbb{R}^n\mapsto\Delta^{n-1}$の写像」という形にまとめられる。$\boldsymbol{x}\in\mathbb{R}^n$は、一般的にLogitsあるいはScoresと呼ばれている。
基本定義
Softmaxの定義は至って単純である。
p_i = softmax(\boldsymbol{x})_i = \frac{e^{x_i}}{\sum\limits_{j=1}^n e^{x_j}}
Softmaxには、エネルギーモデル・統計力学・argmaxの連続的な近似など、数多くの源流や解釈がある。一番最初の出所を特定するのは困難だし、ここでそれを試みるつもりもない。温度パラメーターを追加し、$softmax(\boldsymbol{x}/\tau)$とすることも多いが、$\tau$も$\boldsymbol{x}$の定義に統合することができるので、ここでは特に$\tau$を分離せずに考えよう。
Softmaxの分母は、一般的に$Z(\boldsymbol{x})$と表記される。$Z(\boldsymbol{x})$の対数が、多くの深層学習ライブラリで実装されている$\text{logsumexp}$であり、$\max$の連続的な近似である。
\log Z(\boldsymbol{x}) = \log \sum\limits_{j=1}^n e^{x_j} = \text{logsumexp}(\boldsymbol{x})
\lim_{\tau\to 0^+} \tau\,\text{logsumexp}(\boldsymbol{x}/\tau) = \max(\boldsymbol{x})
$\tau=1$のとき、$\text{logsumexp}(\boldsymbol{x})\approx \max(\boldsymbol{x})$とすることができる。$\boldsymbol{x}$の分散が大きいほど、良い近似になる。
2つの性質
任意のベクトルを確率分布に変換できるほか、Softmaxは以下の性質も満たしている。
- 単調性:$\quad p_i > p_j \Leftrightarrow x_i > x_j,\quad p_i = p_j \Leftrightarrow x_i = x_j$
- 不変性:$\quad softmax(\boldsymbol{x}) = softmax(\boldsymbol{x} + c),\forall c\in\mathbb{R} $
単調性とはつまり、Softmax関数は順序を保存する性質を持つことを意味する。$\boldsymbol{x}$の最大値・最小値は$\boldsymbol{p}$の最大値・最小値とそれぞれ対応する。不変性とはつまり、$\boldsymbol{x}$の各成分に同じ定数を足しても、Softmaxの結果が変わらないことを意味する。これは$argmax$の性質とも一致する。すなわち:$\text{argmax}(\boldsymbol{x})=\text{argmax}(\boldsymbol{x}+c)$。
以上の2つの性質により、Softmaxは実質的に$\text{argmax}$の連続的な近似と見なすことができる(厳密には$ \text{onehot}( \text{argmax}(\cdot))$の近似)。即ち、
\lim_{\tau\to 0^+} softmax(\boldsymbol{x}/\tau) = \text{onehot}(\text{argmax}(\boldsymbol{x}))
おそらくこれがSoftmaxという名前の由来だろう。Softmaxは$\text{argmax}$の近似であり、$\max$と混同しないように気を付けよう。$\max$の近似は$\text{logsumexp}$である。
勾配計算
深層学習の観点から見て、関数の性質を知る重要な手段の一つはその勾配を知ることである。Softmaxに関しては、過去記事「勾配最大化の視点から考えるAttentionのスケーリング処理」で計算してみたことがあった。
\frac{\partial p_i}{\partial x_j} = p_i\delta_{i,j} - p_i p_j =
\begin{cases}
p_i-p_i^2, & i=j \\
-p_ip_j, & i\neq j
\end{cases}
これを行列形式に並べたものをSoftmaxのヤコビ行列(Jacobian Matrix)と呼ぶ。この行列のL1ノルムは、以下のような簡潔な形で書ける。
\frac{1}{2}\left\Vert\frac{\partial \boldsymbol{p}}{\partial \boldsymbol{x}}\right\Vert_1=\frac{1}{2}\sum_{i,j}\left|\frac{\partial p_i}{\partial x_j}\right|=\frac{1}{2}\sum_i (p_i - p_i^2) + \frac{1}{2}\sum_{i\neq j} p_i p_j = 1 - \sum_i p_i^2
$\boldsymbol{p}$がone hot分布であるとき、上の式は0になる。つまりSoftmaxの結果がone hotに近付くほど、勾配消失問題が顕著になるということだ。なので、少なくとも初期化段階では、Softmaxの初期値はone hotに近すぎないよう気を付ける必要がある。
参考実装
Softmaxの実装は簡単だ。$exp$を計算して正規化すればいい。Numpyで書き下ろすと:
def softmax(x):
y = np.exp(x)
return y / y.sum()
しかし、もし$\boldsymbol{x}$の成分が大きめの値を含んでいると、$exp$の計算でオーバーフローが起こりやすい。なので、Softmaxの不変性を利用し、$\boldsymbol{x}$の各値を全体の最大値で引いてから、Softmaxを計算するようにしている。こうすれば、$exp$演算される値は0より小さい値のみになり、オーバーフローの心配はなくなる。
def softmax(x):
y = np.exp(x - x.max())
return y / y.sum()
損失関数
確率分布を構築する主な目的の一つは、単一ラベル・多クラス分類タスクの出力として用いることである。即ち、とある$n$クラス分類タスクに対し、$\boldsymbol{x}$がモデルの出力であるならば、$\boldsymbol{p}=softmax(\boldsymbol{x})$で各クラスの確率を予測したいわけだ。このモデルを訓練するためには、損失関数が必要になる。目標ラベルが$t$としたとき、一般的な選択肢は交差エントロピー損失だろう。
\mathcal{L}_t = - \log p_t = - \log softmax(\boldsymbol{x})_t
この損失関数の勾配は以下の通りである。
-\frac{\partial \log p_t}{\partial x_j} = p_j - \delta_{t,j} =
\begin{cases}
p_t-1, & j=t \\
p_j, & j\neq t
\end{cases}
$t$は既知の値なので、$\delta_{t,j}$はすなわち目標分布$\text{onehot}(t)$である。$p_j$の全体は$\boldsymbol{p}$なので、上の式はこう書くこともできる。
-\frac{\partial \log p_t}{\partial x_j} =\boldsymbol{p}-\text{onehot}(t)
つまり、損失関数の勾配は目標分布と予測分布の差そのものである。両者が等しくならない限り、勾配は存在し続けるし、最適化も止まらない。これが交差エントロピーの長所である。もちろん、場合によっては短所にもなりうる。Softmaxはずっと完全なone hotにはならない、つまり最適化がずっと完全停止しないので、過度な最適化に陥ってしまう可能性があるのだ。これが、このあと議論するいくつかの代替案の動機でもある。
交差エントロピーのほかにも、いくつかの損失関数が考えられる。たとえば、$-p_t$は「正解率の連続的な近似の反数」と見なすことができる。ただ、この損失関数は勾配消失が起こりやすいので、最適化効率は交差エントロピーに及ばない場合が多く、ファインチューニングにのみ使用できる。
Softmaxの変形
Softmaxを一通り復習したので、本ブログで取り上げたことがあるSoftmaxの変形を見てみようMargin Softmax、Taylor Softmax、Sparse Softmaxは、それぞれSoftmaxから発展した関数で、スパース性やロングテール性など異なる性質の改善に重きを置いている。
Margin Softmax
まずは顔認証技術にて提案された、Margin Softmaxと呼ばれるSoftmax変形を紹介しよう。この変形は、後に自然言語処理のSentence Embeddingの訓練にも応用されている。
Margin SoftmaxはSoftmaxという名を持つが、どちらかというと(交差エントロピーのような)損失関数の改善案である。たとえばMargin Softmaxの一種であるAM-Softmaxは、以下のような特徴がある。
- $cos$関数でLogitsを構築している。即ち、$\boldsymbol{x} = [\cos(\boldsymbol{z},\boldsymbol{c}_1),\cos(\boldsymbol{z},\boldsymbol{c}_2),\cdots,\cos(\boldsymbol{z},\boldsymbol{c}_n)]/\tau$。$cos$の範囲は$[-1,1]$であり、確率の差が開かないため、ここでは温度パラメーター$\tau$は必須である
- 単に$-log p_t$を損失とするのではなく、更に強化した形式をとっている
式に書き下ろすと以下の通りである。
\mathcal{L} = - \log \frac{e^{[\cos(\boldsymbol{z},\boldsymbol{c}_t)-m]/\tau}}{e^{[\cos(\boldsymbol{z},\boldsymbol{c}_t)-m]/\tau} + \sum_{j\neq t} e^{\cos(\boldsymbol{z},\boldsymbol{c}_j)/\tau}}
交差エントロピー損失は、$x_t$が$\boldsymbol{x}$の最大値になることを目標としている。一方AM-Softmaxは$x_t$が最大値になるだけでなく、二番目に大きい値を少なくとも$m/\tau$上回ることを目標としている。この$m/\tau$がMarginである。
なぜ追加の目標を設定するのか?これは具体的なタスクの要求によるものである。先に言及したように、Margin Softmaxは顔認証やNLPの意味検索、つまり分類モデルを訓練し情報検索タスクで利用する場面において使われる手法である。単純な交差エントロピーで分類モデルを訓練しても、モデルからエンコードした特徴量が検索タスクの要求を満たさない場合、Marginを加えることで特徴量をよりタイトにさせることができるのだ。
Taylor Softmax
次に紹介するのは、過去記事「exp(x)のx=0地点の偶数階テイラー展開は常に正である」で議論したTaylor Softmaxだ。この関数は$exp(x)$のテイラー展開に関する面白い性質を利用している。
任意の実数$x$および偶数$k$に対し、$f_k(x)\triangleq\sum\limits_{m=0}^k \frac{x^m}{m!} > 0$、つまり$e^x$の$x=0$地点のまわりの偶数階テイラー展開は必ず正である。
この性質を利用し、以下のようなSoftmax変形を作ることができる($k>0$は任意の偶数)
taylor\text{-}softmax(\boldsymbol{x}, k)_i = \frac{f_k(x_i)}{\sum\limits_{j=1}^n f_k(x_j)}
この関数は$exp$のテイラー展開に基づくため、一定の範囲内においてTaylor SoftmaxとSoftmaxは近い結果になり、Taylor SoftmaxでSoftmaxを置き換えることができる。ではTaylor Softmaxの特徴は何かというと、ロングテール(long tail)性が強い点である。Taylor Softmaxは多項式関数の正規化であり、指数関数よりも減衰が緩やかであるため、低確率のクラスにも比較的高い確率を与える傾向がある。そのため、Softmaxにありがちな「自信過剰(over-confident)」現象を緩和することができる。
Sparse Softmax
Sparse Softmaxは筆者(記事原作者Jianlin Su氏)がCAIL20201で提案した、Softmaxの簡単な変形である。
テキスト生成では、決定的なビームサーチ法や、確率的なTopK/TopPサンプリングがデコード手法として使われている。これらのアルゴリズムは予測確率が大きいTokenを選んで探索あるいはサンプリングしており、つまりそのほかのTokenの確率は0と見なしている。一方で学習時に直接Softmaxで確率分布を作ると、ゼロ確率が存在しないため、学習と推論に不一致が起きてしまう。Sparse Softmaxはこの不一致に対処するために考えた手法である。考え方はシンプルで、学習時にTop-k以外のTokenの確率を0にするだけである。
Softmax | Sparse Softmax | |
---|---|---|
定義 | $ p_i = \frac{e^{x_i}}{\sum\limits_{j=1}^n e^{x_j}}$ | $p_i=\begin{cases}\frac{e^{x_i}}{\sum\limits_{i\in\Omega_k} e^{x_i}},&i\in\Omega_k \\ 0,&i\notin\Omega_k\end{cases}$ |
損失関数 | $log\left(\sum^n_{i=1}{e^{x_i}}-x_t\right)$ | $log\left(\sum\limits_{i\in\Omega_k} e^{x_i}\right) - x_t$ |
$\Omega_k$は$x_1, x_2,\cdots,x_n$を降順に並べたものの、最初の$k$個の要素の添え字の集合である。平たく言えば、学習段階で推論と一致した処理を強制するということである。$\Omega_k$はNucleus SamplingのTop-p方式で決めることもできる。ただし、Sparse Softmaxは残りの確率を強引に遮断したため、この部分のLogitsは逆伝播できなくなる。よってSparse Softmaxの学習効率はSoftmaxに及ばず、基本的にファインチューニングにのみ使える手法である。
(続く)
-
CAIL(China AI and Law Challenge):法律系タスクに対するAI技術の性能を競う中国のコンテスト。 ↩