TL;DR
- Base と Donor が 別アーキでも、共有語彙を足場に埋め込み空間を Procrustes(回転)+ Umeyama(スケール)で整合 → 埋め込み(必須)+ lm_head(任意)を選択的に行単位ブレンドします。
- マージ中に PPL ガードを回し、閾値超過ならヘッド段を スキップ/ロールバック。対称 KL も近似でモニタ。
- 性能(品質)を本当に上げたいなら、パラメータ調整が相当必要。既定値は“安全寄りの無難設定”で、ドメインによっては効果が薄いです。 Optuna での探索や、プロンプト集合の作り込みがカギ。
GitHub: https://github.com/Ono-Katsuki/cross-architecture-llm-merge
何ができるのか
- クロスアーキ対応:埋め込み次元・ボキャ差異を吸収して Donor→Base に射影(CPU float32)。
- デバイス/精度セーフ:dtype/device を境界で強制整合、index tensor もデバイス合わせ。
- 品質ガード:PPL 比(相対PPL)で異常増を検知→ヘッド段を抑制/巻き戻し。KL(Base‖Merged)とKL(Merged‖Base)の平均(対称KL)を top-K 近似で併用。
- ディスク安全な探索:Optuna 実行でも 小さな JSON 統計のみ保存。最終モデルは 1回だけ保存(任意)。
-
決定論評価:CuBLAS 決定論モード+
torch.use_deterministic_algorithms(True)+Greedy 生成で再現性を担保。
前提
-
デフォルトのモデル以外では1から探索が必要です。
-
既定値のままだと “安全に壊しにくい” 代わりに、改善幅は控えめになりがちです。
-
期待する改善を得るには、以下の2点が必須です:
-
代表プロンプト集合の設計(
--prompts-fileでPPL/選抜に使う) - ハイパラ探索(Optuna)
-
代表プロンプト集合の設計(
「マージが効く問題設定」をプロンプトで定義し、その上で
min_tau_*/alpha_*/target_*を回す。
インストール
# Python 3.10+ 推奨
pip install torch transformers optuna numpy
実行
git clone https://github.com/Ono-Katsuki/cross-architecture-llm-merge
cd cross-architecture-llm-merge
python merge_embed_head_optuna.py \
--base Qwen/Qwen2.5-3B-Instruct \
--donor google/gemma-2-2b \
--out ./merged_out \
--dtype bf16 \
--donor-on-cpu \
--ppl-cap 3.0
./merged_outに Base 互換のマージ済みモデルを保存。標準出力に PPL 前後・ブレンド行数などの JSON 統計が出ます。
Optuna 探索
python merge_embed_head_optuna.py \
--base Qwen/Qwen2.5-3B-Instruct \
--donor google/gemma-2-2b \
--out ./tuned_merged \
--dtype bf16 \
--donor-on-cpu \
--trials 48 \
--ppl-cap 2.5 \
--kl-topk 512 --kl-positions tail:8 \
--save-topk 0
-
相対PPL を最小化しつつ、対称KL を最大化(実装では
-KL_symを最小化)。 -
./tuned_merged/stats/*.json(各試行の小さな統計)、pareto.json/best.json。 - 既定ではベスト設定で再マージし
./tuned_merged/best_model/に 1回だけ保存。
チューニング項目
| パラメータ | 役割 | レンジの目安 |
|---|---|---|
target_emb_frac |
Stage-Aで埋め込みを何割ブレンドするか。大きいほど効果もリスクも増。 | 0.50–0.80 |
alpha_cap_emb / alpha_floor_emb
|
埋め込みの最大/最小ブレンド率。cap を上げると攻め。 | cap: 0.80–0.95 / floor: 0.02–0.10 |
min_tau_emb |
“似ているとみなす”コサイン閾値。高いほど厳選(安全寄り)。 | 0.50–0.60 |
target_head_rows |
Stage-Bでlm_headの行を何個触るか(tie 無し時)。 | 3k–20k |
alpha_cap_head / alpha_floor_head
|
ヘッドのブレンド強度。cap を上げるほど攻め。 | cap: 0.60–0.95 / floor: 0.10–0.25 |
min_tau_head |
ヘッド側のコサイン閾値。 | 0.60–0.72 |
use_soft_clip |
ノルム外れ値の滑らかな抑制(通常は ON が安定)。 | True/False |
ppl-cap |
ガードの厳しさ。低いと保守的・高いと攻め。 | 2.0–3.0 |
- タスク・モデルにより最適域は変わります。
- ヘッド tying(入出力埋め込みが同一重み)なモデルでは Stage-B は自動スキップ。埋め込みに集中。
- 語彙の重なりが小さい場合は、
--max-sharedを広げつつtarget_emb_fracは控えめから入るのが安全。
コツ
- プロンプト集合が“雑”:探索が不安定になりがち。本番を想定した代表性を持たせる。
-
alpha系を上げすぎ:PPL 爆発→ヘッド巻き戻し連発。cap を下げる/min_tau_*を上げる。 -
ヘッドを触り過ぎ:
target_head_rowsを段階的に上げる。PPL が安定したら攻めていく。 -
VRAM 足りない:
--donor-on-cpuを使う。SVD/射影は CPU f32 前提でも十分動く。
評価
# 例のプロンプト
cat > prompts.txt << 'EOF'
日本語で、Transformerの自己注意を高校生にもわかるように説明して。ポイントを3つに箇条書きして。
Summarize the benefits of attention mechanisms in two concise sentences.
EOF
# 比較実行
python eval_fixed_generation_compare.py \
--base Qwen/Qwen2.5-3B-Instruct \
--merged ./tuned_merged/best_model \
--tokenizer Qwen/Qwen2.5-3B-Instruct \
--dtype bf16 \
--prompts prompts.txt \
--out ./compare_out \
--seed 42 \
--max-new-tokens 256
-
compare_out/results.jsonlとresults.mdを生成(Base/Merged の回答・経過時間を確認)。 - 生成品質は プロンプト分布に依存。探索前後・プロンプト差替えで 結果が変わり得る点に注意。
仕組み
-
共有語彙を抽出(特殊トークン除外)。
-
平均中心化 → Procrustes(回転)→ Umeyama(スケール)で Donor 埋め込みを Base 空間へ射影。
-
スケール整合&ソフトクリップでノルムの暴れを抑制。
-
コサインに基づく選抜ブレンド:
- Stage-A(埋め込み):上位
target_emb_fracを ロジスティック日程(tau/gamma/alpha_floor/cap)で混合。 - PPL ガード:超過なら Stage-B をスキップ。
- Stage-B(lm_head 任意):tie なし時のみ実施、超過なら 巻き戻し。
- Stage-A(埋め込み):上位
-
Optunaでハイパラを探索(PPL↓ & KL↑ を両立)。
まとめ
異なるアーキでも、共有語彙を手掛かりに埋め込み空間を整合 → 行単位で慎重にブレンドし、PPL/KL ガードで安全性を担保できます。