はじめに
昨今、Mixture of Experts (MoE) アーキテクチャの流行により、複数の特化型モデル(専門家)を統合して1つの巨大なシステムとして動かすアプローチが注目されています。
例えば「フランス語・ドイツ語から英語への翻訳に特化したモデル」と「スペイン語から英語への翻訳に特化したモデル」を統合したい場合、入力を自動判別して適切なモデルに処理を流す 「上位ルーター(Super Router)」 が必要になります。
しかし、単純なSoftmaxによるアンサンブルでは「すべてのモデルを計算しなければならない」ため計算コストが跳ね上がります。
本記事では、不要なモデルの計算を完全に省く Hard Routing(Top-1 Gating) をPyTorchで実装し、どの数理的手法が最も完璧にルーティングを学習できるのかを実験・検証しました。
ルーティングの罠:なぜ単純なSoftmaxではダメなのか?
複数のモデルの出力を合成する際、一番簡単なのはSoftmaxで重みを計算する Soft Routing です。
# (B, T, 2) の確率ベクトル
weights = F.softmax(router_logits, dim=-1)
# 最終出力 = モデル1の出力 × w1 + モデル2の出力 × w2
final_out = out1 * w1 + out2 * w2
しかしこの手法には 「Lazy Routing(ルーターの怠け)」 という問題があります。
片方のモデルが圧倒的に正しい出力(高いLogits)を出す場合、ルーターはわざわざ確率を 1.0 : 0.0 に尖らせなくても、0.5 : 0.5 のままでLossを下げることができてしまいます。結果として、推論時には使わないはずのモデルまでメモリに載せて計算しなければならなくなります。
これを防ぐためには、強制的にトップ1だけを選ぶ Hard Routing が必要です。
比較実験:3つのルーティング手法
今回は以下の3手法を実装し、言語(仏・独・西)によるルーティングの「確率分布」がどう変化するかを実験しました。
- Soft Routing: 通常のSoftmaxアンサンブル。
- Straight-Through Estimator (STE): Forwardは強制Top-1選択、BackwardはSoftmaxの勾配を騙して流すハック的手法。
-
Gumbel-Softmax: ガンベルノイズを加え、微分可能な疑似ハードルーティングを行う手法。PyTorch標準の
F.gumbel_softmax(..., hard=True)を使用。
Gumbel-Softmaxの実装例 (PyTorch)
if self.training:
# 学習時はノイズを乗せて微分可能にHard選択
routing_weights = F.gumbel_softmax(router_logits, tau=1.0, hard=True, dim=-1)
else:
# 推論時は決定論的にargmax
top_idx = router_logits.argmax(dim=-1, keepdim=True)
routing_weights = torch.zeros_like(router_logits).scatter_(-1, top_idx, 1.0)
驚くべき実験結果
実験のログ出力結果は以下のようになりました。
(※Model 1 = 仏・独用、Model 2 = 西用)
【Soft Routing の結果】
✅ FRA -> [Routed to Model 1] probs: [0.909, 0.091]
✅ SPA -> [Routed to Model 2] probs: [0.041, 0.959]
【STE (Straight-Through Estimator) の結果】
✅ FRA -> [Routed to Model 1] probs: [0.589, 0.411]
✅ SPA -> [Routed to Model 2] probs: [0.437, 0.563]
【Gumbel-Softmax の結果】
✅ FRA -> [Routed to Model 1] probs: [1.0, 0.0]
✅ SPA -> [Routed to Model 2] probs: [0.0, 1.0]
結果の考察
-
Soft:
0.9という高い確率を出せましたが、残りの0.1がノイズとして残り、計算コストの完全削減には使えません。 -
STE: ターゲットのモデルを選ぶことには成功していますが、確率が
0.6 : 0.4と非常に濁っています。これは正解のモデルを選べた時点で勾配が消失し、確率を尖らせる学習が止まるためです。 -
Gumbel-Softmax: 完璧な
1.0 : 0.0のスパースルーティングを達成しました! 学習時のノイズにより、ルーターは「確実に正解を選ぶために確率分布を極端に尖らせる」ように強制され、推論時に一切の迷いがなくなりました。
まとめ
複数の独立したLLMやモデルを束ね、計算コストを抑えながら動的に切り替える「上位ルーター」を構築する場合、Gumbel-Softmax を用いた Hard Routing が圧倒的に強力であることが分かりました。
(※ちなみに本記事の知見は、現在筆者が開発中の SRA(Synaptic Routing Architecture) という「シナプス単位で動的にモデルを拡張・入れ替えできる次世代アーキテクチャ」の検証実験中に得られたものです。MoEやプラグイン型のアーキテクチャ設計に興味がある方は、そちらもぜひチェックしてみてください!)
以下は、今回の実験・検証のノートブックです。