3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

小型LLM同士を"橋渡し"して性能向上!Gated Cross-Attention Bridgeの実装と効果

Posted at

はじめに

「2つの小型LLMを組み合わせたら、お互いの強みを活かせるのでは?」

この記事では、Qwen2.5-3BGemma-2-2BGated Cross-Attentionで接続し、両者の長所を組み合わせる手法を紹介します。学習不要で、推論時のみ動作するため、すぐに試せます。

TL;DR

  • 🎯 Qwenの生成力 × Gemmaの知識 = より賢い出力
  • 🔧 学習不要、推論時のみBridgeモジュールを挿入
  • 📊 ベイズの定理の説明タスクで明確な改善を確認
  • ⚡ SDPA対応でA100で快適に動作

何が問題だったのか?

小型LLMには、それぞれ得意・不得意があります:

モデル 強み 弱み
Qwen-3B 簡潔で自然な文章生成 知識に誤りがある場合も
Gemma-2B 正確な知識、豊富な説明 冗長で構成が弱い

「片方の知識をもう片方が参照できたら...?」

解決策:Gated Cross-Attention Bridge

アーキテクチャ概要

ユーザー質問
    ↓
Gemma(理解)→ メモリ化(256トークン)
    ↓
Qwen(生成)← メモリを参照しながら出力
    ↑
Gated Cross-Attention(4層に挿入)

核心:Gated Cross-Attention

class GatedCrossAttention(nn.Module):
    def __init__(self, d_model, n_heads, init_gate=0.4):
        super().__init__()
        self.nh = n_heads
        self.d = d_model // n_heads
        # Q: Qwenが「何を知りたいか」
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        # K,V: Gemmaのメモリ
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        # ゲート:Gemmaの影響度を調整
        self.gate = nn.Parameter(torch.full((1,), float(init_gate)))
    
    def forward(self, x, mem):
        B, T, _ = x.shape; S = mem.shape[1]
        # Qwenの状態からQuery
        q = self.q_proj(x).view(B,T,self.nh,self.d).transpose(1,2)
        # GemmaのメモリからKey, Value
        k = self.k_proj(mem).view(B,S,self.nh,self.d).transpose(1,2)
        v = self.v_proj(mem).view(B,S,self.nh,self.d).transpose(1,2)
        # Attention計算
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        y = y.transpose(1,2).contiguous().view(B,T,-1)
        y = self.o_proj(y)
        # 元の状態 + ゲート調整された外部知識
        return x + torch.clamp(self.gate, 0, 1) * y

ポイント

  • Qwenが「必要な情報だけ」Gemmaから取得
  • ゲートで影響度を調整(0〜1の範囲)
  • 元の情報を保持しながら知識を追加

4層への戦略的配置

# Qwenの総層数をLとして
insert_idx = {L//4, L//2, (3*L)//4, L-1}

# 上層ほどゲートを強く
init_gate = {
    insert_idx[0]: 0.30,  # 下層:控えめ
    insert_idx[1]: 0.40,
    insert_idx[2]: 0.50,
    insert_idx[3]: 0.60   # 上層:強め
}

なぜ4層?

  • 下層:基本概念の理解
  • 中層:詳細知識の統合
  • 上層:論理構造の洗練
  • 最終層:出力の最終調整

メモリの構築

@torch.no_grad()
def build_memory_with_gemma(user_text: str, max_mem_tokens: int = 256):
    # Gemmaでユーザー質問を処理
    msgs = [{"role": "user", "content": user_text}]
    prompt = gem_tok.apply_chat_template(msgs, tokenize=False, 
                                         add_generation_prompt=True)
    enc = gem_tok(prompt, return_tensors="pt").to(device)
    
    # 最終層の隠れ状態を取得
    out = gem(**enc, output_hidden_states=True)
    h = out.hidden_states[-1][:, -max_mem_tokens:]
    
    # Qwenの次元に変換して正規化
    h = proj_ext2main(h.to(dtype=dtype))
    h = mem_norm(h)
    return h

256トークンのメモリで、Gemmaの理解を十分に保持

実験結果

タスク:「ベイズの定理を日本語で一段落で説明してください」

❌ Qwen-3B(単体)

ベイズの定理は、事実が得られた後に確率を更新する方法を提供する数学的原理です。これは、ある事象Aが起こった後に、その事象Bが起こる可能性が高まるか低まるかを評価する際に有用です。公式的には、P(B|A) = P(A|B) × P(B) / P(A) と表され、P(B|A)は事象Bが事象Aが起こった条件下での確率を指し、他のパラメータはそれぞれ事象BとAが独立に起こる確率を示します

問題点:「独立に起こる確率」という致命的な誤り

△ Gemma-2B-IT(単体)

ベイズの定理は、確率論において、ある事象の発生確率を計算する際に、既知の情報と新しい情報との関係性を考慮する際に用いられる定理です。簡単に言うと、新しい情報が得られた後、その情報に基づいて事象の発生確率を更新する、という考え方です。

具体的には、ある事象Aが起こる確率をP(A)とします。この確率は、事象Aが起こる可能性のある事象Bと、事象Bが起こる確率P(B)を考慮して計算されます。そして、新しい情報Cが得られた後、事象Aが起こる確率を更新する公式は、P(A|C) = [P(C|A) * P(A)] / P(C) となります。

つまり、新しい情報Cが得られた後、事象Aが起こる確率は、事象Aが起こる可能性のある事象Bと、事象Bが起こる確率P(B)を考慮した上で、新しい情報Cが得られた確率P(C)と、事象Aが起こる確率P(A)の積で計算されます。

ベイズの定理は、様々な分野で応用されており、例えば、医学、機械学習、統計学など、様々な分野で重要な役割を果たしています。

問題点:正確だが冗長で、説明の流れが不明瞭

✅ BRIDGE(Qwen ← Gemma)

ベイズの定理は、事実が得られた後に確率を更新する方法を提供する数学的原理です。この定理は、ある事象Aが起こった後に事象Bが起こる確率を計算する際に有用です。公式はP(B|A) = P(A|B) × P(B) / P(A)と表され、ここでP(B|A)は「事象Bが起こる確率(事象Aが起こった後に)」、P(A|B)は「事象Aが起こる確率(事象Bが起こった後に)」、P(B)は「事象Bが起こる確率」と、P(A)は「事象Aが起こる確率」を意味します。ベイズの定理は、新しい情報に基づいて既存の知識を修正し、予測を改善するための重要なツールとして知られています。

改善点

  • ✅ 各項の意味を正確に説明
  • ✅ 独立性の誤りを修正
  • ✅ 簡潔で分かりやすい構成
  • ✅ 定理の本質(知識の更新)を明確に記述

実装のポイント

1. 決定的な実行(再現性の確保)

SEED = 12345
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

2. 効率的なAttention(SDPA使用)

torch.backends.cuda.sdp_kernel(
    enable_flash=True, 
    enable_math=False, 
    enable_mem_efficient=True
)

3. 動的なforward置き換え

def make_patched_forward(block, gca):
    orig = block.forward
    def patched(self, hidden_states, *a, **kw):
        out = orig(hidden_states, *a, **kw)
        hs, others = (out[0], out[1:]) if isinstance(out, tuple) else (out, ())
        mem = getattr(self, "_ext_mem", None)
        if mem is not None:
            hs = gca(hs, mem)  # メモリがあれば統合
        return (hs,) + others if others else hs
    return patched

元のモデルの重みは一切変更せず、forwardメソッドのみを置き換え

必要な環境

pip install torch transformers accelerate
  • GPU: A100推奨(V100でも動作可能)
  • VRAM: 約16GB(fp16使用時)
  • HuggingFace Tokenへのログインが必要(Gemma利用のため)

完全なコード

# ===========================================
# Qwen(3B) × Gemma2(2B-IT) - Gated Cross-Attention Bridge
# ===========================================
import torch, torch.nn as nn, torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# 再現性の確保
SEED = 12345
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16

# モデルのロード
qwen_id = "Qwen/Qwen2.5-3B-Instruct"
gemma_id = "google/gemma-2-2b-it"

qwen_tok = AutoTokenizer.from_pretrained(qwen_id, use_fast=True)
gem_tok = AutoTokenizer.from_pretrained(gemma_id, use_fast=True)

qwen = AutoModelForCausalLM.from_pretrained(
    qwen_id, torch_dtype=dtype, device_map="cuda", attn_implementation="sdpa"
)
gem = AutoModelForCausalLM.from_pretrained(
    gemma_id, torch_dtype=dtype, device_map="cuda", attn_implementation="sdpa"
)
qwen.eval(); gem.eval()

# Gated Cross-Attentionの定義
class GatedCrossAttention(nn.Module):
    def __init__(self, d_model, n_heads, init_gate=0.4):
        super().__init__()
        self.nh = n_heads
        self.d = d_model // n_heads
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        self.gate = nn.Parameter(torch.full((1,), float(init_gate)))
    
    def forward(self, x, mem):
        B, T, _ = x.shape; S = mem.shape[1]
        q = self.q_proj(x).view(B,T,self.nh,self.d).transpose(1,2)
        k = self.k_proj(mem).view(B,S,self.nh,self.d).transpose(1,2)
        v = self.v_proj(mem).view(B,S,self.nh,self.d).transpose(1,2)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        y = y.transpose(1,2).contiguous().view(B,T,-1)
        y = self.o_proj(y)
        return x + torch.clamp(self.gate, 0, 1) * y

# Bridgeモジュールの設定
d_model = qwen.config.hidden_size
d_ext = gem.config.hidden_size
n_heads = qwen.config.num_attention_heads
L = qwen.config.num_hidden_layers

proj_ext2main = nn.Linear(d_ext, d_model, bias=False).to(device, dtype=dtype)
mem_norm = nn.LayerNorm(d_model, elementwise_affine=False).to(device, dtype=dtype)

insert_idx = sorted({max(0, L//4), L//2, (3*L)//4, L-1})
gca_modules = {}

def make_patched_forward(block, gca):
    orig = block.forward
    def patched(self, hidden_states, *a, **kw):
        out = orig(hidden_states, *a, **kw)
        hs, others = (out[0], out[1:]) if isinstance(out, tuple) else (out, ())
        mem = getattr(self, "_ext_mem", None)
        if mem is not None:
            hs = gca(hs, mem)
        return (hs,) + others if others else hs
    return patched

for idx in insert_idx:
    init_gate = {insert_idx[0]:0.30, insert_idx[1]:0.40, 
                 insert_idx[2]:0.50, insert_idx[3]:0.60}[idx]
    gca = GatedCrossAttention(d_model, n_heads, init_gate=init_gate).to(device, dtype=dtype)
    gca_modules[idx] = gca
    blk = qwen.model.layers[idx]
    blk.forward = make_patched_forward(blk, gca).__get__(blk, blk.__class__)

def set_external_memory(mem):
    for idx in insert_idx:
        setattr(qwen.model.layers[idx], "_ext_mem", mem)

# メモリ構築
@torch.no_grad()
def build_memory_with_gemma(user_text: str, max_mem_tokens: int = 256):
    msgs = [{"role": "user", "content": user_text}]
    prompt = gem_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    enc = gem_tok(prompt, return_tensors="pt").to(device)
    out = gem(**enc, output_hidden_states=True)
    h = out.hidden_states[-1][:, -max_mem_tokens:]
    h = proj_ext2main(h.to(dtype=dtype))
    h = mem_norm(h)
    return h

# 生成関数
@torch.no_grad()
def generate_bridge(user_text: str):
    mem = build_memory_with_gemma(user_text, max_mem_tokens=256)
    set_external_memory(mem)
    
    msgs = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": user_text}
    ]
    prompt = qwen_tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    enc = qwen_tok(prompt, return_tensors="pt").to(device)
    in_len = enc.input_ids.shape[1]
    
    out = qwen.generate(
        **enc,
        max_new_tokens=1024,
        do_sample=False,
        num_beams=1,
        use_cache=True,
        eos_token_id=qwen_tok.eos_token_id,
        pad_token_id=qwen_tok.eos_token_id,
    )
    return qwen_tok.decode(out[0, in_len:], skip_special_tokens=True)

# 実行
user_prompt = "ベイズの定理を日本語で一段落で説明してください。"
print(generate_bridge(user_prompt))

今後の改善案

  1. 動的ゲート調整: タスクに応じてゲート強度を自動調整
  2. 複数モデル対応: 3つ以上のLLMを接続
  3. メモリの選択的取得: Attentionスコアに基づいて必要な情報だけ取得
  4. ファインチューニング: ゲート・プロジェクタを少量データで学習

まとめ

Gated Cross-Attention Bridgeにより:

  • 🎯 学習不要で2つのLLMを接続
  • 💡 お互いの強みを活かした出力が可能
  • ⚡ 推論時のみの動作で実装が容易

小型LLMでも、組み合わせ方次第で性能向上が期待できます。ぜひ試してみてください!

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?