「寡(すくな)きを患えずして均(ひと)しからざるを患う」:「論語」の一節。一国の為政者は富の量よりも富の平等を重視すべきだと説いている。
この記事は中国のNLP研究者Jianlin Su氏が運営するブログ「科学空間」で掲載された解説記事の日本語訳です。
苏剑林. (Feb. 21, 2025). 《MoE环游记:2、不患寡而患不均 》[Blog post]. Retrieved from https://kexue.fm/archives/10735
前回記事「MoE世界一周(1):幾何学的な意味を探る」では、MoEの幾何学的な解釈を通して、「Denseモデルの近似」という視点からMoEを構築する方法を紹介した。記事の最後に、MoEの定義は始まりに過ぎず、実際に有効なMoEを訓練するには多くの課題を解決する必要があることにも触れた。その一つがロードバランス(Load Balance)の問題である。
ロードバランスとは、「寡きを患えずして均しからざるを患う」という諺もあるように、すべてのExpertにちゃんと仕事をさせ、一部のExpertがリソースを無駄遣いするのを避けることだ。ロードバランスは計算リソースをフル活用するため、またMoEの大きなパラメーター量を活かすために必要な措置である。
課題の分析
MoEの基本形式は以下の通りだ。
\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}} \rho_i \boldsymbol{e}_i
一般的なMoEにおいて、$\boldsymbol{\rho}$は確率分布(Roueter)、$\boldsymbol{e}_i=\boldsymbol{v}_i$、$\boldsymbol{v}_i$は小規模なFFN(Expert)の出力である。前回記事で議論した「幾何学的な」MoEでは、$\boldsymbol{\rho}$に正規化要件はなく、Expertのノルムを予測するための量である。一方$\boldsymbol{e}_i=\boldsymbol{v}_i/\Vert \boldsymbol{v}_i\Vert$は、Expertの方角を予測する量になる。
いずれのMoEも性能的には大差なく、単に視点の違いでしかない。ただし、MoEの数式は「Tokenが入力されるたびに、適切なExpertを選んで計算する」という順序で処理している風に見えるが、実際の学習時は逆である。すなわち、まず各Expertに計算リソースを分配してから、TokenをExpertに割り当て(Route)、並行的に計算を行っているのだ。
こうなると、仮にExpertの負担が不均等だと、色々まずい状況が起こる。たとえば、一部のExpertはほとんど放置され、リソースを無駄遣いする(Dead Expert)。一方で、一部のExpertに割り当てられたTokenが多すぎて、計算が間に合わず、一部のTokenをドロップせざるを得なくなる。Dead Expertが発生することは、実質的なMoEのパラメーター量が想定を下回っていることになるので、これでは折角多くのGPUメモリを用意しても小規模モデルの性能しか得られない。
学習効率と性能を最大化するためには、Expertの負担を均等化することが必須である。
補助的な損失
負担の均等化を促す一般的なアプローチは、そのための損失関数を追加することだ。このような損失関数をAux Loss (Auxiliary Loss)と呼ぶ。いま広く使われているAux Lossは、2020年の論文「GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding」で提案されたものだ。
Aux Lossを紹介する前に、まずはいくつかの概念を導入しよう。先ほど、一般的なMoEにおいて、$\boldsymbol{\rho}$は確率分布である必要がないとしたが、$\boldsymbol{\rho}$を正規化したものを$\boldsymbol{p}=[p_1,p_2,\cdots,p_n]$とし、そのTop-kを$\boldsymbol{f}=[f_1,f_2,\cdots,f_n]$とする。つまり、
p_i = \frac{\rho_i}{\sum_{i=1}^n \rho_i},\qquad f_i = \left\{\begin{aligned}1/k, \quad i\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho} \\
0, \quad i\not\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}\end{aligned}\right.
続いて、$\boldsymbol{P}=\mathbb{E}[\boldsymbol{p}],\boldsymbol{F}=\mathbb{E}[\boldsymbol{f}]$とする。ここで$\mathbb{E}$は全てのサンプルのTokenに対して平均をとる計算を指している。$\boldsymbol{F}$はExpertの現在の負担割合で、$\boldsymbol{P}$は$\boldsymbol{F}$の連続的な近似であることがわかる。
これらの記号を利用して、Aux Lossは以下のように定義されている。
\mathcal{L}_{aux}=\boldsymbol{F}\cdot\boldsymbol{P}=\sum_{i=1}^nF_iP_i
Aux Lossに$n$を乗算し、損失関数を$n\mathcal{L}_{aux}$としている文献も多い。大規模なMoEの場合、ノードごとにAux Lossを計算し、ノード内のバランスを保つことでノード間の通信を減らす、といったテクニックもある。一方最近の研究では、局所的なバランスを強制することは最終的な性能に悪影響を及ぼす可能性が高いことも示唆されている。
STE (Straight-Through Estimator)
Aux Lossについて、ひとつ疑問に思うところはないだろうか。Aux Lossを提案した最初の論文も、後続研究も、紹介記事も、とにかく筆者が読んだあらゆる文献において、Aux Lossは厳密な証明がないまま使用されている。まるでAux Lossが負担の均等化を促す効果があることは自明であるかのようだ。しかし、本当にそこまで自明な事実なのだろうか?
少なくとも筆者はそう思えない。なので、Aux Lossの定義を導出するひとつの考え方を提示してみようと思う。この考え方を利用すれば、ほかのAux Lossを定義することもできる。
まず、一様分布$\boldsymbol{Q}=(1/n,1/n,\cdots,1/n)$を定義する。$\boldsymbol{F}$は現在の負担割合なので、負担が完全に均等になった場合、$\boldsymbol{F}=\boldsymbol{Q}$になるだろう。となると、以下のようなAux Lossが考えられる。
\mathcal{L}_{aux}=\frac{1}{2}\Vert\boldsymbol{F}-\boldsymbol{Q}\Vert^2=
\frac{1}{2}\sum_{i=1}^n\Vert F_i-1/n\Vert^2
しかし、$\boldsymbol{F}$は$argtop_k$から算出されたので、上の式は微分可能な目標関数ではなく、直接利用することはできない。どうすればいいかというと、STE (Straight-Through Estimator) というテクニックを利用し、前向き計算と逆伝播計算の関数を別々に設計すればいい。$\boldsymbol{F}$は微分不可だが、その連続的な近似$\boldsymbol{P}$は微分可能だ。ならば、逆伝播の際は$\boldsymbol{F}$を$\boldsymbol{P}$に置き換えてしまえばいい。つまり、
\mathcal{L}_{aux} = \frac{1}{2}\Vert \boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n)^2
$\text{sg}[]$はstop gradient演算子、つまり前向き計算の結果は変えず、勾配だけ強制的にゼロにする処理である。こうすれば、$\mathcal{L}_{aux}$は使える損失関数になる。
試しに、この関数の勾配を求めてみよう。
\begin{aligned}
\nabla_{\boldsymbol{\theta}}\mathcal{L}_{\text{aux}} =&\, \frac{1}{2}\nabla_{\boldsymbol{\theta}}\sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n)^2 \\
=&\, \sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n) \nabla_{\boldsymbol{\theta}}(P_i + \text{sg}[F_i - P_i] - 1/n)\\
=&\, \sum_{i=1}^n (F_i - 1/n) \nabla_{\boldsymbol{\theta}}P_i = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n (F_i - 1/n) P_i\\
=&\, \nabla_{\boldsymbol{\theta}}\left(\sum_{i=1}^n F_i P_i\right)
\end{aligned}\
$\boldsymbol{\theta}$はモデルパラメーターを指している。結果的に、筆者がSTEから導き出した損失関数の勾配は、最初に示した損失関数の勾配と一致することが分かった。つまり2つの損失関数は実質的に等価である。
一般的な形式
先ほどの導出方法は、Aux Lossを構築する一般的なアプローチを示している。すなわち、$\boldsymbol{F}$が満たすべき条件に基づいて損失関数を定義し、実装時は$\boldsymbol{F}$を$\boldsymbol{P}+\text{sg}[\boldsymbol{F}-\boldsymbol{P}]$に置き換える、という手法だ。
たとえば、エントロピー最大化も分布を均等化する効果があることから、エントロピーの反数でAux Lossを構築することもできる。
\mathcal{L}_{aux}=\sum_{i=1}^n(P_i+\text{sg}[F_i-P_i])\log(P_i+\text{sg}[F_i-P_i])
これでも簡単に実装することができるが、より簡潔な形式にしたければ、同じように勾配を計算してみればいい。
\nabla_{\boldsymbol{\theta}}\mathcal{L}_{\text{aux}} = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n(P_i + \text{sg}[F_i - P_i]) \log(P_i + \text{sg}[F_i - P_i]) = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n P_i \log F_i
勾配計算の際、どちらも以下の恒等式を利用している。
\sum^n_{i=1}\nabla_{\boldsymbol{\theta}}P_i=\nabla_{\boldsymbol{\theta}}\sum^n_{i=1}P_i=\nabla_{\boldsymbol{\theta}}1=0
これは$\boldsymbol{P}$が確率分布であることと、目標分布$\boldsymbol{Q}$が一様分布であることに依存している。もし簡潔化を追求せず、直接$\boldsymbol{F}=\boldsymbol{P}+\text{sg}[\boldsymbol{F}-\boldsymbol{P}]$を使うなら、この制約を満たす必要はない。
たとえば、$\boldsymbol{P}$を$\boldsymbol{F}$の近似として使った理由は、「$P_i$が大きければ$F_i$もだいたい大きい」という性質があるからに過ぎず、正規化されていない$\mathbb{E}[\boldsymbol{\rho}]$を$\boldsymbol{P}$としても問題ない。これは特定の場面(たとえば$\boldsymbol{\rho}$が負の値も取りうるので、確率分布に正規化できない場合)では重要になることもある。あるいは、$\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2$の$\boldsymbol{Q}$を任意のベクトルに置き換えれば、$\boldsymbol{F}$を欲しい形に誘導することもできる。
まとめ
本記事はMoEのロードバランス問題を紹介し、Aux Lossを構築する一般的な考え方を示した。Aux Lossのほかにも、負担の均等化を促す手法はいくつかある。これは次回また紹介しよう。