初めに
松尾研LLM開発コンペ2025に参加した際、チームの取り組みの1つとしてVLD(Virtual Logical Depth:仮想論理深度) をQwen3に実装し検証を行いました。
ここではその知見共有が目的となります。
VLDとは
論文↓
ざっくり、
VLD(Virtual Logical Depth:仮想論理深度) とは、パラメータ数を増やさずに重みを再利用することで、モデルの効果的なアルゴリズム深度を増加させる新しいスケーリング手法です。
この論文では、大規模言語モデルのスケーリングにおける「第4の次元」として提案されています(従来の3次元は深さ、幅、パラメータ数)。
主な特徴として以下があげられます。:
- 固定されたパラメータ数でVLDを適用すると、知識容量はほぼ一定に保たれる一方で、推論能力が大幅に向上
つまりVLDは、
モデルを大きくせずに推論性能を向上させる効率的なスケーリング戦略として提案されています。
また、パラメータの再利用方法には何パターンかあり、この中でもCYCLEパターンが最も性能の向上が見られているそうです。
コードの全体像:VLD (cycleパターン) を実装したQwen3言語モデルについて
以下のコードは、Qwen3というTransformerベースの大規模言語モデルを改造し、仮想論理深度(Virtual Logical Depth: VLD) という技術を実装したものです。
Qwen3 VLD コード
from transformers import (
Qwen3ForCausalLM,
Qwen3Model,
Qwen3Config,
Qwen3PreTrainedModel
)
from transformers.models.qwen3.modeling_qwen3 import (
apply_rotary_pos_emb,
eager_attention_forward,
Qwen3Attention,
Qwen3DecoderLayer
)
import torch
from torch import nn
from typing import Optional,Unpack, Callable
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.utils import TransformersKwargs
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
class VLDCycleQwen3Attention(Qwen3Attention):
def __init__(
self,
config: Qwen3Config,
layer_idx: int,
):
super().__init__(config, layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
current_share_layer_idx: int,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
num_physical_layers = self.config.num_hidden_layers
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx+num_physical_layers*current_share_layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class VLDCycleQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__(config, layer_idx)
self.self_attn = VLDCycleQwen3Attention(config=config, layer_idx=layer_idx)
def forward(
self,
hidden_states: torch.Tensor,
current_share_layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
current_share_layer_idx=current_share_layer_idx,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class VLDCycleQwen3Model(Qwen3Model):
def __init__(
self,
config: Qwen3Config,
num_share_layers: int
):
super().__init__(config)
self.num_share_layers = num_share_layers
self.layers = nn.ModuleList(
[VLDCycleQwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for current_share_layer_idx in range(self.num_share_layers):
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
current_share_layer_idx=current_share_layer_idx,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
class VLDCycleQwen3ForCausalLM(Qwen3ForCausalLM):
_no_split_modules = ["VLDCycleQwen3DecoderLayer"]
def __init__(
self,
config,
num_share_layers: int = 1
):
Qwen3PreTrainedModel.__init__(self,config)
self.model = VLDCycleQwen3Model(
config,
num_share_layers=num_share_layers
)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
具体的には、VLDの中でも cycle(サイクル) と呼ばれるパターンを適用しています。
この改造の目的は、「モデルが記憶できる知識の量(パラメータ数)を増やすことなく、計算の深さを増すことで、モデルの推論能力を向上させる」ことです。
以下にどのような改造が施されているかを解説していきます。
背景:TransformerアーキテクチャとQwen3モデル
このコードが何をしているかを完全に理解するためには、まずベースとなっているTransformerアーキテクチャについて理解することが重要です。
Transformerとは?
論文 → Attention Is All You Need
Transformerの基本的な構成要素は「Transformerブロック」と呼ばれる隠れ層であり、これは主に2つのサブレイヤーから構成されます。
(詳細はいろいろと調べてみてください。私もフワッとしか分かっておりません故…)
上図において、左半分がエンコーダ、右半分がデコーダ となっています。
-
マルチヘッド・セルフアテンション (Multi-Head Self-Attention):
文中の各単語が、他のすべての単語にどれだけ「注意」を向けるべきかを計算します。これにより、単語は文脈に応じた豊かな表現を獲得します。 -
位置ごとのフィードフォワードネットワーク (Position-wise Feed-Forward Network, FFN):
アテンション層からの出力を受け取り、各単語の表現を個別に非線形変換します。これにより、モデルの表現力が高まります。
Qwen3におけるTransformerの利用
Qwen3は、このTransformerアーキテクチャをベースに構築された、大規模言語モデルです。
特にQwen3は、GPTシリーズのように「デコーダオンリー(decoder-only)」と呼ばれる構造を採用しています。
これは、オリジナルのTransformerが持っていたエンコーダとデコーダのうち、デコーダ部分のみを積み重ねて構成されていることを意味します。
上図において、右半分のデコーダ のみを使用
このデコーダオンリーアーキテクチャは、主にテキスト生成タスクに特化しています。
モデルは、それまでに入力された単語のシーケンスを基に、次に来る単語を予測するという処理を繰り返します。
今回のコードで言及されている「64層の隠れ層」とは、このTransformerのデコーダブロックが64個積み重なっていることを指します。
各デコーダブロックには、前述のアテンション層とFFNが含まれますが、テキスト生成のためにアテンション層は「マスク付き(Masked)」になっており、未来の単語を参照できないようになっています。
このコードは、このQwen3の基本的な構造、特に各デコーダブロックの呼び出し方と、その内部にあるアテンション層の状態管理に手を加えることで、VLDという新しい技術を実装しています。
VLD cycleパターンの仕組み
元のQwen3-32Bの隠れ層が64層の場合を例に説明します。
cycleパターンでは、その物理的な64層(0層目から63層目まで)全体を一つの巨大なブロックとして扱います。そして、そのブロック全体をループで複数回実行します。
例えば、繰り返し回数を1回(合計2周)に設定すると、モデルの実効的な深さは
$64 * 2 = 128$
層になります。
計算の流れは以下の通りです。
- 1周目: 入力データが、物理的な0層目から63層目までを順番に通過します。
- 2周目: 63層目から出てきた出力が、再び0層目のパラメータを使って処理され、1層目、2層目...と、もう一度63層目までを通過します。
この仕組みにより、物理的な0層目のパラメータは、実効的な64層目の計算でも再利用されることになります。
これが「パラメータ共有」です。
コードの解説:何が、どこで、なぜ行われているのか?
この実装は、大きく分けて2つのレベルで標準のTransformerアーキテクチャに変更を加えています。
1. (マクロな変更)デコーダ層全体の「サイクル実行」
VLD cycleパターンの心臓部です。
Qwen3-32Bでは隠れ層が64層(ベース層)なので、上図の赤枠のデコーダブロックが64枚重なっている状態となります。
VLDではこのベース層を使いまわして、仮想的に層を深めていきますが各層のパラメータは使いまわすようにしています。
該当コード: VLDCycleQwen3Model クラスの forward(順伝搬) メソッド
class VLDCycleQwen3Model(Qwen3Model):
#...
def forward(self,...):
#...
# (A) VLDのサイクル(繰り返し)を制御する外側のループ
for current_share_layer_idx in range(self.num_share_layers):
# (B) 物理的なベース層(64層)を1サイクル分実行する内側のループ
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
current_share_layer_idx=current_share_layer_idx,
#...
)
#...
-
(A) 外側のループ:
num_share_layersで指定された回数だけ、64層のブロック全体を繰り返します。これがVLDの「繰り返し係数」に相当します。 -
(B) 内側のループ: 物理的に存在する64層のデコーダ層(
self.layers)を順番に実行します。外側のループが一周するたびに、全く同じパラメータを持つself.layersが再利用されます。
2. (ミクロな変更)デコーダ層内部のマルチヘッドアテンション機構の「状態管理」
しかし、単純に層を繰り返すだけでは、特にテキスト生成(推論)時に問題が発生します。
例えばまともな文章を生成できなくなるでしょう。
これはKVキャッシュまでも1週目と2週目で使いまわしてしまっていることが原因です。
KVキャッシュとは入力・生成されたトークンのキー(Key)とバリュー(Value)のベクトルを保存しておく一時的な記憶領域です。
これを使いまわすということは1週目の値を2週目の値で上書きされてしまい、「文脈情報が破壊され、時間的な順序が崩壊する」 ことになります。
そこで、マルチヘッドアテンション機構の内部に、非常に重要な変更を加える必要があります。
該当コード: VLDCycleQwen3Attention クラスの forward メソッド
class VLDCycleQwen3Attention(Qwen3Attention):
#...
def forward(self,..., current_share_layer_idx: int,...):
#...
if past_key_value is not None:
#...
# (C) KVキャッシュの保存先インデックスを動的に計算
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx + num_physical_layers * current_share_layer_idx,
cache_kwargs
)
#...
ここで、2週目以降の仮想隠れ層のKVキャッシュは1週目とは別場所(インデックス)に保存するようにしています。
「パラメータ共有」と「KVキャッシュ」の違い
ここで、以前の会話で触れた重要な区別を整理します。
-
パラメータ: モデルの学習済み知識である重みとバイアス
→ VLDではこれは共有する -
KVキャッシュ: テキスト生成を高速化するために、アテンション計算の中間結果(KeyとValueベクトル) を一時的に保存するメモリ領域
→ VLDでもこれは共有しない
変更されていない点:フィードフォワードネットワーク
Transformerのデコーダ隠れ層は、主に「Masked Multi-Head Self-Attention」と「Position-wise Feed-Forward Network (FFN)」で構成されています(decoder-onlyモデルの場合)。
このコードでは、FFN(コード中ではself.mlp)の内部ロジックには一切手が加えられていません。
FFNのパラメータ自体はマクロなループによって再利用されますが、KVキャッシュのような状態管理が不要なため、特別な変更は必要ないのです。
まとめ
今回のコードは、以下のことを行っています。
- 目的: Qwen3モデルのパラメータ数を増やすことなく、推論能力を向上させる。
-
手法: 仮想論理深度(VLD)の
cycleパターンを実装する。 -
実装(マクロ):
VLDCycleQwen3Modelクラスで、物理的な64層のデコーダ層全体を一つのブロックとみなし、ループ処理で複数回実行する。これにより、パラメータ(重み・バイアス) が共有される。 -
実装(ミクロ):
VLDCycleQwen3Attentionクラスで、推論時に各計算ステップの状態を正しく区別するため、サイクル数に応じてKVキャッシュの保存先インデックスを動的に計算する。
以上、論文で提案されたVLDのコンセプトを、実際のTransformerモデル(Qwen3)に適用した実装となります。
本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。




