はじめに
本記事では Qwen3-235B-A22B-Thinking-2507 に HARI する方法を紹介します。
背景として、予選時は HLE のスコアが高いベースモデル(deepseek-ai/DeepSeek-R1-0528)を、なるべく忘却を防いで軽く SFT して提出するという戦略を取りました(予選1位)。
具体的には MedMCQA 10 サンプルを 1 エポック、KL正則化(λ=0.9) で SFT。ベースモデルから何も変わっていないだろうと思っていたものの、チーム内評価・予選評価ともにスコアはベースより上昇し、この時点では「運の良い偶然」と見なしていました。
決勝で Qwen / Qwen3-235B-A22B-Thinking-2507 が解禁され、同様の目的でKV自己蒸留 (α=0.9)を適用したところ同様にスコアが上昇。一方で、他パラメータは変えずに “ただの SFT” や “last4 のみ” を試すとスコアは低下しました。ここで kv=0.9 に意味があることは分かったものの、当時はまだ「謎チューニング」と呼び、戦略の主軸に据える決断には至っていませんでした。
以降、1 万サンプル超の SFT や DFT を試してもスコアは下がり、RL はコストが高すぎるという制約もあり、結果的になぜかスコアが上がる “謎チューニング” で進めることに。検証の過程で、500 サンプル・1 エポックが特に効くことが分かり、約 30 モデル・多様なデータで学習しても例外なくスコアが上がるという挙動を確認しました。最終的には約 2.4% の向上を得られ、「謎チューニング」という名はふさわしくないため、鍼(はり)のようにモデルの潜在能力を高めるという意味を込めて HARI と命名しています。
以降の章では、上記の経緯を踏まえた Qwen3-235B-A22B-Thinking-2507 への具体的な HARI 手順とスクリプトを示します。
1. 環境構築
以下のcondaを作成してください。
Condaバージョン: Miniconda 24.7.1
Python: 3.11
cat > setup_qwen_h100_one_shot.sh <<'SH'
#!/usr/bin/env bash
set -eo pipefail # -u は付けない(/etc/bashrc 問題を避ける)
ENV_NAME=${ENV_NAME:-qwen_sft}
PY_VER=${PY_VER:-3.10}
QWEN_DIR="${HOME}/Qwen"
# 0) conda をロード(~/.bashrcを読まない安全ルート)
eval "$($(conda info --base)/bin/conda shell.bash hook)"
# 1) 同名環境があれば削除
if conda env list | awk '{print $1}' | grep -qx "${ENV_NAME}"; then
echo "[INFO] Removing existing env: ${ENV_NAME}"
conda deactivate || true
conda remove -y -n "${ENV_NAME}" --all
fi
# 2) 新規作成&有効化
echo "[INFO] Creating env: ${ENV_NAME} (python=${PY_VER})"
conda create -y -n "${ENV_NAME}" python="${PY_VER}"
conda activate "${ENV_NAME}"
# 3) PyTorch (CUDA 12.4, H100 OK)
pip install --index-url https://download.pytorch.org/whl/cu124 \
torch torchvision torchaudio
# 4) CUDA Toolkit を conda で導入(nvcc を確実に入れる)
conda install -y -c nvidia cuda-toolkit=12.4
# 5) CUDA_HOME をこのenvに固定(永続化)
export CUDA_HOME="$CONDA_PREFIX"
mkdir -p "$CONDA_PREFIX/etc/conda/activate.d" "$CONDA_PREFIX/etc/conda/deactivate.d"
cat > "$CONDA_PREFIX/etc/conda/activate.d/cuda_home.sh" <<'EOS'
export CUDA_HOME="$CONDA_PREFIX"
export PATH="$CUDA_HOME/bin:$PATH"
EOS
cat > "$CONDA_PREFIX/etc/conda/deactivate.d/cuda_home.sh" <<'EOS'
unset CUDA_HOME
EOS
echo "[INFO] CUDA_HOME set to $CUDA_HOME"
if [ -x "$CUDA_HOME/bin/nvcc" ]; then
"$CUDA_HOME/bin/nvcc" --version || true
else
echo "[ERROR] nvcc not found in $CUDA_HOME/bin" >&2
exit 1
fi
# 6) 依存(ビルドに必要なもの含む)
pip install -U pip setuptools wheel packaging ninja
# 7) DeepSpeed をビルドインストール(CUDAあり)
pip install deepspeed==0.17.5 --no-build-isolation
# 8) そのほか必要ライブラリ
pip install -U transformers datasets peft accelerate trl hf_transfer
pip install "bitsandbytes>=0.43.3"
# Flash-Attention:ビルド環境が整っていれば有効。不要ならコメントのままでOK
pip install flash-attn --no-build-isolation
# 9) Qwen 公式リポジトリ(finetune.py を使うため)
if [ -d "${QWEN_DIR}/.git" ]; then
echo "[INFO] Updating ${QWEN_DIR}"
git -C "${QWEN_DIR}" pull --ff-only || true
else
echo "[INFO] Cloning Qwen repo to ${QWEN_DIR}"
git clone https://github.com/QwenLM/Qwen.git "${QWEN_DIR}"
fi
# 10) 検証:torch/DS/CUDA が見えるか
python - <<'PY'
import os, torch, deepspeed, subprocess
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("DeepSpeed:", deepspeed.__version__)
print("CUDA_HOME:", os.environ.get("CUDA_HOME"))
nvcc = os.path.join(os.environ.get("CUDA_HOME",""), "bin", "nvcc")
print("Has nvcc:", os.path.exists(nvcc))
try:
if os.path.exists(nvcc):
subprocess.run([nvcc, "--version"], check=False)
except Exception as e:
print("nvcc check error:", e)
PY
echo "[OK] Environment '${ENV_NAME}' is ready."
echo "Next: conda activate ${ENV_NAME}"
SH
bash setup_qwen_h100_one_shot.sh
2. HARI
2.1. ノード確保
srun --partition P10 --nodes=1 --gpus-per-node=8 --cpus-per-task=240 --time=48:00:00 --pty bash -i
2.2. condaアクティベート
conda activate qwen_sft
2.3. スクリプト
cat > finetune_hf_sft.py <<'PY'
# finetune_hf_sft.py
# - CoT任意学習(answerは任意形式、<think> も任意で付与)
# - KV自己蒸留(KV/HS、層選択、トークン間引き)※単一ロードで教師パス取得可
# - 複数HFデータセット / 使用量指定(name[:split[:amount]] を複数 --hf_source)
# - 出力フォルダ名にデータ名&件数付与 / 実行設定(run_config.json, cmd.txt)保存
# - HFトークン必要モデル/データに対応(--hf_token / --hf_data_token / 環境変数)
# - 成功時に保存先パスをプリント
# - TransformersのTrainer互換(num_items_in_batch対応)
# - Qwen-3流のChatML(systemは前置しない)
from dataclasses import dataclass, field, asdict
import json, os, re, random, sys
from typing import Dict, Optional, List, Any, Tuple
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from contextlib import nullcontext
import transformers
from transformers import Trainer, BitsAndBytesConfig, HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_pt_utils import LabelSmoother
from accelerate.utils import DistributedType
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
# ---------------- Consts / utils ----------------
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
local_rank = None
def rank0_print(*a):
if local_rank in (0, None):
print(*a)
# ---------------- Args ----------------
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")
hf_token: Optional[str] = field(default=None, metadata={"help":"HF auth token for models (optional). If omitted, env or cached login is used."})
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help":"Local json/jsonl path (optional)."})
eval_data_path: str = field(default=None)
lazy_preprocess: bool = False
# HF datasets:
hf_dataset: Optional[str] = field(default=None, metadata={"help":"Legacy single dataset name (use --hf_source instead)."})
hf_split: str = field(default="train")
hf_eval_split: Optional[str] = field(default=None)
hf_source: List[str] = field(default_factory=list, metadata={"help":"Repeatable: name[:split[:amount]] e.g. oNo-1/MedMCQA:train:all"})
hf_data_token: Optional[str] = field(default=None, metadata={"help":"HF auth token for datasets (optional)"})
# formatting / filtering
use_cot: bool = field(default=True)
cot_required: bool = field(default=True)
answer_letters: str = field(default="ABCDE") # 互換のため残置(使わない)
sample_shuffle_seed: int = 42
append_data_tag: bool = True # append dataset tag to output_dir
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(default=16384)
use_lora: bool = False
interactive: bool = False
@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
# bitsandbytes
bnb_4bit_quant_type: str = "nf4"
bnb_4bit_compute_dtype: str = "bfloat16"
bnb_4bit_use_double_quant: bool = True
@dataclass
class DistillArguments:
kv_sd: bool = False
kv_sd_alpha: float = 0.1
kv_sd_token_stride: int = 4
kv_sd_layers: str = "all" # "all" | "every2" | "last4"
kv_sd_quantize_teacher: bool = True
# 単一ロード(学生モデルを教師として共有: LoRA一時無効でforward)
kv_sd_share_student: bool = True
# ---------------- HF auth helpers ----------------
def _get_hf_token(model_args: ModelArguments, data_args: DataArguments) -> Optional[str]:
tok = data_args.hf_data_token or model_args.hf_token
if not tok:
tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
return tok
def _hub_kwargs(hf_token: Optional[str]) -> dict:
if hf_token:
# NOTE: Transformers v5ではuse_auth_tokenは廃止。tokenのみ渡す。
return {"token": hf_token}
return {}
# ---------------- IO ----------------
def _read_json_or_jsonl(path):
if path and path.endswith(".jsonl"):
rows = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
elif path:
return json.load(open(path, "r", encoding="utf-8"))
return None
# ---------------- SFT format ----------------
def _detect_and_to_conversations(ex: Dict[str, Any]) -> List[Dict[str, str]]:
# 既存の会話形式があればそのまま使う
if "conversations" in ex and isinstance(ex["conversations"], list):
return ex["conversations"]
if "messages" in ex and isinstance(ex["messages"], list):
conv = []
for m in ex["messages"]:
r = m.get("role"); c = m.get("content", "")
if r == "user":
conv.append({"from":"user","value":c})
elif r == "assistant":
conv.append({"from":"assistant","value":c})
return conv
if "prompt" in ex and "response" in ex:
return [{"from":"user","value":ex["prompt"]},{"from":"assistant","value":ex["response"]}]
# question/answer/cot/options から素直に組み立て
if "question" in ex or "answer" in ex or "cot" in ex or "options" in ex:
q = (ex.get("question") or "").strip()
ans = (ex.get("answer") or "").strip() # 任意形式OK
cot = (ex.get("cot") or "").strip()
user = q
if "options" in ex and ex["options"]:
raw = ex["options"]
if isinstance(raw, list):
opts_txt = []
for o in raw:
opts_txt.append(o.get("text", o) if isinstance(o, dict) else str(o))
user = (q + ("\n" if q and opts_txt else "") + "\n".join(opts_txt)).strip()
# ---- 修正: CoTを先に、Answerを後ろに ----
assistant = (((f"<think>\n{cot}\n</think>\n\n") if cot else "") + (ans or "")).strip()
return [{"from":"user","value":user},{"from":"assistant","value":assistant}]
raise KeyError("Unsupported sample format; need conversations/messages or prompt/response or question/answer/cot/options.")
def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, max_len: int, system_message=None)->Dict:
"""
Qwen-3流のChatML整形:
- systemは前置しない
- <|im_start|>role \\n content <|im_end|> を user/assistant 交互に並べる
"""
im_start = getattr(tokenizer, "im_start_id", None) or tokenizer.convert_tokens_to_ids("<|im_start|>")
im_end = getattr(tokenizer, "im_end_id", None) or tokenizer.convert_tokens_to_ids("<|im_end|>")
roles = {"user":"<|im_start|>user","assistant":"<|im_start|>assistant"}
nl = tokenizer("\n").input_ids
input_ids, targets = [], []
for raw in sources:
source = _detect_and_to_conversations(raw)
if not source or source[0].get("from") != "user":
source = [s for s in source if s.get("from") in ("user","assistant")]
if source and source[0].get("from") != "user":
source = source[1:]
ids, tgt = [], [] # system は一切挿入しない
for sent in source:
role_tok = roles[sent["from"]]
_ids = tokenizer(role_tok).input_ids + nl + tokenizer(sent["value"]).input_ids + [im_end] + nl
ids += _ids
if role_tok == "<|im_start|>user":
_tgt = [im_start] + [IGNORE_TOKEN_ID]*(len(_ids)-3) + [im_end] + nl
else:
pref = tokenizer(role_tok).input_ids
_tgt = [im_start] + [IGNORE_TOKEN_ID]*len(pref) + _ids[len(pref)+1:-2] + [im_end] + nl
tgt += _tgt
pad_id = tokenizer.pad_token_id
ids = (ids + [pad_id]*(max_len - len(ids)))[:max_len]
tgt = (tgt + [IGNORE_TOKEN_ID]*(max_len - len(tgt)))[:max_len]
input_ids.append(ids); targets.append(tgt)
input_ids = torch.tensor(input_ids, dtype=torch.int)
targets = torch.tensor(targets, dtype=torch.int)
return dict(input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id))
# ---------------- Datasets ----------------
class SupervisedDataset(Dataset):
def __init__(self, rows, tokenizer, max_len):
rank0_print("Formatting inputs...")
# keep raw for later saving (does not affect training)
self.raw = rows
d = preprocess(rows, tokenizer, max_len)
self.input_ids=d["input_ids"]; self.labels=d["labels"]; self.attention_mask=d["attention_mask"]
def __len__(self): return len(self.input_ids)
def __getitem__(self, i):
return dict(input_ids=self.input_ids[i], labels=self.labels[i], attention_mask=self.attention_mask[i])
class LazySupervisedDataset(Dataset):
def __init__(self, rows, tokenizer, max_len):
self.tokenizer=tokenizer; self.max_len=max_len; self.raw=rows; self.cached={}
rank0_print("Formatting inputs...Skip in lazy mode")
def __len__(self): return len(self.raw)
def __getitem__(self, i):
if i in self.cached: return self.cached[i]
d = preprocess([self.raw[i]], self.tokenizer, self.max_len)
out = dict(input_ids=d["input_ids"][0], labels=d["labels"][0], attention_mask=d["attention_mask"][0])
self.cached[i]=out; return out
def _dataset_module_from_rows(tokenizer, data_args, max_len, train_rows, eval_rows=None):
cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
return dict(train_dataset=cls(train_rows, tokenizer, max_len),
eval_dataset=cls(eval_rows, tokenizer, max_len) if eval_rows else None)
def _dataset_module_from_files(tokenizer, data_args, max_len):
cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
rank0_print("Loading data from file(s)...")
train_rows = _read_json_or_jsonl(data_args.data_path)
eval_rows = _read_json_or_jsonl(data_args.eval_data_path) if data_args.eval_data_path else None
return dict(train_dataset=cls(train_rows, tokenizer, max_len),
eval_dataset=cls(eval_rows, tokenizer, max_len) if eval_rows else None)
# ---------------- HF -> conversations ----------------
def to_conversations_from_hf_row(ex: Dict[str, Any], letters: str="ABCDE", use_cot: bool=True, cot_required: bool=True):
"""
可能なら 'conversations'/'messages'/'prompt&response' をそのまま使う。
それ以外は question/answer/cot/options から“固定プロンプトなし”で素直に組み立てる。
answer は任意形式で許容(A〜Eなどの強制はしない)。
"""
# 既定ハンドリング(あるがまま)
if "conversations" in ex and isinstance(ex["conversations"], list):
return {"conversations": ex["conversations"]}
if "messages" in ex and isinstance(ex["messages"], list):
conv = []
for m in ex["messages"]:
r = m.get("role"); c = m.get("content", "")
if r == "user":
conv.append({"from":"user","value":c})
elif r == "assistant":
conv.append({"from":"assistant","value":c})
return {"conversations": conv}
if "prompt" in ex and "response" in ex:
return {"conversations":[
{"from":"user","value": (ex.get("prompt") or "").strip()},
{"from":"assistant","value": (ex.get("response") or "").strip()},
]}
# question/answer/cot/options
q = (ex.get("question") or "").strip()
ans = (ex.get("answer") or "").strip() # 任意形式OK
cot = (ex.get("cot") or "").strip()
if cot_required and not cot:
return None
user = q
if "options" in ex and ex["options"]:
raw = ex["options"]
if isinstance(raw, list):
opts_txt = []
for o in raw:
opts_txt.append(o.get("text", o) if isinstance(o, dict) else str(o))
user = (q + ("\n" if q and opts_txt else "") + "\n".join(opts_txt)).strip()
# ---- 修正: CoTを先に、Answerを後ろに(use_cotがTrueのとき)----
assistant = ans
if use_cot and cot:
assistant = (f"<think>\n{cot}\n</think>\n\n{ans}".strip()) if ans else f"<think>\n{cot}\n</think>"
elif (not use_cot) and cot and (not ans):
# CoTのみを教師信号にしたいケースも一応許容
assistant = cot
return {"conversations":[{"from":"user","value":user},{"from":"assistant","value":assistant}]}
def load_and_convert_hf(dataset_name: str, split: str, letters: str, use_cot: bool, cot_required: bool, hf_token: Optional[str])->List[Dict[str, Any]]:
rank0_print(f"Loading HF dataset: {dataset_name} [{split}]")
ds = load_dataset(dataset_name, split=split, **_hub_kwargs(hf_token))
out = []; skip = 0
for i, ex in enumerate(ds):
item = to_conversations_from_hf_row(ex, letters=letters, use_cot=use_cot, cot_required=cot_required)
if item is None:
skip += 1
continue
# Attach non-intrusive metadata columns for reproducibility
meta_item = dict(item) # shallow copy; 'conversations' stays the same
meta_item["_source"] = {"hf_dataset": dataset_name, "split": split, "index": i}
# Keep original row minimally for traceability (does not affect training)
meta_item["_orig"] = ex
out.append(meta_item)
rank0_print(f"[HF] converted {len(out)} samples (skipped={skip})")
return out
# --------- multi-source parsing & sampling ----------
def _parse_hf_source(spec: str) -> Tuple[str, str, Optional[str]]:
parts = spec.split(":")
name = parts[0].strip()
split = parts[1].strip() if len(parts) >= 2 and parts[1].strip() else "train"
amount = parts[2].strip() if len(parts) >= 3 and parts[2].strip() else None
return name, split, amount
def _parse_amount(amount_str: Optional[str], total: int) -> int:
if amount_str is None or amount_str.lower() == "all":
return total
try:
if "." in amount_str or amount_str.startswith("0"):
frac = float(amount_str)
if frac <= 0: return 0
if frac <= 1.0: return max(1, int(total * frac))
n = int(float(amount_str))
return max(0, min(total, n))
except Exception:
return total
def _sanitize_name_for_path(name: str) -> str:
return re.sub(r"[^A-Za-z0-9_.-]+", "_", name.replace("/", "__"))
def _sample_rows(rows: List[Dict[str, Any]], n: int, seed: int) -> List[Dict[str, Any]]:
if n >= len(rows): return rows
rnd = random.Random(seed)
idx = list(range(len(rows))); rnd.shuffle(idx); idx = idx[:n]
return [rows[i] for i in idx]
# --------- LoRA utils & save ----------
def _to_dtype(name: str):
n = (name or "").lower()
return {
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
"float32": torch.float32, "fp32": torch.float32
}.get(n, torch.bfloat16)
def _auto_targets(model_name: str, current: List[str])->List[str]:
m = (model_name or "").lower()
if any(k in m for k in ["qwen","llama","yi","mistral","mixtral","qwen2","qwen3","phi-3","glm"]):
return ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
return current or ["q_proj","k_proj","v_proj","o_proj"]
def maybe_zero_3(param):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
lora_bias_names.add(k.split("lora_")[0] + "bias")
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias.items():
if k in lora_bias_names:
to_return[k] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
if is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
else:
if getattr(trainer.args, "use_lora", False):
state_dict = get_peft_state_maybe_zero_3(trainer.model.named_parameters(), bias)
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and (trainer.args.local_rank in (0, -1)):
trainer._save(output_dir, state_dict=state_dict)
# ---------------- KV self-distill utils ----------------
def _collect_kv_outputs(model, inputs, token_stride=1):
kv = {"k": [], "v": []}
handles = []
def make_hook(kind):
def _hook(_m, _inp, out):
t = out
if token_stride > 1:
t = t[:, ::token_stride, :]
kv[kind].append(t)
return _hook
for name, module in model.named_modules():
if name.endswith("k_proj") and hasattr(module, "register_forward_hook"):
handles.append(module.register_forward_hook(make_hook("k")))
elif name.endswith("v_proj") and hasattr(module, "register_forward_hook"):
handles.append(module.register_forward_hook(make_hook("v")))
outputs = model(**inputs)
for h in handles: h.remove()
found = len(kv["k"]) > 0 and len(kv["v"]) > 0
return outputs, kv, found
def _hidden_states_distill(student_out, teacher_out, layers_mode="last4", token_stride=1, ref_device=None):
if not (hasattr(student_out, "hidden_states") and hasattr(teacher_out, "hidden_states")
and student_out.hidden_states and teacher_out.hidden_states):
dev = ref_device or student_out.logits.device
return torch.tensor(0.0, device=dev)
s_hs = student_out.hidden_states; t_hs = teacher_out.hidden_states
if layers_mode == "last4": s_sel = s_hs[-4:]; t_sel = t_hs[-4:]
elif layers_mode == "every2": s_sel = s_hs[::2]; t_sel = t_hs[::2]
else: s_sel = s_hs; t_sel = t_hs
dev = ref_device or student_out.logits.device
loss = None; n = 0
for s, t in zip(s_sel, t_sel):
if token_stride > 1: s = s[:, ::token_stride, :]; t = t[:, ::token_stride, :]
d = min(s.size(-1), t.size(-1))
s = s.to(dev); t = t.to(dev)
mse = F.mse_loss(s[..., :d], t[..., :d])
loss = mse if loss is None else (loss + mse)
n += 1
if n == 0:
return torch.tensor(0.0, device=dev)
return loss / n
def _kv_mse(student_kv, teacher_kv, layers_mode="all", ref_device=None):
def _select(seq, mode):
if mode == "every2": return seq[::2]
if mode == "last4": return seq[-4:]
return seq
dev = ref_device or (student_kv["k"][0].device if student_kv["k"] else student_kv["v"][0].device)
loss = None; n = 0
s_k = _select(student_kv["k"], layers_mode); s_v = _select(student_kv["v"], layers_mode)
t_k = _select(teacher_kv["k"], layers_mode); t_v = _select(teacher_kv["v"], layers_mode)
for sk, tk in zip(s_k, t_k):
d = min(sk.size(-1), tk.size(-1))
sk = sk.to(dev); tk = tk.to(dev)
mse = F.mse_loss(sk[..., :d], tk[..., :d])
loss = mse if loss is None else (loss + mse)
n += 1
for sv, tv in zip(s_v, t_v):
d = min(sv.size(-1), tv.size(-1))
sv = sv.to(dev); tv = tv.to(dev)
mse = F.mse_loss(sv[..., :d], tv[..., :d])
loss = mse if loss is None else (loss + mse)
n += 1
if n == 0:
return torch.tensor(0.0, device=dev)
return loss / n
class KVSDTrainer(transformers.Trainer):
def __init__(self, *args, teacher_model=None, kv_sd_alpha=0.1, kv_sd_layers="all", kv_sd_token_stride=1, kv_sd_share_student=True, **kwargs):
super().__init__(*args, **kwargs)
self.teacher_model = teacher_model
self.kv_sd_alpha = kv_sd_alpha
self.kv_sd_layers = kv_sd_layers
self.kv_sd_token_stride = kv_sd_token_stride
self.kv_sd_share_student = kv_sd_share_student
if self.teacher_model is not None:
self.teacher_model.eval()
for p in self.teacher_model.parameters():
p.requires_grad = False
def _teacher_forward_shared(self, model, inputs):
# LoRA/adapterを一時無効にして「教師」パスを取る(単一ロード)
try:
ctx = model.disable_adapter()
except AttributeError:
ctx = nullcontext()
with torch.no_grad():
with ctx:
t_outputs, t_kv, t_has_kv = _collect_kv_outputs(
model, {**inputs, "output_hidden_states": True},
token_stride=self.kv_sd_token_stride
)
return t_outputs, t_kv, t_has_kv
# Transformers>=4.44 で num_items_in_batch が渡るため対応
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
# 1) 教師パス
if self.teacher_model is not None:
with torch.no_grad():
t_outputs, t_kv, t_has_kv = _collect_kv_outputs(
self.teacher_model, {**inputs, "output_hidden_states": True},
token_stride=self.kv_sd_token_stride
)
elif self.kv_sd_share_student:
t_outputs, t_kv, t_has_kv = self._teacher_forward_shared(model, {**inputs})
else:
t_outputs, t_kv, t_has_kv = None, None, False
# 2) 学習パス
outputs, s_kv, s_has_kv = _collect_kv_outputs(
model, {**inputs, "output_hidden_states": True},
token_stride=self.kv_sd_token_stride
)
ce_loss = outputs.loss
ref_dev = ce_loss.device # ここに集約
# 3) 蒸留ロス
distill_loss = torch.tensor(0.0, device=ref_dev)
if t_outputs is not None:
if s_has_kv and t_has_kv:
distill_loss = _kv_mse(s_kv, t_kv, layers_mode=self.kv_sd_layers, ref_device=ref_dev)
else:
mode = "last4" if self.kv_sd_layers=="last4" else ("every2" if self.kv_sd_layers=="every2" else "all")
distill_loss = _hidden_states_distill(outputs, t_outputs, layers_mode=mode, token_stride=self.kv_sd_token_stride, ref_device=ref_dev)
loss = ce_loss + (self.kv_sd_alpha * distill_loss)
return (loss, outputs) if return_outputs else loss
# ---------------- build model/tokenizer ----------------
def make_model_and_tokenizer(model_args, training_args, lora_args, hf_token):
device_map = None if (int(os.environ.get("WORLD_SIZE", 1)) != 1 or is_deepspeed_zero3_enabled()) else "auto"
config = transformers.AutoConfig.from_pretrained(
model_args.model_name_or_path, cache_dir=training_args.cache_dir, trust_remote_code=True, **_hub_kwargs(hf_token)
)
config.use_cache = False
bnb_config = None
if training_args.use_lora and lora_args.q_lora:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=lora_args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=_to_dtype(lora_args.bnb_4bit_compute_dtype),
bnb_4bit_use_double_quant=lora_args.bnb_4bit_use_double_quant,
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
device_map=device_map,
trust_remote_code=True,
quantization_config=bnb_config,
low_cpu_mem_usage=not is_deepspeed_zero3_enabled(),
**_hub_kwargs(hf_token),
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
trust_remote_code=True,
**_hub_kwargs(hf_token),
)
if getattr(tokenizer, "pad_token_id", None) is None:
if hasattr(tokenizer, "eod_id") and tokenizer.eod_id is not None:
tokenizer.pad_token_id = tokenizer.eod_id
else:
tokenizer.add_special_tokens({"pad_token":"<|pad|>"})
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
model.resize_token_embeddings(len(tokenizer))
return model, tokenizer, device_map
# ---------------- helper: save used rows ----------------
def _save_used_rows_jsonl(out_dir: str, filename: str, rows: List[Dict[str, Any]], local_path: Optional[str] = None):
try:
if not rows:
return
path = os.path.join(out_dir, filename)
with open(path, "w", encoding="utf-8") as f:
for r in rows:
obj = r
# For local data, add non-intrusive source/meta without mutating training objects
if local_path is not None:
obj = copy.deepcopy(r)
if "_source" not in obj:
obj["_source"] = {"local_path": local_path}
if "_orig" not in obj:
obj["_orig"] = r
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
rank0_print(f"[OK] Wrote: {path}")
except Exception as e:
rank0_print(f"[warn] failed to write {filename}: {e}")
# ---------------- train ----------------
def train():
global local_rank
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments, DistillArguments))
(model_args, data_args, training_args, lora_args, distill_args) = parser.parse_args_into_dataclasses()
# HF token
hf_token = _get_hf_token(model_args, data_args)
if hf_token and not os.environ.get("HUGGING_FACE_HUB_TOKEN"):
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
# deepspeed single-GPU hint
if getattr(training_args, "deepspeed", None) and int(os.environ.get("WORLD_SIZE", 1)) == 1:
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
local_rank = training_args.local_rank
# model/tokenizer
model, tokenizer, device_map = make_model_and_tokenizer(model_args, training_args, lora_args, hf_token)
# LoRA
if training_args.use_lora:
# "a,b,c" を list へ
if isinstance(lora_args.lora_target_modules, list) and len(lora_args.lora_target_modules)==1 \
and isinstance(lora_args.lora_target_modules[0], str) and "," in lora_args.lora_target_modules[0]:
lora_args.lora_target_modules = [s.strip() for s in lora_args.lora_target_modules[0].split(",") if s.strip()]
if not lora_args.lora_target_modules:
lora_args.lora_target_modules = _auto_targets(model_args.model_name_or_path, lora_args.lora_target_modules)
lora_config = LoraConfig(
r=lora_args.lora_r, lora_alpha=lora_args.lora_alpha, target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout, bias=lora_args.lora_bias, task_type="CAUSAL_LM", modules_to_save=None,
)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
# data module
used_sources_meta = None
data_tag = None
all_rows = None
eval_rows = None
if data_args.data_path:
data_module = _dataset_module_from_files(tokenizer, data_args, training_args.model_max_length)
# For saving: get raw rows from dataset object
try:
all_rows = getattr(data_module["train_dataset"], "raw", None)
eval_rows = getattr(data_module.get("eval_dataset", None), "raw", None) if data_module.get("eval_dataset", None) else None
except Exception:
all_rows = None; eval_rows = None
else:
sources = list(data_args.hf_source)
if (not sources) and data_args.hf_dataset:
sources = [f"{data_args.hf_dataset}:{data_args.hf_split}:all"]
if not sources:
raise ValueError("Please specify either --data_path or at least one --hf_source (or legacy --hf_dataset).")
all_rows = []
tag_parts: List[str] = []
used_sources_meta = []
for spec in sources:
name, split, amount_str = _parse_hf_source(spec)
rows_full = load_and_convert_hf(
dataset_name=name, split=split,
letters=data_args.answer_letters, use_cot=data_args.use_cot, cot_required=data_args.cot_required,
hf_token=hf_token
)
total_avail = len(rows_full)
use_n = _parse_amount(amount_str, total_avail)
rows_used = _sample_rows(rows_full, use_n, seed=data_args.sample_shuffle_seed)
all_rows.extend(rows_used)
tag_parts.append(f"{_sanitize_name_for_path(name)}-{len(rows_used)}")
used_sources_meta.append({
"name": name, "split": split,
"amount_spec": amount_str if amount_str is not None else "all",
"total_available": total_avail,
"used_count": len(rows_used),
"seed": data_args.sample_shuffle_seed
})
data_tag = "+".join(tag_parts) if tag_parts else None
if data_args.append_data_tag and data_tag:
base_out = training_args.output_dir or "./output_qwen"
training_args.output_dir = os.path.join(base_out, data_tag)
os.makedirs(training_args.output_dir or "./output_qwen", exist_ok=True)
if data_args.hf_eval_split:
name0, _, _ = _parse_hf_source(sources[0])
eval_rows = load_and_convert_hf(
dataset_name=name0, split=data_args.hf_eval_split,
letters=data_args.answer_letters, use_cot=data_args.use_cot, cot_required=data_args.cot_required,
hf_token=hf_token
)
data_module = _dataset_module_from_rows(tokenizer, data_args, training_args.model_max_length, all_rows, eval_rows)
# teacher(共有しない場合のみ別ロード)
teacher_model = None
if distill_args.kv_sd and not distill_args.kv_sd_share_student:
teacher_bnb = None
if distill_args.kv_sd_quantize_teacher:
teacher_bnb = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
)
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, **_hub_kwargs(hf_token)),
cache_dir=training_args.cache_dir,
device_map=(None if (int(os.environ.get("WORLD_SIZE",1))!=1 or is_deepspeed_zero3_enabled()) else "auto"),
trust_remote_code=True,
quantization_config=teacher_bnb,
low_cpu_mem_usage=True,
**_hub_kwargs(hf_token),
)
teacher_model.eval()
for p in teacher_model.parameters(): p.requires_grad = False
# save run config snapshot
out_dir = training_args.output_dir or "./output_qwen"
os.makedirs(out_dir, exist_ok=True)
try:
run_cfg = {
"argv": sys.argv,
"model_args": asdict(model_args),
"data_args": asdict(data_args),
"training_args": {k: getattr(training_args, k) for k in vars(training_args) if not k.startswith("_")},
"lora_args": asdict(lora_args),
"distill_args": asdict(distill_args),
"resolved": {
"output_dir": out_dir,
"used_sources": used_sources_meta,
"data_tag": data_tag,
"hf_token_provided": bool(_get_hf_token(model_args, data_args)),
},
}
with open(os.path.join(out_dir, "run_config.json"), "w", encoding="utf-8") as f:
json.dump(run_cfg, f, ensure_ascii=False, indent=2, default=str)
with open(os.path.join(out_dir, "cmd.txt"), "w", encoding="utf-8") as f:
f.write(" ".join(sys.argv) + "\n")
except Exception as e:
rank0_print(f"[warn] failed to write run_config.json: {e}")
# Save the exact used rows (train/eval) into output_dir with metadata columns
try:
if all_rows is None:
# fall back to dataset.raw (works for local/json and lazy mode)
all_rows = getattr(data_module["train_dataset"], "raw", None)
if all_rows:
# if local file, annotate with source path at write time
_save_used_rows_jsonl(out_dir, "used_train_rows.jsonl", all_rows, local_path=data_args.data_path if data_args.data_path else None)
if eval_rows:
_save_used_rows_jsonl(out_dir, "used_eval_rows.jsonl", eval_rows, local_path=data_args.eval_data_path if data_args.eval_data_path else None)
except Exception as e:
rank0_print(f"[warn] failed to persist used dataset rows: {e}")
# Trainer
trainer_cls = KVSDTrainer if distill_args.kv_sd else Trainer
trainer = trainer_cls(
model=model, tokenizer=tokenizer, args=training_args, **data_module,
**({"teacher_model": teacher_model,
"kv_sd_alpha": distill_args.kv_sd_alpha,
"kv_sd_layers": distill_args.kv_sd_layers,
"kv_sd_token_stride": distill_args.kv_sd_token_stride,
"kv_sd_share_student": distill_args.kv_sd_share_student} if distill_args.kv_sd else {})
)
# Train & Save
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
rank0_print(f"[OK] Saved to: {training_args.output_dir}")
if __name__ == "__main__":
train()
PY
2.4. 実行
kv_sd_alpha 0.9がポイントです。これがHARIです。
export HUGGING_FACE_HUB_TOKEN=
python -u ~/finetune_hf_sft.py \
--model_name_or_path /nvme34/Qwen/Qwen3-235B-A22B-Thinking-2507 \
--hf_source oNo-1/MedMCQA:train:10 \
--output_dir ./runs/qwen235b_medmcqa_lora_attnonly \
--bf16 True --gradient_checkpointing True \
--use_lora True --q_lora True \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--lora_r 8 --lora_alpha 16 --lora_dropout 0.1 \
--bnb_4bit_quant_type nf4 --bnb_4bit_compute_dtype bfloat16 --bnb_4bit_use_double_quant True \
--optim adamw_bnb_8bit \
--model_max_length 10000 \
--per_device_train_batch_size 1 --gradient_accumulation_steps 16 \
--num_train_epochs 1 \
--learning_rate 1e-4 \
--logging_steps 20 --save_strategy "no" \
--do_eval False \
--dataloader_num_workers 6 --group_by_length True \
--lazy_preprocess True \
--kv_sd --kv_sd_alpha 0.9
まとめ
-
HARIの要点:少量SFTに忘却抑制を重ね、ベース挙動を崩さず HLEで安定上振れを狙う軽量レシピ。
- 予選=KL正則化(λ=0.9)
- 決勝=KV自己蒸留(α=0.9)
-
予選の気づき:DeepSeek-R1-0528 に MedMCQA 10サンプル×1ep+KL=0.9 を当てると、チーム内/予選評価ともベース超え。
-
決勝(Qwen3-235B-A22B-Thinking-2507):同目的で KV自己蒸留(α=0.9) を適用すると同様に上昇。
対照として 素のSFT や last4のみ は低下 → α=0.9 が効因。 -
スケール検証:500サンプル×1epが特に効く傾向。約30モデル・多様データで例外なく上昇、最終+2.4% を確認。偶然ではなく再現性あり。
「少量SFT × 忘却抑制(予選=KL0.9/決勝=KV0.9)」で、Qwen3-235B を含む大規模モデルでも低コストに安定した増が得られる──これが HARI の核心です。
本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。