4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Qwen3にVLD(Virtual Logical Depth:仮想論理深度)を実装する

Last updated at Posted at 2025-11-10

初めに

松尾研LLM開発コンペ2025に参加した際、チームの取り組みの1つとしてVLD(Virtual Logical Depth:仮想論理深度) をQwen3に実装し検証を行いました。
ここではその知見共有が目的となります。

VLDとは

論文↓

ざっくり、
VLD(Virtual Logical Depth:仮想論理深度) とは、パラメータ数を増やさずに重みを再利用することで、モデルの効果的なアルゴリズム深度を増加させる新しいスケーリング手法です。

この論文では、大規模言語モデルのスケーリングにおける「第4の次元」として提案されています(従来の3次元は深さ、幅、パラメータ数)。

主な特徴として以下があげられます。:

  • 固定されたパラメータ数でVLDを適用すると、知識容量はほぼ一定に保たれる一方で、推論能力が大幅に向上

つまりVLDは、
モデルを大きくせずに推論性能を向上させる効率的なスケーリング戦略として提案されています。

また、パラメータの再利用方法には何パターンかあり、この中でもCYCLEパターンが最も性能の向上が見られているそうです。

image.png

コードの全体像:VLD (cycleパターン) を実装したQwen3言語モデルについて

以下のコードは、Qwen3というTransformerベースの大規模言語モデルを改造し、仮想論理深度(Virtual Logical Depth: VLD) という技術を実装したものです。

Qwen3 VLD コード
from transformers import (
    Qwen3ForCausalLM,
    Qwen3Model,
    Qwen3Config,
    Qwen3PreTrainedModel
)
from transformers.models.qwen3.modeling_qwen3 import (
    apply_rotary_pos_emb, 
    eager_attention_forward,
    Qwen3Attention,
    Qwen3DecoderLayer
)
import torch
from torch import nn
from typing import Optional,Unpack, Callable
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.utils import TransformersKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
)

class VLDCycleQwen3Attention(Qwen3Attention):
    def __init__(
        self, 
        config: Qwen3Config, 
        layer_idx: int,
    ):
        super().__init__(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        current_share_layer_idx: int,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        num_physical_layers = self.config.num_hidden_layers 

        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx+num_physical_layers*current_share_layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,  # diff with Llama
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights
    

class VLDCycleQwen3DecoderLayer(Qwen3DecoderLayer):
    def __init__(self, config: Qwen3Config, layer_idx: int):
        super().__init__(config, layer_idx)
        self.self_attn = VLDCycleQwen3Attention(config=config, layer_idx=layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        current_share_layer_idx: int,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor]:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        # Self Attention
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            current_share_layer_idx=current_share_layer_idx,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

class VLDCycleQwen3Model(Qwen3Model):
    def __init__(
        self, 
        config: Qwen3Config,
        num_share_layers: int
    ):
        super().__init__(config)
        self.num_share_layers = num_share_layers
        self.layers = nn.ModuleList(
            [VLDCycleQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        ) 

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPast:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
            }
            # The sliding window alternating layers are not always activated depending on the config
            if self.has_sliding_layers:
                causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for current_share_layer_idx in range(self.num_share_layers):
            for decoder_layer in self.layers[: self.config.num_hidden_layers]:
                hidden_states = decoder_layer(
                    hidden_states,
                    current_share_layer_idx=current_share_layer_idx,
                    attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **kwargs,
                )

        hidden_states = self.norm(hidden_states)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )
    
class VLDCycleQwen3ForCausalLM(Qwen3ForCausalLM):    
    _no_split_modules = ["VLDCycleQwen3DecoderLayer"]
    
    def __init__(
        self, 
        config,
        num_share_layers: int = 1
    ):
        Qwen3PreTrainedModel.__init__(self,config)
        self.model = VLDCycleQwen3Model(
            config,
            num_share_layers=num_share_layers
        )
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

具体的には、VLDの中でも cycle(サイクル) と呼ばれるパターンを適用しています。

この改造の目的は、「モデルが記憶できる知識の量(パラメータ数)を増やすことなく、計算の深さを増すことで、モデルの推論能力を向上させる」ことです。

以下にどのような改造が施されているかを解説していきます。

背景:TransformerアーキテクチャとQwen3モデル

このコードが何をしているかを完全に理解するためには、まずベースとなっているTransformerアーキテクチャについて理解することが重要です。

Transformerとは?

論文 → Attention Is All You Need

Transformerの基本的な構成要素は「Transformerブロック」と呼ばれる隠れ層であり、これは主に2つのサブレイヤーから構成されます。
(詳細はいろいろと調べてみてください。私もフワッとしか分かっておりません故…)

image (4).png

上図において、左半分がエンコーダ、右半分がデコーダ となっています。

  1. マルチヘッド・セルフアテンション (Multi-Head Self-Attention):
    文中の各単語が、他のすべての単語にどれだけ「注意」を向けるべきかを計算します。これにより、単語は文脈に応じた豊かな表現を獲得します。
  2. 位置ごとのフィードフォワードネットワーク (Position-wise Feed-Forward Network, FFN):
    アテンション層からの出力を受け取り、各単語の表現を個別に非線形変換します。これにより、モデルの表現力が高まります。

Qwen3におけるTransformerの利用

Qwen3は、このTransformerアーキテクチャをベースに構築された、大規模言語モデルです。

特にQwen3は、GPTシリーズのように「デコーダオンリー(decoder-only)」と呼ばれる構造を採用しています。
これは、オリジナルのTransformerが持っていたエンコーダとデコーダのうち、デコーダ部分のみを積み重ねて構成されていることを意味します。

image (5).png

上図において、右半分のデコーダ のみを使用

このデコーダオンリーアーキテクチャは、主にテキスト生成タスクに特化しています。
モデルは、それまでに入力された単語のシーケンスを基に、次に来る単語を予測するという処理を繰り返します。

今回のコードで言及されている「64層の隠れ層」とは、このTransformerのデコーダブロックが64個積み重なっていることを指します。

各デコーダブロックには、前述のアテンション層とFFNが含まれますが、テキスト生成のためにアテンション層は「マスク付き(Masked)」になっており、未来の単語を参照できないようになっています。

このコードは、このQwen3の基本的な構造、特に各デコーダブロックの呼び出し方と、その内部にあるアテンション層の状態管理に手を加えることで、VLDという新しい技術を実装しています。

VLD cycleパターンの仕組み

元のQwen3-32Bの隠れ層が64層の場合を例に説明します。

cycleパターンでは、その物理的な64層(0層目から63層目まで)全体を一つの巨大なブロックとして扱います。そして、そのブロック全体をループで複数回実行します。
例えば、繰り返し回数を1回(合計2周)に設定すると、モデルの実効的な深さは

$64 * 2 = 128$
層になります。

計算の流れは以下の通りです。

  1. 1周目: 入力データが、物理的な0層目から63層目までを順番に通過します。
  2. 2周目: 63層目から出てきた出力が、再び0層目のパラメータを使って処理され、1層目、2層目...と、もう一度63層目までを通過します。

この仕組みにより、物理的な0層目のパラメータは、実効的な64層目の計算でも再利用されることになります。
これが「パラメータ共有」です。

コードの解説:何が、どこで、なぜ行われているのか?

この実装は、大きく分けて2つのレベルで標準のTransformerアーキテクチャに変更を加えています。

1. (マクロな変更)デコーダ層全体の「サイクル実行」

image (6).png

VLD cycleパターンの心臓部です。
Qwen3-32Bでは隠れ層が64層(ベース層)なので、上図の赤枠のデコーダブロックが64枚重なっている状態となります。

VLDではこのベース層を使いまわして、仮想的に層を深めていきますが各層のパラメータは使いまわすようにしています。
該当コード: VLDCycleQwen3Model クラスの forward(順伝搬) メソッド

class VLDCycleQwen3Model(Qwen3Model):
    #...
    def forward(self,...):
        #...
        # (A) VLDのサイクル(繰り返し)を制御する外側のループ
        for current_share_layer_idx in range(self.num_share_layers):
            # (B) 物理的なベース層(64層)を1サイクル分実行する内側のループ
            for decoder_layer in self.layers[: self.config.num_hidden_layers]:
                hidden_states = decoder_layer(
                    hidden_states,
                    current_share_layer_idx=current_share_layer_idx,
                    #...
                )
        #...

  • (A) 外側のループ: num_share_layers で指定された回数だけ、64層のブロック全体を繰り返します。これがVLDの「繰り返し係数」に相当します。
  • (B) 内側のループ: 物理的に存在する64層のデコーダ層(self.layers)を順番に実行します。外側のループが一周するたびに、全く同じパラメータを持つ self.layers が再利用されます。

2. (ミクロな変更)デコーダ層内部のマルチヘッドアテンション機構の「状態管理」

image (7).png

しかし、単純に層を繰り返すだけでは、特にテキスト生成(推論)時に問題が発生します。
例えばまともな文章を生成できなくなるでしょう。

これはKVキャッシュまでも1週目と2週目で使いまわしてしまっていることが原因です。
KVキャッシュとは入力・生成されたトークンのキー(Key)とバリュー(Value)のベクトルを保存しておく一時的な記憶領域です。

これを使いまわすということは1週目の値を2週目の値で上書きされてしまい、「文脈情報が破壊され、時間的な順序が崩壊する」 ことになります。

そこで、マルチヘッドアテンション機構の内部に、非常に重要な変更を加える必要があります。
該当コード: VLDCycleQwen3Attention クラスの forward メソッド

class VLDCycleQwen3Attention(Qwen3Attention):
    #...
    def forward(self,..., current_share_layer_idx: int,...):
        #...
        if past_key_value is not None:
            #...
            # (C) KVキャッシュの保存先インデックスを動的に計算
            key_states, value_states = past_key_value.update(
                key_states,
                value_states,
                self.layer_idx + num_physical_layers * current_share_layer_idx,
                cache_kwargs
            )
        #...

ここで、2週目以降の仮想隠れ層のKVキャッシュは1週目とは別場所(インデックス)に保存するようにしています。

「パラメータ共有」と「KVキャッシュ」の違い

ここで、以前の会話で触れた重要な区別を整理します。

  • パラメータ: モデルの学習済み知識である重みとバイアス
    → VLDではこれは共有する
  • KVキャッシュ: テキスト生成を高速化するために、アテンション計算の中間結果(KeyとValueベクトル) を一時的に保存するメモリ領域
    → VLDでもこれは共有しない

変更されていない点:フィードフォワードネットワーク

Transformerのデコーダ隠れ層は、主に「Masked Multi-Head Self-Attention」と「Position-wise Feed-Forward Network (FFN)」で構成されています(decoder-onlyモデルの場合)。
このコードでは、FFN(コード中ではself.mlp)の内部ロジックには一切手が加えられていません
FFNのパラメータ自体はマクロなループによって再利用されますが、KVキャッシュのような状態管理が不要なため、特別な変更は必要ないのです。

まとめ

今回のコードは、以下のことを行っています。

  1. 目的: Qwen3モデルのパラメータ数を増やすことなく、推論能力を向上させる。
  2. 手法: 仮想論理深度(VLD)のcycleパターンを実装する。
  3. 実装(マクロ): VLDCycleQwen3Modelクラスで、物理的な64層のデコーダ層全体を一つのブロックとみなし、ループ処理で複数回実行する。これにより、パラメータ(重み・バイアス) が共有される。
  4. 実装(ミクロ): VLDCycleQwen3Attentionクラスで、推論時に各計算ステップの状態を正しく区別するため、サイクル数に応じてKVキャッシュの保存先インデックスを動的に計算する。

以上、論文で提案されたVLDのコンセプトを、実際のTransformerモデル(Qwen3)に適用した実装となります。


本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?