この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。
苏剑林. (Mar. 05, 2025). 《MoE环游记:3、换个思路来分配 》[Blog post]. Retrieved from https://kexue.fm/archives/10757
本記事では引き続きMoEの負担均等化について議論する。前回記事「MoE世界一周(2): 寡きを患えずして均しからざるを患う」では、Aux Lossを利用して負担均等化を促す手法を紹介した。Aux Lossは分かりやすいが、「重みの調整が難しい」という大きな欠点がある。重みが小さすぎると均等化が進まないし、大きすぎると学習ロスに響く。ゆえに、MoE界隈はAux Lossに替わる代替案を探し続けていた。
本記事で紹介するのは「Loss-Free」と呼ばれる手法で、DeepSeekが論文「Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts」で提案した。DeepSeekが発表した数多の輝かしいオープンソースプロジェクトと比較して、あまり目立った成果ではないかもしれないが、この論文の潜在的なインパクトはほかの成果を遥かに上回るかもしれないと筆者は考える。提案された手法はシンプルかつ効果的であり、加えて極めて汎用性が高く、模範的な論文だと言える。
大まかな説明
負担を均等化するためにAux Lossがとった方法は、追加の損失を与えることでRouterが均等的なスコアを出力するよう誘導することである。一方Loss-Freeは分配の方法そのものを変える。つまりRouterが出力するスコアは変えず、$\text{argtop}_k\boldsymbol{\rho}$という分配方式を変えるのだ。
実は、既に似たような試みはされている。2021年にFacebookが提案したBASE Layerは、Expertの分配を線形割当問題と見なした。つまり、負担の均等化を制約条件とし、その制約条件下でRouterの総スコアがなるべく高くなるような分配方式を求める問題である。この問題はハンガリアン法で解くことができる。しかし、この手法は全体のTokenのスコアを先に知る必要はあるので、自己回帰的なLLMにだと学習時にしか使えず、推論は依然$\text{argtop}_k\boldsymbol{\rho}$を使うしかないので、学習と推論で不一致が生じる。しかも、アルゴリズムの制約上、$k=1$の場合にしか適用できない。
一方、Loss-Free手法は非常に簡潔で効果的だ。提案手法は、いかなる場合でも、$\text{argtop}_k\boldsymbol{\rho}+\boldsymbol{b}$の分配が均等になるようなバイアス項$\boldsymbol{b}$が存在する、という事実に着目した。そこで、MoEの数式を以下のように改造した。
\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i
$\boldsymbol{b}$は入力と無関係なベクトルである。学習を経て決定され、そのまま推論で使われるので、学習と推論過程は一致する。$\boldsymbol{e}_i$に乗算するのは依然$\rho_i$であって、$\rho_i+b_i$ではないことに注意したい。$\boldsymbol{b}$は分配にのみ関与し、MoEの前向き計算自体には関与しないのだ。なので、$\boldsymbol{b}$と$\boldsymbol{\rho}+\boldsymbol{b}$の符号にも特に制約はない。
「手作り」な勾配
では、$\boldsymbol{b}$はどう学習させればいいのだろうか。$\boldsymbol{b}$の最適化目標は、負担の均等化であることを思い出そう。前回記事に倣って、まず記号$\boldsymbol{f}=[f_1,f_2,\cdots,f_n]$を定義する。
f_i = \left\{\begin{aligned}1/k, \quad i\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}+\boldsymbol{b} \\
0, \quad i\not\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}+\boldsymbol{b}\end{aligned}\right.
また、$\boldsymbol{F}=\mathbb{E}[\boldsymbol{f}]$とする。$\boldsymbol{F}$は$\boldsymbol{b}$のバイアスを追加した条件下の、Expertの負担分布である。前回定義した記号$\boldsymbol{Q}=(1/n,1/n,\cdots,1/n)$を使えば、負担均等化とは以下の式を最小化することになる。
\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (F_i - 1/n)^2
この式は微分不可だが、前回記事と同様、STE(Straight-Through Estimator)テクニックで解決できる。STEの鍵は、微分可能かつ$\boldsymbol{F}$と等しい増減傾向を持つ関数を$\boldsymbol{F}$の連続的な近似とすることだ。いま、我々が最適化したいパラメーターは$\boldsymbol{b}$のみで、しかも$\boldsymbol{b}$はちょうど望ましい性質を持ち合わせている($b_i$を大きくすれば、$i$が選ばれる確率は上がり、$F_i$も大きくなる)。であれば、答えは明らかだろう。
\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (b_i + \text{sg}[F_i - b_i] - 1/n)^2
勾配は
\nabla_{\boldsymbol{b}}\mathcal{L}_{\text{aux}} = \frac{1}{2}\nabla_{\boldsymbol{b}}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \boldsymbol{F} - \boldsymbol{Q}
なので、確率的勾配降下法の更新式は
\boldsymbol{b}\leftarrow\boldsymbol{b}-\alpha(\boldsymbol{F-Q})
$\alpha$は$\boldsymbol{b}$の学習率である。ただし、Loss-Free手法が最終的に採用した更新式は少し異なっており、符号勾配降下法を使っている。
\boldsymbol{b}\leftarrow\boldsymbol{b}-\alpha\text{sign}(\boldsymbol{F-Q})
手法自体は単純だ。$F_i$が$1/n$より大きければ$b_i$を下げ、逆ならば$b_i$を上げる、というだけの話である。
改良版
符号勾配降下法のほかに、筆者は$\boldsymbol{F}-\boldsymbol{Q}$にRMS Norm を加える手法(すなわちNormalized SGD)も、同じ$\alpha$でより良い効果が得られることに気付いた。
\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha\frac{\boldsymbol{F} - \boldsymbol{Q}}{\text{RMS}(\boldsymbol{F} - \boldsymbol{Q})}
$\text{RMS}$は"Root Mean Square"の略で、
\text{RMS}(\boldsymbol{F} - \boldsymbol{Q}) = \sqrt{\frac{1}{n}\sum_{i=1}^n (F_i - Q_i)^2}
$\text{sign}(\boldsymbol{F-Q})$と$\frac{\boldsymbol{F} - \boldsymbol{Q}}{\text{RMS}(\boldsymbol{F} - \boldsymbol{Q})}$のRMSはいずれも1であることから、両者の尺度は大体同じであることが分かるので、同じ$\alpha$を使うことができる。
分かりやすく説明すると、$\text{sign}$は$F_i$と$Q_i$間の距離に関わらず同じ更新幅を採用するため、既に$Q_i$に近付いている$F_i$が逆に元のバランスから脱してしまう可能性があった。一方でRMS Normは$F_i-Q_i$の距離を保持しており、より適応的に更新幅を設定するので、理屈上はよりバランスを促進する効果があるはずで、実際の実験でもより良い効果を示している。
同じ流れ
論文がLoss-Free手法を説明する際は、上記のようなAux Lossに基づいた導出を示したわけではなく、例の更新式を直接提示したので、$\boldsymbol{b}$の勾配$\text{sign}(\boldsymbol{F-Q})$を天下り的に「手作り」したような印象がある。これもLoss-Freeと呼ばれる所以である。
しかし、先ほどの導出が示しているように、更新式はAux Lossを基に導き出すこともできる。両者の考え方は一緒なのだ。Loss-Freeの主なメリットは、Aux Lossの重みを調整せずに済むことだが、それでも学習率$\alpha$はまだ残っている。論文では$\alpha=0.001$という値を与えてくれてはいるが、ハイパラを完全に消したわけではないと言わざるを得ない。
筆者から見て、Loss-Free論文の本質的な新規性はAux Lossを無くしたことではなく、Aux lossとLM Lossが最適化するパラメーターを分離することで、負担均衡とモデル性能を両立させたことである。ここで重要なのは、「バイアス項を足すだけで負担を均衡化できる」という事実に着目した点である。Aux Lossにはそのバイアス項$\boldsymbol{b}$を最適化させ、LM Lossはその他のパラメーターを最適化させることで、Aux LossによるLM Lossの副作用を最小限に抑えたのだ。
従来のAux Lossはモデルの全パラメーターを使い負担均衡を促す。LM Lossも全パラメーターで最適化するため、両者の最適化方向が相容れない場合もある。この時、最適なバランスを見つけることは困難だ。Loss-Free手法は「バイアス項だけで負担を均衡化できる」という発見に基づき、両損失関数が最適化するパラメーターを隔離する手法であり、負担均衡化問題の絶妙な解決策である。
細かい注意点
Loss-Free手法は十分簡潔な手法だが、それでも扱う際はいくつかの細かい点に注意する必要がある。
まず、ミニバッチが与えられたとき、まずLM Lossでモデルパラメーターを更新してから、$\boldsymbol{b}$を更新すべきである。$\boldsymbol{b}$の更新はToken全体の統計的な情報$\boldsymbol{F}$に依存しているので、先に$\boldsymbol{b}$を更新してしまうと、未来の情報をリークしてしまうおそれがある。たかがベクトル一つにリークされる情報量など大したことないと思われるかもしれないが、リスクがある以上はなるべく避けるべきだ。
論文では$\alpha=0.001$に設定されていることに触れたが、これは論文ではSigmoid関数をRouterの活性化関数に選んだことと関連していると思われる。Sigmoid関数を通すと、各$\rho_i$は比較的独立的な値になり、しかも$(0,1)$の範囲に限定されるので、$\alpha=0.001$という更新幅は上手く機能する。仮にSoftmaxやReLU、あるいはほかの活性化関数を使うなら、適切な$\alpha$を見つけ直す必要があるかもしれない。
筆者としては、GateとBiasに用いる活性化関数を分離することを勧めたい。つまり、
\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}^{(\sigma)} + \boldsymbol{b}} \rho_i^{(h)} \boldsymbol{e}_i
ここで$\boldsymbol{\rho}^{(\sigma)} = \sigma(\boldsymbol{x}\boldsymbol{W}^{(R)}), \boldsymbol{\rho}^{(h)} = h(\boldsymbol{x}\boldsymbol{W}^{(R)})$であり、$\sigma(\cdot)$はSigmoid関数、$h(\cdot)$は任意の単調かつ非負の関数である。つまり、Sigmoid関数で正規化した値に、$\boldsymbol{b}$を足すのである。Expertに乗算するGateは、Sigmoid関数の単調性と一致する任意の別の関数を使うことができる。
学習を回した結果、$b_i$の絶対値が1を上回ることがあり、絶対値が全体的に大きくなっていく可能性もあるが、モデルの効果自体には影響しない。すべての$b_i$に同じ定数を足しても、$\text{argtop}_k\boldsymbol{\rho +b}$の結果は変わらないからである。この$\boldsymbol{b}$の余分な自由度を利用して、とある「面白いこと」をすることもできるのだが、この話はまた次回に。
更なる延長
MoEの負担均等化以外にも、Loss-Freeの思想は色んな似たような問題に応用することができる。たとえば、VQ-VAEのコードブック崩壊(Codebook Collapse)も、より自然で普遍的な解決手法を提供できる。本記事の冒頭で、「Losss-Free手法の潜在的なインパクトはほかの成果を遥かに上回るかもしれない」と評したのも、Loss-Free手法には優れた普遍性があるからだ。
数学的な視点から見て、Loss-Free手法の主な貢献は、勾配降下法で線形分配問題を解く手法を示したことである。一般的な線形分配問題は以下の式で表される。
\min_f \sum_{i=1}^n c_{i, f(i)}
$c_{i,j}$は事前に与えられたコスト関数、$f$は集合$[1,2,\cdots,n]$から自身への全単射である。本記事の問題に当てはめると、$c_{i,j}$はまさに$n$個のトークンと$n$個のExpertに対するスコアで、$f$は一つの負担均衡化案にあたる。
この類の問題は、制約条件を満たす空間内でなるべく最適な解を捜索することが一般的な解法だが、Loss-Free手法はまず「最適だが制約条件を満たさない解」を設定する。
f(i)\mathop{\text{argmin}}_j c_{i,j}
この解は明らかに最適解だが、全単射の条件を満たすとは限らない。全単射を満たさない状態は、言い換えれば負担が不均衡な状態である。そこで、例のバイアス項を導入する。
f(i)\mathop{\text{argmin}}_j c_{i,j}+b_j
$b_j$の初期値はゼロとし、上述の更新式で更新し続ければいい。平たく言えば、$j$が出現する回数が多ければ、$b_j$の値をちょっと減らし、逆ならばちょっと増やす。これを全単射になるまで繰り返すだけだ。
まとめ
本記事はMoEの負担均衡化問題を解くためのLoss-Free手法を紹介した。DeepSeekによって提案された本手法の核心は、バイアス項の導入によって負担均衡化を実現することである。本記事はこの手法とAux Lossの関連を考察し、更に数学的な視点から潜在的な応用を掘り下げた。