0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Qwen3.5のモデル構造からMoEを理解する

0
Posted at

はじめに

最近リリースされる大規模言語モデル(LLM)では、Mixture of Experts(MoE)と呼ばれる構造がよく採用されています。

MoEは、モデル内部に複数のネットワーク、すなわちExpertを持ち、入力トークンごとにその一部のExpertだけを選択して計算する仕組みです。これにより、モデル全体のパラメータ数を大きく保ちながら、推論時に実際に使う計算量を抑えることができます。

本記事では、Qwen3.5のMoEモデルを題材に、MoEがソースコード上でどのような構造として実装されているのか、また入力に応じてどのようにExpertが選択されるのかを解説します。

想定読者

本記事は、MoEの概念はある程度知っているものの、実際のモデル実装ではどのように動いているのかがまだイメージしづらい方を主な読者として想定しています。

そのため、MoEの基本概念そのものについては詳しく扱いません。MoEについて初めて学ぶ方は、まず以下の記事などを読んでおくと、本記事の内容を追いやすくなると思います。

また、ソースコードを読みながら解説するため、PythonおよびPyTorchの基本的な読み書きに慣れている方を前提としています。

なお、Qwen3.5 MoEには、Gated DeltaNetによるlinear attentionや、full attentionとのハイブリッド構成など、MoE以外にも多くの工夫が導入されています。ただし、本記事の主題はあくまでMoE部分の実装理解であるため、それらの仕組みについては必要な範囲で触れるにとどめ、詳細な解説は行いません。

前提

まず、Qwen3.5には、MoEモデルである Qwen3.5-35B-A3B / Qwen3.5-122B-A10B / Qwen3.5-397B-A17B と、denseモデルである Qwen3.5-9B / Qwen3.5-27B が基盤モデルとして公開されています。

MoEモデル名に含まれる ○○B-AxxB は、○○B がモデル全体のパラメータ数、AxxB が推論時にアクティブになるパラメータ数の目安を表しています。たとえば Qwen3.5-35B-A3B であれば、モデル全体では35B規模のパラメータを持ちますが、各トークンの計算で主に使用されるのは約3B分のパラメータです。

本記事では、このうち Qwen3.5-35B-A3B を前提に解説します。また、ソースコードは以下のHugging Face Transformersの実装をもとにします。

この実装には、テキスト生成を扱う Qwen3_5MoeForCausalLM と、画像・動画入力も扱えるマルチモーダル用の Qwen3_5MoeForConditionalGeneration があります。本記事では、MoEの構造を追いやすくするため、テキスト生成用の Qwen3_5MoeForCausalLM に絞って解説します。

Qwen3.5 MoEの全体像

上図は、Qwen3.5-35B-A3B の全体構造を簡略化したものです。

モデルの中心となるのは DecoderLayer であり、Qwen3.5-35B-A3B では合計40層のDecoderLayerが積み重ねられています。ただし、すべての層が同じ構造になっているわけではありません。内部では、linear_attention layer が3層続いた後に full_attention layer が1層配置される構成を1グループとしており、このグループが10回繰り返されます。

つまり、全体としては次のような構造になります。

1 group = 3 × linear_attention layer + 1 × full_attention layer
10 groups × 4 layers = 40 layers

各DecoderLayerは、大きく見ると「Attention系のToken Mixer」と「Sparse MoE Block」から構成されます。linear_attention layer ではGated DeltaNetが使われ、full_attention layer では通常のSelf-Attentionに近い処理が使われます。一方で、後段にはどちらの層でもSparse MoE Blockが配置されます。

Configから構造を確認する

次に、Qwen3_5MoeTextConfig からモデル構造に関わる主要な設定を確認します。

src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py
class Qwen3_5MoeTextConfig(PreTrainedConfig):

    # 一部のパラメータは省略

    vocab_size: int = 248320                   # 語彙数。lm_headの出力次元で、次トークン候補の総数
    hidden_size: int = 2048                    # 隠れ状態ベクトルの次元数。各トークンを2048次元で表現する
    num_hidden_layers: int = 40                # DecoderLayerの層数。ここでは40層
    num_attention_heads: int = 16              # Full Attentionで使うQueryヘッド数
    num_key_value_heads: int = 2               # Full Attentionで使うKey/Valueヘッド数。GQA用の設定
    hidden_act: str = "silu"                   # MLPやMoE内部で使う活性化関数
    head_dim: int = 256                        # Full Attentionの各ヘッドの次元数
    linear_conv_kernel_dim: int = 4            # Gated DeltaNet内の短い畳み込みに使うカーネルサイズ
    linear_key_head_dim: int = 128             # Linear Attention側のKeyヘッド1つあたりの次元数
    linear_value_head_dim: int = 128           # Linear Attention側のValueヘッド1つあたりの次元数
    linear_num_key_heads: int = 16             # Linear Attention側のKeyヘッド数
    linear_num_value_heads: int = 32           # Linear Attention側のValueヘッド数
    moe_intermediate_size: int = 512           # 各MoE expert内部の中間層次元数
    shared_expert_intermediate_size: int = 512 # 全トークンで共有されるshared expertの中間層次元数
    num_experts_per_tok: int = 8               # 1トークンあたり選択されるexpert数。Top-k routingのk
    num_experts: int = 256                     # MoEに用意されているexpertの総数

このConfigを見ると、モデル全体の大まかな構造を把握できます。

まず、vocab_size248320 に設定されています。これは、モデルが出力候補として扱うトークンIDの総数です。最終的な lm_head は、各トークン位置の隠れ状態をこの語彙数ぶんのlogitに変換します。

また、hidden_size2048 です。これは、各トークンがモデル内部で2048次元のベクトルとして表現されることを意味します。DecoderLayerの数は num_hidden_layers = 40 であり、先ほどの図で示したように、40層のDecoderLayerが積み重ねられます。

Attentionに関する設定としては、full attention用の num_attention_headsnum_key_value_headshead_dim に加えて、linear attention用の linear_key_head_dimlinear_value_head_dimlinear_num_key_headslinear_num_value_heads などが定義されています。ここから、Qwen3.5 MoEではfull attentionだけでなく、Gated DeltaNetによるlinear attentionも併用されていることが分かります。

さらに、本記事で主に扱うMoEに関する設定として、num_experts = 256num_experts_per_tok = 8 が定義されています。これは、モデル内に256個のExpertが用意されており、各トークンに対してそのうち8個のExpertが選択されることを意味します。

つまり、このConfigからは、Qwen3.5-35B-A3B が「40層のDecoderLayer」「full attentionとlinear attentionのハイブリッド構成」「256個のExpertからトークンごとに8個を選択するSparse MoE構造」を持つモデルであることが確認できます。

実装を追う

ここからは、実際のソースコードを上位のクラスから順に見ていきます。

流れとしては、まずユーザーが直接呼び出す Qwen3_5MoeForCausalLM を確認し、その内部で使われる Qwen3_5MoeTextModel、さらにその中に積み重ねられている Qwen3_5MoeDecoderLayer へと進みます。

Qwen3_5MoeForCausalLM
  ↓
Qwen3_5MoeTextModel
  ↓
Qwen3_5MoeDecoderLayer
  ↓
Qwen3_5MoeSparseMoeBlock

最終的には、本記事の主題である Qwen3_5MoeSparseMoeBlock の内部を詳しく見ていきます。

テキスト生成モデルの入口:Qwen3_5MoeForCausalLM

まずは、ユーザーが直接呼び出すモデルである Qwen3_5MoeForCausalLM から見ていきます。

src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin):
    def __init__(self, config):
        super().__init__(config)

        self.model = Qwen3_5MoeTextModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.router_aux_loss_coef = config.router_aux_loss_coef
        self.num_experts = config.num_experts
        self.num_experts_per_tok = config.num_experts_per_tok

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_router_logits=None,
        logits_to_keep=0,
        **kwargs,
    ) -> MoeCausalLMOutputWithPast:

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_router_logits=output_router_logits,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state

        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

        aux_loss = None
        if output_router_logits:
            aux_loss = load_balancing_loss_func(
                outputs.router_logits,
                self.num_experts,
                self.num_experts_per_tok,
                attention_mask,
            )
            if labels is not None:
                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)

        return MoeCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_logits=outputs.router_logits,
        )

Qwen3_5MoeForCausalLM は、テキスト生成用モデルの入口にあたるクラスです。

大きな流れは次のとおりです。

input_ids
  ↓
Qwen3_5MoeTextModel
  ↓
hidden_states
  ↓
lm_head
  ↓
logits

まず、ユーザーから渡された input_idsQwen3_5MoeTextModel に入力されます。Qwen3_5MoeTextModel は、Embedding、DecoderLayer、RMSNormなどを通して、各トークンの特徴量である hidden_states を出力します。

その後、lm_head によって hidden_states が語彙数ぶんのスコアに変換されます。

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

今回のConfigでは、hidden_size = 2048vocab_size = 248320 です。そのため、各トークン位置の2048次元ベクトルは、248,320個の語彙トークンそれぞれに対するスコアへ変換されます。

このスコアが logits です。

hidden_states: [batch_size, seq_len, hidden_size]
logits:        [batch_size, seq_len, vocab_size]

また、labels が与えられている場合は言語モデリング用の loss が計算されます。さらに、output_router_logits=True の場合は、MoEのExpert選択が偏りすぎないようにするための補助損失 aux_loss も計算されます。

最後に、lossaux_losslogitsrouter_logits などが MoeCausalLMOutputWithPast にまとめられて返されます。

MoeCausalLMOutputWithPast は、通常のテンソルそのものではなく、モデル出力を名前付きでまとめるためのデータクラス系オブジェクトです。

言語モデル本体:Qwen3_5MoeTextModel

次に、Qwen3_5MoeTextModel を見ていきます。ここが、EmbeddingやDecoderLayerを含む言語モデル本体です。

src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel):
    def __init__(self, config: Qwen3_5MoeTextConfig):
        super().__init__(config)

        self.embed_tokens = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            config.pad_token_id,
        )

        self.layers = nn.ModuleList(
            [
                Qwen3_5MoeDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )

        self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config)

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        **kwargs,
    ) -> BaseModelOutputWithPast:

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        causal_mask = create_causal_mask(
            config=self.config,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        linear_attn_mask = self._update_linear_attn_mask(
            attention_mask,
            past_key_values,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
            layer_mask = (
                linear_attn_mask
                if self.config.layer_types[i] == "linear_attention"
                else causal_mask
            )

            hidden_states = decoder_layer(
                hidden_states,
                position_embeddings=position_embeddings,
                attention_mask=layer_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)

        return Qwen3_5MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )

Qwen3_5MoeTextModel では、まず input_ids がEmbedding層に入力されます。

inputs_embeds = self.embed_tokens(input_ids)

これにより、各トークンIDは2048次元のベクトルに変換されます。

input_ids:     [batch_size, seq_len]
inputs_embeds: [batch_size, seq_len, hidden_size]

ここでの hidden_size はConfigで指定されている 2048 です。つまり、各トークンはモデル内部では2048次元の特徴ベクトルとして扱われます。

続いて、Attentionに使うマスクが作成されます。

causal_mask = create_causal_mask(...)
linear_attn_mask = self._update_linear_attn_mask(...)

Qwen3.5 MoEでは、層によって full_attentionlinear_attention が切り替わるため、それぞれに応じたマスクが使われます。

また、位置情報としてRotary Position Embedding、いわゆるRoPEが計算されます。

position_embeddings = self.rotary_emb(hidden_states, position_ids)

RoPEは、トークンの順序や相対的な位置関係をAttention計算に反映するための位置埋め込みです。
その後、DecoderLayer を順番に通していきます。

for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
    layer_mask = (
        linear_attn_mask
        if self.config.layer_types[i] == "linear_attention"
        else causal_mask
    )

    hidden_states = decoder_layer(...)

Decoderでは、self.config.layer_types[i] によって、各層が linear_attention なのか full_attention なのかを判定しています。linear_attention の層では linear_attn_mask が使われ、full_attention の層では causal_mask が使われます。

すべてのDecoderLayerを通った後、最後にRMSNormが適用されます。

hidden_states = self.norm(hidden_states)

そして、最終的な hidden_stateslast_hidden_state として返されます。

DecoderLayerの構造

続いて、Qwen3.5 MoEのメインブロックである DecoderLayer を見ていきます。

ここまでで、Qwen3.5 MoEでは linear_attentionfull_attention が層ごとに使い分けられていることを確認しました。

DecoderLayer では、前半でトークン間の情報を混ぜる Token Mixer が使われ、後半で Sparse MoE Block によるFFN処理が行われます。Qwen3.5 MoEでは、この Token Mixer として、層によって linear_attention または full_attention が使い分けられています。

src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
class Qwen3_5MoeDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_type = config.layer_types[layer_idx]

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx)
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3_5MoeAttention(config, layer_idx)

        self.mlp = Qwen3_5MoeSparseMoeBlock(config)
        self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> torch.FloatTensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Token Mixer
        if self.layer_type == "linear_attention":
            hidden_states = self.linear_attn(
                hidden_states=hidden_states,
                cache_params=past_key_values,
                attention_mask=attention_mask,
                **kwargs,
            )
        elif self.layer_type == "full_attention":
            hidden_states, _ = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = residual + hidden_states

        # Sparse MoE Block
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        if isinstance(hidden_states, tuple):
            hidden_states, _ = hidden_states

        hidden_states = residual + hidden_states

        return hidden_states

このコードを見ると、Qwen3_5MoeDecoderLayer は大きく次の2つの部分に分けられます。

1. Token Mixer
   - linear_attention: Qwen3_5MoeGatedDeltaNet
   - full_attention: Qwen3_5MoeAttention

2. Sparse MoE Block
   - Qwen3_5MoeSparseMoeBlock

まず、入力された hidden_statesinput_layernorm によって正規化されます。

hidden_states = self.input_layernorm(hidden_states)

その後、self.layer_type に応じて、linear_attention または full_attention のどちらかに入力されます。

if self.layer_type == "linear_attention":
    hidden_states = self.linear_attn(...)
elif self.layer_type == "full_attention":
    hidden_states, _ = self.self_attn(...)

各層でどちらのAttentionを使うかは、Config内の layer_types によって決まります。

src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py
def __post_init__(self, **kwargs):
    kwargs.setdefault("partial_rotary_factor", 0.25)

    if self.layer_types is None:
        interval_pattern = kwargs.pop("full_attention_interval", 4)
        self.layer_types = [
            "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
            for i in range(self.num_hidden_layers)
        ]

    super().__post_init__(**kwargs)

このコードでは、full_attention_interval のデフォルト値が 4 になっています。そのため、4層に1回だけ full_attention が使われ、それ以外の層では linear_attention が使われます。

linear_attention
linear_attention
linear_attention
full_attention

つまり、Qwen3.5-35B-A3B では、この4層のパターンが10回繰り返され、合計40層のDecoderLayerを構成しています。

full_attentionlinear_attention の違いについて

本記事はMoEの解説を主題としているため、それぞれのAttention機構については深く扱いません。ただし、Qwen3.5 MoEでは full_attentionlinear_attention が併用されているため、ここでは両者の違いを簡単に整理します。

full_attention:
通常のSelf-Attentionに近く、各トークンが他のトークンとのAttentionスコアを計算します。そのため、文脈中のトークン間の関係を直接的に捉えやすい一方で、系列長が長くなるほど計算量やメモリ使用量が大きくなります。

Qwen3.5 MoEの full_attention では、通常の q, k, v に加えてAttention出力を調整するgateが導入されています。また、Key/Valueヘッド数をQueryヘッド数より少なくするGrouped Query Attention(GQA)も使われており、KV cacheの増大を抑えたり、推論時のK/V読み出しコストを下げたりする工夫が含まれています。

linear_attention:
full_attention では全トークン間のAttentionスコアを計算するため、系列長を N とすると計算量が概ね O(N^2) に増えます。これに対して、linear_attention は全トークン間のAttention行列を明示的に作らず、系列情報をより効率的に扱うための仕組みです。

Qwen3.5 MoEでは、linear_attention として Gated DeltaNet が使われています。Gated DeltaNetは、通常のAttentionとは異なり、系列情報を状態として扱いながらトークン間の情報を混ぜるLinear Attention系のToken Mixerです。これにより、長い系列に対する計算量やメモリ使用量を抑えやすくなります。

このように、Qwen3.5 MoEでは、効率性を重視した linear_attention と、通常のSelf-Attentionに近い full_attention を組み合わせることで、計算量を抑えつつ文脈情報を扱う構成になっています。

どちらのToken Mixerを通った場合でも、その後段には共通して Qwen3_5MoeSparseMoeBlockが配置されています。

self.mlp = Qwen3_5MoeSparseMoeBlock(config)

そのため、DecoderLayerの流れは次のように整理できます。

ここまでで、入力が Qwen3_5MoeForCausalLM から Qwen3_5MoeTextModel、さらに Qwen3_5MoeDecoderLayer を通り、その中で Qwen3_5MoeSparseMoeBlock に到達するまでの流れを確認しました。

以降では、本記事の主題である Qwen3_5MoeSparseMoeBlock の内部、特にRouterによるExpert選択、Expert計算、出力の集約について詳しく見ていきます。

Sparse MoE Blockの内部

Sparse MoE Block全体の流れ

ここからは、本記事のメインであるMoEブロックの内部を見ていきます。

Attentionブロックを通った hidden_states は、Qwen3_5MoeSparseMoeBlock に入力されます。このブロックでは、入力トークンごとに使用するExpertを選択し、その出力を集約することで、最終的なMoEブロックの出力を作ります。

まずは、全体の流れを確認します。

src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
class Qwen3_5MoeSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = Qwen3_5MoeTopKRouter(config)
        self.experts = Qwen3_5MoeExperts(config)
        self.shared_expert = Qwen3_5MoeMLP(
            config,
            intermediate_size=config.shared_expert_intermediate_size,
        )
        self.shared_expert_gate = torch.nn.Linear(
            config.hidden_size,
            1,
            bias=False,
        )

    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, sequence_length, hidden_dim = hidden_states.shape

        hidden_states_reshaped = hidden_states.view(-1, hidden_dim)

        shared_expert_output = self.shared_expert(hidden_states_reshaped)

        _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)

        expert_output = self.experts(
            hidden_states_reshaped,
            selected_experts,
            routing_weights,
        )

        shared_expert_output = (
            F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
            * shared_expert_output
        )

        expert_output = expert_output + shared_expert_output
        expert_output = expert_output.reshape(
            batch_size,
            sequence_length,
            hidden_dim,
        )

        return expert_output

このコードの流れを大きく見ると、次のようになります。

ここで Bbatch_sizeSsequence_lengthHhidden_dim を表します。

最初に、入力された hidden_states[B, S, H] から [B*S, H] に変形されます。これは、MoEでは各トークンごとにExpertを選択するため、バッチ方向と系列長方向をまとめて「全トークンを1行ずつ並べたテンソル」として扱うためです。

その後、処理は大きく2つの経路に分かれます。1つ目は、Routerによって選択されたExpertを通る経路です。2つ目は、すべてのトークンに共通して適用されるshared expertの経路です。

最後に、この2つの経路の出力を足し合わせ、形を [B, S, H] に戻すことで、MoEブロックの出力が得られます。

以降では、まずRouterによるExpert選択、次に選択されたExpertによる計算、最後にshared expertを含む出力の集約という順番で詳しく見ていきます。

選択されたExpertによる計算:Qwen3_5MoeExperts

Routerによって各トークンに対して使用するExpertが決まると、次に Qwen3_5MoeExperts で実際のExpert計算が行われます。

src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
class Qwen3_5MoeExperts(nn.Module):
    """Collection of expert weights stored as 3D tensors."""

    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.hidden_dim = config.hidden_size
        self.intermediate_dim = config.moe_intermediate_size

        self.gate_up_proj = nn.Parameter(
            torch.empty(
                self.num_experts,
                2 * self.intermediate_dim,
                self.hidden_dim,
            )
        )
        self.down_proj = nn.Parameter(
            torch.empty(
                self.num_experts,
                self.hidden_dim,
                self.intermediate_dim,
            )
        )
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(
        self,
        hidden_states: torch.Tensor,
        top_k_index: torch.Tensor,
        top_k_weights: torch.Tensor,
    ) -> torch.Tensor:
        final_hidden_states = torch.zeros_like(hidden_states)

        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(
                top_k_index,
                num_classes=self.num_experts,
            )
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(
                expert_mask.sum(dim=(-1, -2)),
                0,
            ).nonzero()

        for expert_idx in expert_hit:
            expert_idx = expert_idx[0]

            if expert_idx == self.num_experts:
                continue

            top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

            current_state = hidden_states[token_idx]

            gate, up = nn.functional.linear(
                current_state,
                self.gate_up_proj[expert_idx],
            ).chunk(2, dim=-1)

            current_hidden_states = self.act_fn(gate) * up

            current_hidden_states = nn.functional.linear(
                current_hidden_states,
                self.down_proj[expert_idx],
            )

            current_hidden_states = (
                current_hidden_states
                * top_k_weights[token_idx, top_k_pos, None]
            )

            final_hidden_states.index_add_(
                0,
                token_idx,
                current_hidden_states.to(final_hidden_states.dtype),
            )

        return final_hidden_states

ここでは、まず top_k_index をもとに、今回の入力で実際に使われるExpertだけを抽出しています。

with torch.no_grad():
    expert_mask = torch.nn.functional.one_hot(
        top_k_index,
        num_classes=self.num_experts,
    )
    expert_mask = expert_mask.permute(2, 1, 0)
    expert_hit = torch.greater(
        expert_mask.sum(dim=(-1, -2)),
        0,
    ).nonzero()

top_k_index は、各トークンに対して選ばれたExpert IDを表します。これをone-hot化し、Expert方向を先頭に並べ替えることで、各Expertがどのトークンに選ばれたのかを扱いやすくしています。

その後、expert_hit によって、今回の入力で1回以上選ばれたExpertだけを取り出します。つまり、256個すべてのExpertを毎回計算するのではなく、実際に使われるExpertだけを処理対象にしています。

次に、Expertごとにループし、そのExpertに割り当てられたトークンを取り出します。

top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

current_state = hidden_states[token_idx]

token_idx は、現在のExpertに割り当てられたトークンの位置を表します。
そのため、hidden_states[token_idx] によって、現在のExpertが処理すべきトークンの特徴量だけを取り出しています。

Expert内部のMLP計算では、まず gate_up_proj によって線形変換を行います。

gate, up = nn.functional.linear(
    current_state,
    self.gate_up_proj[expert_idx],
).chunk(2, dim=-1)

gate_up_proj はExpertごとに異なる重みを持っています。
Qwen3.5-35B-A3Bでは、num_experts = 256hidden_dim = 2048intermediate_dim = 512 なので、gate_up_proj の形は次のようになります。

[num_experts, 2 * intermediate_dim, hidden_dim]
= [256, 1024, 2048]

実際には、self.gate_up_proj[expert_idx] によって、現在のExpertに対応する重みだけを取り出して使います。

self.gate_up_proj[expert_idx]: [1024, 2048]

current_state の形を [n_tokens_for_expert, hidden_dim] とすると、線形変換後の出力は次の形になります。

current_state:
  [n_tokens_for_expert, 2048]

linear output:
  [n_tokens_for_expert, 1024]

この出力を .chunk(2, dim=-1) で2分割し、gateup に分けています。

gate: [n_tokens_for_expert, 512]
up:   [n_tokens_for_expert, 512]

この分岐と再合流の流れを図にすると、次のようになります。

ここで行われている中心的な処理は、次の部分です。

current_hidden_states = self.act_fn(gate) * up

これは、通常の単純なMLPではなく、SwiGLUに近いゲート付きMLP構造です。
gate 側に活性化関数をかけたものが、up 側の特徴量をどの程度通すかを制御します。

gate:
  どの成分をどの程度通すかを制御する

up:
  変換後の特徴量本体

act_fn(gate) * up:
  gateによって制御された中間表現

その後、down_proj によって中間表現を再び hidden_dim に戻します。

current_hidden_states = nn.functional.linear(
    current_hidden_states,
    self.down_proj[expert_idx],
)

down_proj もExpertごとに異なる重みを持っています。
形は次のようになります。

[num_experts, hidden_dim, intermediate_dim]

Qwen3.5-35B-A3Bでは、具体的には次の形です。

[256, 2048, 512]

現在のExpertに対応する重みだけを見ると、次の形になります。

self.down_proj[expert_idx]: [2048, 512]

current_hidden_states[n_tokens_for_expert, 512] なので、down_proj を通すことで次のように戻ります。

[n_tokens_for_expert, 512]
  ↓
[n_tokens_for_expert, 2048]

ここまでで、現在のExpertによる出力が得られます。

次に、この出力にRouterで計算された重みを掛けます。

current_hidden_states = (
    current_hidden_states
    * top_k_weights[token_idx, top_k_pos, None]
)

top_k_weights は、選ばれたExpertの出力をどの程度反映するかを表す重みです。各トークンはTop-k個のExpertに通されるため、それぞれのExpert出力に対応する重みを掛けてから集約します。

最後に、index_add_ によって、対応するトークン位置へ出力を加算します。

final_hidden_states.index_add_(
    0,
    token_idx,
    current_hidden_states.to(final_hidden_states.dtype),
)

同じトークンは複数のExpertに割り当てられるため、各Expertから得られた出力を同じトークン位置に足し合わせる必要があります。index_add_ は、そのための加算処理です。

つまり、Qwen3_5MoeExperts では、Routerで選ばれたExpertごとに対象トークンを集め、そのExpert専用のゲート付きMLPで変換し、Routerの重みを掛けたうえで、最終的な出力テンソルに加算しています。

Shared Expertによる共通経路と出力の集約

Qwen3.5 MoEでは、256個のExpertの中から各トークンごとにTop-k個のExpertを選択します。一方で、それとは別に、すべてのトークンが必ず通る共通経路として shared_expert も用意されています。

該当する処理は、Qwen3_5MoeSparseMoeBlockforward 内の次の部分です。

hidden_states_reshaped = hidden_states.view(-1, hidden_dim)

shared_expert_output = self.shared_expert(hidden_states_reshaped)

shared_expert_output = (
    F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
    * shared_expert_output
)

expert_output = expert_output + shared_expert_output

まず、Attentionブロックを通った hidden_states は、MoEブロック内で [batch_size × sequence_length, hidden_dim] の形に変形されます。

hidden_states_reshaped = hidden_states.view(-1, hidden_dim)

その後、この hidden_states_reshapedshared_expert に入力されます。

shared_expert_output = self.shared_expert(hidden_states_reshaped)

shared_expert は、Routerによって選択される通常のExpertとは異なり、すべてのトークンに共通して適用されるMLPです。つまり、Top-k RouterでどのExpertが選ばれたかに関係なく、各トークンは必ずこの共通経路を通ります。

さらに、shared expertの出力には shared_expert_gate によるゲートが掛けられます。

shared_expert_output = (
    F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
    * shared_expert_output
)

ここでは、shared_expert_gate によって各トークンごとにスカラー値を計算し、sigmoid を通してから shared_expert_output に掛けています。これにより、shared expertの出力をどの程度反映するかをトークンごとに調整しています。

最後に、Routerで選ばれたExpertによる出力 expert_output と、共通経路である shared_expert_output を足し合わせます。

expert_output = expert_output + shared_expert_output

つまり、Qwen3_5MoeSparseMoeBlock の最終出力は、次の2つを合成したものです。

1. Routerで選ばれたTop-k Expertの出力
2. すべてのトークンが通るshared expertの出力

このように、Qwen3.5 MoEでは、トークンごとに選択されるSparseなExpert経路に加えて、全トークン共通のshared expert経路も組み合わせています。これにより、Expertごとの専門的な変換だけでなく、全トークンに共通する変換も同時に利用できる構造になっています。

最終的に得られたMoEブロックの出力は、DecoderLayer内で残差接続された後、次のDecoderLayer、または最終層であれば Final RMSNormLM Head に渡されます。そして、LM Head によって語彙数ぶんのlogitsに変換され、次トークン予測に使われます。

【補足】補助損失 aux_loss

MoEでは、各トークンごとに使用するExpertをRouterが選択します。しかし、学習中にRouterの出力が偏ると、一部のExpertばかりが使われ、他のExpertがほとんど使われない状態になる可能性があります。

このような偏りを抑えるために、Qwen3.5 MoEでは補助損失として aux_loss が計算されます。これは、通常の言語モデリングlossとは別に、Expertの使用が特定のExpertに集中しすぎないようにするためのlossです。

該当する処理は、Qwen3_5MoeForCausalLMforward 内にあります。

aux_loss = None
if output_router_logits:
    aux_loss = load_balancing_loss_func(
        outputs.router_logits,
        self.num_experts,
        self.num_experts_per_tok,
        attention_mask,
    )
    if labels is not None:
        loss += self.router_aux_loss_coef * aux_loss.to(loss.device)

output_router_logits=True の場合、各MoE層から出力された router_logits を使って load_balancing_loss_func が呼び出されます。
そして、labels が与えられている学習時には、通常の言語モデリングlossに対して、係数 router_aux_loss_coef を掛けた aux_loss が加えられます。

最終的なloss
= 言語モデリングloss + router_aux_loss_coef × aux_loss

load_balancing_loss_func の大まかな処理は次のとおりです。

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

まず、各層の router_logits を結合し、softmaxによってExpert方向の確率分布 routing_weights を計算します。
その後、Top-kによって実際に選ばれるExpertを取得し、one-hot化します。

ここで重要なのは、次の2つの値です。

tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

router_prob_per_expert = torch.mean(routing_weights, dim=0)

tokens_per_expert は、Top-kによって実際に各Expertがどれくらい選ばれたかを表します。
つまり、Routerの確率分布から実際に採用されたExpertの使用率です。

一方、router_prob_per_expert は、Top-kでExpertを選ぶ前のRouter確率を平均したものです。
これは、Routerが各Expertに平均的にどれくらいの確率を割り当てていたかを表します。

整理すると、次のようになります。

tokens_per_expert
  = Top-k後の実際のExpert使用率

router_prob_per_expert
  = Top-k前のRouter確率の平均

最後に、この2つを掛け合わせて合計します。

overall_loss = torch.sum(
    tokens_per_expert * router_prob_per_expert.unsqueeze(0)
)

return overall_loss * num_experts

この式では、実際によく選ばれているExpertであり、かつRouterが高い確率を割り当てているExpertがあると、そのExpertに対応する積が大きくなります。
つまり、特定のExpertにルーティングが集中しているほど aux_loss が大きくなります。

直感的には、次のようなExpertの偏りを罰しています。

Expert 0ばかりが選ばれる
RouterもExpert 0に高い確率を出し続ける
→ aux_lossが大きくなる

逆に、Expertが比較的均等に使われていれば、特定のExpertだけに大きな値が集中しにくくなるため、aux_loss は小さくなります。

なお、attention_mask が与えられている場合は、padding tokenを除外して tokens_per_expertrouter_prob_per_expert を計算します。padding部分のRouter結果までExpert使用率に含めてしまうと、実際の有効トークンに対するExpert使用状況を正しく評価できないためです。

このように、aux_loss はMoEモデルにおいて、Expertが偏って使われることを防ぐための補助的な正則化項として機能します。

まとめ

本記事では、Qwen3.5 MoEの実装を、Qwen3_5MoeForCausalLM から Qwen3_5MoeSparseMoeBlock まで順に追いながら、MoEがソースコード上でどのように動いているのかを確認しました。

Qwen3.5 MoEでは、DecoderLayerの後半に Qwen3_5MoeSparseMoeBlock が配置されており、通常のDenseなFFNの代わりに、Routerで選択されたExpertによる変換が行われます。

Sparse MoE Blockの中心となる処理は、次の3つです。

Qwen3_5MoeTopKRouter
  → 各トークンに対して使用するExpertを選択する

Qwen3_5MoeExperts
  → 選択されたExpertで実際に特徴量を変換する

shared_expert
  → すべてのトークンが通る共通経路を提供する

つまり、Qwen3.5 MoEのSparse MoE Blockは、単に複数のMLPを並べた構造ではなく、RouterによるExpert選択、Expertごとの計算、shared expertによる共通経路を組み合わせた構造になっています。

また、学習時にはExpertの使用が偏りすぎないように、補助的に aux_loss も使われます。

この流れを押さえることで、「入力トークンごとにExpertを選び、必要なExpertだけを使って計算する」というMoEの仕組みを、Qwen3.5の実装と対応づけて理解しやすくなります。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?