はじめに
前回の記事では、MoE (Mixture of Experts) の直感的なイメージについてまとめました。
エキスパートが大活躍! LLM の最新トレンド MoE モデルってなんだ?
今回は MoE の元論文について紹介します。この論文では
- モデル容量だけを増加させる
- 1 トークンあたりの計算量はあまり増やさない(学習時・推論時とも)
- GPU クラスタ上で高速に動かす
という、MoE を実務レベルまで落とし込んだアイデアと実装が紹介されています。少し長い内容になるので前半後半に分けて、本記事ではその前編として MoE レイヤの定義と学習を安定させるための理論を中心に取り上げます。
参考文献:
N. Shazeer, A. Mirhoseini, K. Maziarz, A. Davis, Q. V. Le, G. E. Hinton & J. Dean. "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer." ICLR 2017
🐣 前回はふんわり系だったので今回はちゃんと論文を読んでみます
この論文でやっていること
論文の目的は
条件付き計算 (conditional computation) を使って、とにかく巨大なネットワークを動かしたい!
です。そのために
- 汎用コンポーネントとしての「MoE レイヤ」を設計
- ゲートの設計を工夫して疎な Top-k ルーティングを実装
- GPU クラスタ上での並列化・通信オーバーヘッドの問題を解析
- 1B Word LM、100B Word LM、機械翻訳 (WMT14, Google 本番) で検証
という理論から実装までを含む幅広い内容になっています。
前回の記事で説明した「FFN をたくさん並べて必要なものだけ通す」というイメージが、まさにここで具体化されているわけです!
Sparsely-Gated MoE レイヤの数式【学習時・推論時共通】
まず MoE レイヤの中身を数式で表してみましょう。以下の式は学習時・推論時ともに共通で使われます。
入力を $x$、エキスパート (FFN) を $E_1, \dots, E_n$、ゲートの出力を $G(x) \in \mathbb{R}^n$ とすると、MoE レイヤの出力 $y$ は
y = \sum_{i=1}^{n} G(x)_i \, E_i(x)
というシンプルな形になります。ここで
- 各 $E_i$ はそれぞれ別パラメータを持つ FFN
- $G(x)_i$ は「エキスパート $i$ をどれくらい使うか」の重み
- $G(x)$ の多くの要素は 0 になるように設計する (sparse)
というのがポイントです。
ここで重要なのは、$G(x)_i = 0$ になったエキスパート $E_i$ については「出力に 0 を掛ける」のではなく、実装上そもそも forward を計算しない形で完全にスキップされることです。式の上では全ての $E_i(x)$ が並んでいますが、実際の計算グラフ上では Top-$k$ 以外のエキスパートは呼び出されません。
推論時にも上位 $k$ 個のエキスパートしか活性化しないため、全パラメータを使う dense モデルに比べて計算量を大幅に削減できます。ただし、全エキスパートのパラメータはメモリ上にロードしておく必要があるため、VRAM 使用量は総パラメータ数に比例します。
🐣 計算量と異なり、メモリは節約されません
前回の記事で「FFN がエキスパート集合になる」と書きましたが、論文でも
- エキスパートは FFN
- MoE レイヤは Transformer などの中に挿せる汎用ブロック
というスタンスで設計されています。
ゲーティング: Noisy Top-k Softmax【学習時の工夫】
ゲートは「どのエキスパートにルーティングするか」を決める小さなネットワークです。論文では段階的に定義されています。学習時と推論時で処理が異なるため、明確に区別して追っていきます。
1. 通常の softmax ゲート【学習時・推論時共通の基本形】
まずはシンプルな softmax ゲートです。
G_\sigma(x) = \mathrm{softmax}(x W_g)
- $W_g$ はゲート用の行列
- 出力は確率ベクトルですが、このままだとすべてのエキスパートが少しずつ活性化してしまいます
2. Noisy Top-k Gating【学習時のみノイズを使用】
ここからが本題で、論文のコアアイデアの一つです。学習時には以下の処理を実行します。
- スコアにガウスノイズを足す
- 上位 $k$ 個だけ残して残りは $-\infty$ にする
- その上で softmax を取る
\begin{aligned}
H(x)_i &= (x W_g)_i + \mathrm{StandardNormal}() \cdot
\mathrm{Softplus}((x W_{\mathrm{noise}})_i) \\
G(x) &= \mathrm{softmax}(\mathrm{KeepTopK}(H(x), k))
\end{aligned}
- $W_{\mathrm{noise}}$ でノイズ量をトークンごと・エキスパートごとに学習
- $\mathrm{KeepTopK}$ は上位 $k$ 以外を $-\infty$ にする演算
推論時には、ノイズ項を除去し、決定論的にルーティングします。
\begin{aligned}
H(x)_i &= (x W_g)_i \\
G(x) &= \mathrm{softmax}(\mathrm{KeepTopK}(H(x), k))
\end{aligned}
学習時にノイズを入れる理由は「負荷分散を良くするため」です。
ゲートが決定論的すぎると、ある一部のエキスパートだけが永遠に選ばれてしまい、他が育たない問題が出るためです。推論時はすでに学習済みなので、安定した決定論的ルーティングの方が適切です。また、Softmax の前に $\mathrm{KeepTopK}$ で上位 $k$ 以外を $-\infty$ にしてから確率に変換することで、選ばれなかったエキスパートの確率が厳密に 0 になり、条件付き計算として不要な計算にリソースが割かれないようになっています。
🐣 学習時に摂動を加える考えはシミュレーティッドアニーリングの時代から今でも現役ですね
ロードバランス用の損失【学習時のみ】
それでも素直に学習させると「人気エキスパートにトークンが殺到する」問題が起きます。
そこで論文では、学習時のみゲートに対して 2 種類の補助損失を入れています。これらは推論時には使用しません。
1. 重要度のばらつきを抑える損失
バッチ $X$ に対して、エキスパート $i$ の「重要度」を
\mathrm{Importance}(X)_i = \sum_{x \in X} G(x)_i
と定義し、重要度ベクトルの変動係数 (coefficient of variation) の二乗を損失として追加します。
L_{\mathrm{importance}}(X) = w_{\mathrm{importance}} \cdot
\mathrm{CV}(\mathrm{Importance}(X))^2
すべてのエキスパートが同じくらいの合計ゲート値を持つように誘導するイメージです。
2. ロード (バッチ内のサンプル数) を揃える損失
重要度だけ揃えても
- 少数のサンプルに大きな重み
- 多数のサンプルに小さな重み
のような偏りが起きます。
GPU メモリ的には「何サンプル来るか」も重要なので、エキスパートごとの予想サンプル数を滑らかに近似した $\mathrm{Load}(X)$ を定義して、こちらも変動係数を抑える損失
L_{\mathrm{load}}(X) = w_{\mathrm{load}} \cdot \mathrm{CV}(\mathrm{Load}(X))^2
を追加しています。これらの補助損失は、学習時にエキスパート間の負荷を均等にするためのものです。推論時には損失計算は実行されません。学習済みのゲートによってバランスの良いルーティングが実現すると信じるわけです。
今回のまとめ
今回は MoE の元論文から、まずは数式定義とルーティングの仕組みについて解説しました。
しかし、MoE を実際に巨大なモデルとして動かすには、理論だけでなく「GPU 上でどう効率的に計算するか」という実装上の課題を解決しなければなりません。
次回は、この実装に関わる部分を見ていきたいと思います。
🐣 エンジニア的にはやっぱり泥臭い実装の話ははずせませんよね