5
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Qwen3-235B-A22B-Thinking-2507にHARIする方法について

Last updated at Posted at 2025-10-27

はじめに

本記事では Qwen3-235B-A22B-Thinking-2507 に HARI する方法を紹介します。

背景として、予選時HLE のスコアが高いベースモデル(deepseek-ai/DeepSeek-R1-0528)を、なるべく忘却を防いで軽く SFT して提出するという戦略を取りました(予選1位)。
具体的には MedMCQA 10 サンプルを 1 エポック、KL正則化(λ=0.9) で SFT。ベースモデルから何も変わっていないだろうと思っていたものの、チーム内評価・予選評価ともにスコアはベースより上昇し、この時点では「運の良い偶然」と見なしていました。

決勝で Qwen / Qwen3-235B-A22B-Thinking-2507 が解禁され、同様の目的でKV自己蒸留 (α=0.9)を適用したところ同様にスコアが上昇。一方で、他パラメータは変えずに “ただの SFT” や “last4 のみ” を試すとスコアは低下しました。ここで kv=0.9 に意味があることは分かったものの、当時はまだ「謎チューニング」と呼び、戦略の主軸に据える決断には至っていませんでした。

以降、1 万サンプル超の SFT や DFT を試してもスコアは下がり、RL はコストが高すぎるという制約もあり、結果的になぜかスコアが上がる “謎チューニング” で進めることに。検証の過程で、500 サンプル・1 エポックが特に効くことが分かり、約 30 モデル・多様なデータで学習しても例外なくスコアが上がるという挙動を確認しました。最終的には約 2.4% の向上を得られ、「謎チューニング」という名はふさわしくないため、鍼(はり)のようにモデルの潜在能力を高めるという意味を込めて HARI と命名しています。

以降の章では、上記の経緯を踏まえた Qwen3-235B-A22B-Thinking-2507 への具体的な HARI 手順とスクリプトを示します。

1. 環境構築

以下のcondaを作成してください。
Condaバージョン: Miniconda 24.7.1
Python: 3.11

cat > setup_qwen_h100_one_shot.sh <<'SH'
#!/usr/bin/env bash
set -eo pipefail   # -u は付けない(/etc/bashrc 問題を避ける)

ENV_NAME=${ENV_NAME:-qwen_sft}
PY_VER=${PY_VER:-3.10}
QWEN_DIR="${HOME}/Qwen"

# 0) conda をロード(~/.bashrcを読まない安全ルート)
eval "$($(conda info --base)/bin/conda shell.bash hook)"

# 1) 同名環境があれば削除
if conda env list | awk '{print $1}' | grep -qx "${ENV_NAME}"; then
  echo "[INFO] Removing existing env: ${ENV_NAME}"
  conda deactivate || true
  conda remove -y -n "${ENV_NAME}" --all
fi

# 2) 新規作成&有効化
echo "[INFO] Creating env: ${ENV_NAME} (python=${PY_VER})"
conda create -y -n "${ENV_NAME}" python="${PY_VER}"
conda activate "${ENV_NAME}"

# 3) PyTorch (CUDA 12.4, H100 OK)
pip install --index-url https://download.pytorch.org/whl/cu124 \
  torch torchvision torchaudio

# 4) CUDA Toolkit を conda で導入(nvcc を確実に入れる)
conda install -y -c nvidia cuda-toolkit=12.4

# 5) CUDA_HOME をこのenvに固定(永続化)
export CUDA_HOME="$CONDA_PREFIX"
mkdir -p "$CONDA_PREFIX/etc/conda/activate.d" "$CONDA_PREFIX/etc/conda/deactivate.d"
cat > "$CONDA_PREFIX/etc/conda/activate.d/cuda_home.sh" <<'EOS'
export CUDA_HOME="$CONDA_PREFIX"
export PATH="$CUDA_HOME/bin:$PATH"
EOS
cat > "$CONDA_PREFIX/etc/conda/deactivate.d/cuda_home.sh" <<'EOS'
unset CUDA_HOME
EOS

echo "[INFO] CUDA_HOME set to $CUDA_HOME"
if [ -x "$CUDA_HOME/bin/nvcc" ]; then
  "$CUDA_HOME/bin/nvcc" --version || true
else
  echo "[ERROR] nvcc not found in $CUDA_HOME/bin" >&2
  exit 1
fi

# 6) 依存(ビルドに必要なもの含む)
pip install -U pip setuptools wheel packaging ninja

# 7) DeepSpeed をビルドインストール(CUDAあり)
pip install deepspeed==0.17.5 --no-build-isolation

# 8) そのほか必要ライブラリ
pip install -U transformers datasets peft accelerate trl hf_transfer
pip install "bitsandbytes>=0.43.3"

# Flash-Attention:ビルド環境が整っていれば有効。不要ならコメントのままでOK
pip install flash-attn --no-build-isolation

# 9) Qwen 公式リポジトリ(finetune.py を使うため)
if [ -d "${QWEN_DIR}/.git" ]; then
  echo "[INFO] Updating ${QWEN_DIR}"
  git -C "${QWEN_DIR}" pull --ff-only || true
else
  echo "[INFO] Cloning Qwen repo to ${QWEN_DIR}"
  git clone https://github.com/QwenLM/Qwen.git "${QWEN_DIR}"
fi

# 10) 検証:torch/DS/CUDA が見えるか
python - <<'PY'
import os, torch, deepspeed, subprocess
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("DeepSpeed:", deepspeed.__version__)
print("CUDA_HOME:", os.environ.get("CUDA_HOME"))
nvcc = os.path.join(os.environ.get("CUDA_HOME",""), "bin", "nvcc")
print("Has nvcc:", os.path.exists(nvcc))
try:
    if os.path.exists(nvcc):
        subprocess.run([nvcc, "--version"], check=False)
except Exception as e:
    print("nvcc check error:", e)
PY

echo "[OK] Environment '${ENV_NAME}' is ready."
echo "Next: conda activate ${ENV_NAME}"
SH

bash setup_qwen_h100_one_shot.sh

2. HARI

2.1. ノード確保

srun --partition P10 --nodes=1 --gpus-per-node=8 --cpus-per-task=240 --time=48:00:00 --pty bash -i

2.2. condaアクティベート

conda activate qwen_sft

2.3. スクリプト

cat > finetune_hf_sft.py <<'PY'
# finetune_hf_sft.py
# - CoT任意学習(answerは任意形式、<think> も任意で付与)
# - KV自己蒸留(KV/HS、層選択、トークン間引き)※単一ロードで教師パス取得可
# - 複数HFデータセット / 使用量指定(name[:split[:amount]] を複数 --hf_source)
# - 出力フォルダ名にデータ名&件数付与 / 実行設定(run_config.json, cmd.txt)保存
# - HFトークン必要モデル/データに対応(--hf_token / --hf_data_token / 環境変数)
# - 成功時に保存先パスをプリント
# - TransformersのTrainer互換(num_items_in_batch対応)
# - Qwen-3流のChatML(systemは前置しない)

from dataclasses import dataclass, field, asdict
import json, os, re, random, sys
from typing import Dict, Optional, List, Any, Tuple
import copy

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from contextlib import nullcontext

import transformers
from transformers import Trainer, BitsAndBytesConfig, HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_pt_utils import LabelSmoother

from accelerate.utils import DistributedType
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

from datasets import load_dataset

# ---------------- Consts / utils ----------------
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
local_rank = None
def rank0_print(*a):
    if local_rank in (0, None):
        print(*a)

# ---------------- Args ----------------
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B")
    hf_token: Optional[str] = field(default=None, metadata={"help":"HF auth token for models (optional). If omitted, env or cached login is used."})

@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help":"Local json/jsonl path (optional)."})
    eval_data_path: str = field(default=None)
    lazy_preprocess: bool = False
    # HF datasets:
    hf_dataset: Optional[str] = field(default=None, metadata={"help":"Legacy single dataset name (use --hf_source instead)."})
    hf_split: str = field(default="train")
    hf_eval_split: Optional[str] = field(default=None)
    hf_source: List[str] = field(default_factory=list, metadata={"help":"Repeatable: name[:split[:amount]] e.g. oNo-1/MedMCQA:train:all"})
    hf_data_token: Optional[str] = field(default=None, metadata={"help":"HF auth token for datasets (optional)"})
    # formatting / filtering
    use_cot: bool = field(default=True)
    cot_required: bool = field(default=True)
    answer_letters: str = field(default="ABCDE")  # 互換のため残置(使わない)
    sample_shuffle_seed: int = 42
    append_data_tag: bool = True  # append dataset tag to output_dir

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=16384)
    use_lora: bool = False
    interactive: bool = False

@dataclass
class LoraArguments:
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
    lora_weight_path: str = ""
    lora_bias: str = "none"
    q_lora: bool = False
    # bitsandbytes
    bnb_4bit_quant_type: str = "nf4"
    bnb_4bit_compute_dtype: str = "bfloat16"
    bnb_4bit_use_double_quant: bool = True

@dataclass
class DistillArguments:
    kv_sd: bool = False
    kv_sd_alpha: float = 0.1
    kv_sd_token_stride: int = 4
    kv_sd_layers: str = "all"          # "all" | "every2" | "last4"
    kv_sd_quantize_teacher: bool = True
    # 単一ロード(学生モデルを教師として共有: LoRA一時無効でforward)
    kv_sd_share_student: bool = True

# ---------------- HF auth helpers ----------------
def _get_hf_token(model_args: ModelArguments, data_args: DataArguments) -> Optional[str]:
    tok = data_args.hf_data_token or model_args.hf_token
    if not tok:
        tok = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
    return tok

def _hub_kwargs(hf_token: Optional[str]) -> dict:
    if hf_token:
        # NOTE: Transformers v5ではuse_auth_tokenは廃止。tokenのみ渡す。
        return {"token": hf_token}
    return {}

# ---------------- IO ----------------
def _read_json_or_jsonl(path):
    if path and path.endswith(".jsonl"):
        rows = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    rows.append(json.loads(line))
        return rows
    elif path:
        return json.load(open(path, "r", encoding="utf-8"))
    return None

# ---------------- SFT format ----------------
def _detect_and_to_conversations(ex: Dict[str, Any]) -> List[Dict[str, str]]:
    # 既存の会話形式があればそのまま使う
    if "conversations" in ex and isinstance(ex["conversations"], list):
        return ex["conversations"]
    if "messages" in ex and isinstance(ex["messages"], list):
        conv = []
        for m in ex["messages"]:
            r = m.get("role"); c = m.get("content", "")
            if r == "user":
                conv.append({"from":"user","value":c})
            elif r == "assistant":
                conv.append({"from":"assistant","value":c})
        return conv
    if "prompt" in ex and "response" in ex:
        return [{"from":"user","value":ex["prompt"]},{"from":"assistant","value":ex["response"]}]
    # question/answer/cot/options から素直に組み立て
    if "question" in ex or "answer" in ex or "cot" in ex or "options" in ex:
        q   = (ex.get("question") or "").strip()
        ans = (ex.get("answer")   or "").strip()  # 任意形式OK
        cot = (ex.get("cot")      or "").strip()
        user = q
        if "options" in ex and ex["options"]:
            raw = ex["options"]
            if isinstance(raw, list):
                opts_txt = []
                for o in raw:
                    opts_txt.append(o.get("text", o) if isinstance(o, dict) else str(o))
                user = (q + ("\n" if q and opts_txt else "") + "\n".join(opts_txt)).strip()
        # ---- 修正: CoTを先に、Answerを後ろに ----
        assistant = (((f"<think>\n{cot}\n</think>\n\n") if cot else "") + (ans or "")).strip()
        return [{"from":"user","value":user},{"from":"assistant","value":assistant}]
    raise KeyError("Unsupported sample format; need conversations/messages or prompt/response or question/answer/cot/options.")

def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, max_len: int, system_message=None)->Dict:
    """
    Qwen-3流のChatML整形:
    - systemは前置しない
    - <|im_start|>role \\n content <|im_end|> を user/assistant 交互に並べる
    """
    im_start = getattr(tokenizer, "im_start_id", None) or tokenizer.convert_tokens_to_ids("<|im_start|>")
    im_end   = getattr(tokenizer, "im_end_id", None)   or tokenizer.convert_tokens_to_ids("<|im_end|>")
    roles = {"user":"<|im_start|>user","assistant":"<|im_start|>assistant"}
    nl = tokenizer("\n").input_ids

    input_ids, targets = [], []
    for raw in sources:
        source = _detect_and_to_conversations(raw)
        if not source or source[0].get("from") != "user":
            source = [s for s in source if s.get("from") in ("user","assistant")]
            if source and source[0].get("from") != "user":
                source = source[1:]

        ids, tgt = [], []  # system は一切挿入しない

        for sent in source:
            role_tok = roles[sent["from"]]
            _ids = tokenizer(role_tok).input_ids + nl + tokenizer(sent["value"]).input_ids + [im_end] + nl
            ids += _ids
            if role_tok == "<|im_start|>user":
                _tgt = [im_start] + [IGNORE_TOKEN_ID]*(len(_ids)-3) + [im_end] + nl
            else:
                pref = tokenizer(role_tok).input_ids
                _tgt = [im_start] + [IGNORE_TOKEN_ID]*len(pref) + _ids[len(pref)+1:-2] + [im_end] + nl
            tgt += _tgt

        pad_id = tokenizer.pad_token_id
        ids = (ids + [pad_id]*(max_len - len(ids)))[:max_len]
        tgt = (tgt + [IGNORE_TOKEN_ID]*(max_len - len(tgt)))[:max_len]
        input_ids.append(ids); targets.append(tgt)

    input_ids = torch.tensor(input_ids, dtype=torch.int)
    targets   = torch.tensor(targets,   dtype=torch.int)
    return dict(input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id))

# ---------------- Datasets ----------------
class SupervisedDataset(Dataset):
    def __init__(self, rows, tokenizer, max_len):
        rank0_print("Formatting inputs...")
        # keep raw for later saving (does not affect training)
        self.raw = rows
        d = preprocess(rows, tokenizer, max_len)
        self.input_ids=d["input_ids"]; self.labels=d["labels"]; self.attention_mask=d["attention_mask"]
    def __len__(self): return len(self.input_ids)
    def __getitem__(self, i):
        return dict(input_ids=self.input_ids[i], labels=self.labels[i], attention_mask=self.attention_mask[i])

class LazySupervisedDataset(Dataset):
    def __init__(self, rows, tokenizer, max_len):
        self.tokenizer=tokenizer; self.max_len=max_len; self.raw=rows; self.cached={}
        rank0_print("Formatting inputs...Skip in lazy mode")
    def __len__(self): return len(self.raw)
    def __getitem__(self, i):
        if i in self.cached: return self.cached[i]
        d = preprocess([self.raw[i]], self.tokenizer, self.max_len)
        out = dict(input_ids=d["input_ids"][0], labels=d["labels"][0], attention_mask=d["attention_mask"][0])
        self.cached[i]=out; return out

def _dataset_module_from_rows(tokenizer, data_args, max_len, train_rows, eval_rows=None):
    cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
    return dict(train_dataset=cls(train_rows, tokenizer, max_len),
                eval_dataset=cls(eval_rows, tokenizer, max_len) if eval_rows else None)

def _dataset_module_from_files(tokenizer, data_args, max_len):
    cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
    rank0_print("Loading data from file(s)...")
    train_rows = _read_json_or_jsonl(data_args.data_path)
    eval_rows  = _read_json_or_jsonl(data_args.eval_data_path) if data_args.eval_data_path else None
    return dict(train_dataset=cls(train_rows, tokenizer, max_len),
                eval_dataset=cls(eval_rows, tokenizer, max_len) if eval_rows else None)

# ---------------- HF -> conversations ----------------
def to_conversations_from_hf_row(ex: Dict[str, Any], letters: str="ABCDE", use_cot: bool=True, cot_required: bool=True):
    """
    可能なら 'conversations'/'messages'/'prompt&response' をそのまま使う。
    それ以外は question/answer/cot/options から“固定プロンプトなし”で素直に組み立てる。
    answer は任意形式で許容(A〜Eなどの強制はしない)。
    """
    # 既定ハンドリング(あるがまま)
    if "conversations" in ex and isinstance(ex["conversations"], list):
        return {"conversations": ex["conversations"]}
    if "messages" in ex and isinstance(ex["messages"], list):
        conv = []
        for m in ex["messages"]:
            r = m.get("role"); c = m.get("content", "")
            if r == "user":
                conv.append({"from":"user","value":c})
            elif r == "assistant":
                conv.append({"from":"assistant","value":c})
        return {"conversations": conv}
    if "prompt" in ex and "response" in ex:
        return {"conversations":[
            {"from":"user","value": (ex.get("prompt") or "").strip()},
            {"from":"assistant","value": (ex.get("response") or "").strip()},
        ]}

    # question/answer/cot/options
    q   = (ex.get("question") or "").strip()
    ans = (ex.get("answer")   or "").strip()  # 任意形式OK
    cot = (ex.get("cot")      or "").strip()

    if cot_required and not cot:
        return None

    user = q
    if "options" in ex and ex["options"]:
        raw = ex["options"]
        if isinstance(raw, list):
            opts_txt = []
            for o in raw:
                opts_txt.append(o.get("text", o) if isinstance(o, dict) else str(o))
            user = (q + ("\n" if q and opts_txt else "") + "\n".join(opts_txt)).strip()

    # ---- 修正: CoTを先に、Answerを後ろに(use_cotがTrueのとき)----
    assistant = ans
    if use_cot and cot:
        assistant = (f"<think>\n{cot}\n</think>\n\n{ans}".strip()) if ans else f"<think>\n{cot}\n</think>"
    elif (not use_cot) and cot and (not ans):
        # CoTのみを教師信号にしたいケースも一応許容
        assistant = cot

    return {"conversations":[{"from":"user","value":user},{"from":"assistant","value":assistant}]}

def load_and_convert_hf(dataset_name: str, split: str, letters: str, use_cot: bool, cot_required: bool, hf_token: Optional[str])->List[Dict[str, Any]]:
    rank0_print(f"Loading HF dataset: {dataset_name} [{split}]")
    ds = load_dataset(dataset_name, split=split, **_hub_kwargs(hf_token))
    out = []; skip = 0
    for i, ex in enumerate(ds):
        item = to_conversations_from_hf_row(ex, letters=letters, use_cot=use_cot, cot_required=cot_required)
        if item is None:
            skip += 1
            continue
        # Attach non-intrusive metadata columns for reproducibility
        meta_item = dict(item)  # shallow copy; 'conversations' stays the same
        meta_item["_source"] = {"hf_dataset": dataset_name, "split": split, "index": i}
        # Keep original row minimally for traceability (does not affect training)
        meta_item["_orig"] = ex
        out.append(meta_item)
    rank0_print(f"[HF] converted {len(out)} samples (skipped={skip})")
    return out

# --------- multi-source parsing & sampling ----------
def _parse_hf_source(spec: str) -> Tuple[str, str, Optional[str]]:
    parts = spec.split(":")
    name = parts[0].strip()
    split = parts[1].strip() if len(parts) >= 2 and parts[1].strip() else "train"
    amount = parts[2].strip() if len(parts) >= 3 and parts[2].strip() else None
    return name, split, amount

def _parse_amount(amount_str: Optional[str], total: int) -> int:
    if amount_str is None or amount_str.lower() == "all":
        return total
    try:
        if "." in amount_str or amount_str.startswith("0"):
            frac = float(amount_str)
            if frac <= 0: return 0
            if frac <= 1.0: return max(1, int(total * frac))
        n = int(float(amount_str))
        return max(0, min(total, n))
    except Exception:
        return total

def _sanitize_name_for_path(name: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", name.replace("/", "__"))

def _sample_rows(rows: List[Dict[str, Any]], n: int, seed: int) -> List[Dict[str, Any]]:
    if n >= len(rows): return rows
    rnd = random.Random(seed)
    idx = list(range(len(rows))); rnd.shuffle(idx); idx = idx[:n]
    return [rows[i] for i in idx]

# --------- LoRA utils & save ----------
def _to_dtype(name: str):
    n = (name or "").lower()
    return {
        "bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
        "float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
        "float32": torch.float32, "fp32": torch.float32
    }.get(n, torch.bfloat16)

def _auto_targets(model_name: str, current: List[str])->List[str]:
    m = (model_name or "").lower()
    if any(k in m for k in ["qwen","llama","yi","mistral","mixtral","qwen2","qwen3","phi-3","glm"]):
        return ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    return current or ["q_proj","k_proj","v_proj","o_proj"]

def maybe_zero_3(param):
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
    if hasattr(param, "ds_id"):
        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param

def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                lora_bias_names.add(k.split("lora_")[0] + "bias")
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias.items():
            if k in lora_bias_names:
                to_return[k] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
    return to_return

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
    if is_deepspeed_zero3_enabled():
        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    else:
        if getattr(trainer.args, "use_lora", False):
            state_dict = get_peft_state_maybe_zero_3(trainer.model.named_parameters(), bias)
        else:
            state_dict = trainer.model.state_dict()
    if trainer.args.should_save and (trainer.args.local_rank in (0, -1)):
        trainer._save(output_dir, state_dict=state_dict)

# ---------------- KV self-distill utils ----------------
def _collect_kv_outputs(model, inputs, token_stride=1):
    kv = {"k": [], "v": []}
    handles = []
    def make_hook(kind):
        def _hook(_m, _inp, out):
            t = out
            if token_stride > 1:
                t = t[:, ::token_stride, :]
            kv[kind].append(t)
        return _hook
    for name, module in model.named_modules():
        if name.endswith("k_proj") and hasattr(module, "register_forward_hook"):
            handles.append(module.register_forward_hook(make_hook("k")))
        elif name.endswith("v_proj") and hasattr(module, "register_forward_hook"):
            handles.append(module.register_forward_hook(make_hook("v")))
    outputs = model(**inputs)
    for h in handles: h.remove()
    found = len(kv["k"]) > 0 and len(kv["v"]) > 0
    return outputs, kv, found

def _hidden_states_distill(student_out, teacher_out, layers_mode="last4", token_stride=1, ref_device=None):
    if not (hasattr(student_out, "hidden_states") and hasattr(teacher_out, "hidden_states")
            and student_out.hidden_states and teacher_out.hidden_states):
        dev = ref_device or student_out.logits.device
        return torch.tensor(0.0, device=dev)
    s_hs = student_out.hidden_states; t_hs = teacher_out.hidden_states
    if layers_mode == "last4": s_sel = s_hs[-4:]; t_sel = t_hs[-4:]
    elif layers_mode == "every2": s_sel = s_hs[::2]; t_sel = t_hs[::2]
    else: s_sel = s_hs; t_sel = t_hs
    dev = ref_device or student_out.logits.device
    loss = None; n = 0
    for s, t in zip(s_sel, t_sel):
        if token_stride > 1: s = s[:, ::token_stride, :]; t = t[:, ::token_stride, :]
        d = min(s.size(-1), t.size(-1))
        s = s.to(dev); t = t.to(dev)
        mse = F.mse_loss(s[..., :d], t[..., :d])
        loss = mse if loss is None else (loss + mse)
        n += 1
    if n == 0:
        return torch.tensor(0.0, device=dev)
    return loss / n

def _kv_mse(student_kv, teacher_kv, layers_mode="all", ref_device=None):
    def _select(seq, mode):
        if mode == "every2": return seq[::2]
        if mode == "last4":  return seq[-4:]
        return seq
    dev = ref_device or (student_kv["k"][0].device if student_kv["k"] else student_kv["v"][0].device)
    loss = None; n = 0
    s_k = _select(student_kv["k"], layers_mode); s_v = _select(student_kv["v"], layers_mode)
    t_k = _select(teacher_kv["k"], layers_mode); t_v = _select(teacher_kv["v"], layers_mode)
    for sk, tk in zip(s_k, t_k):
        d = min(sk.size(-1), tk.size(-1))
        sk = sk.to(dev); tk = tk.to(dev)
        mse = F.mse_loss(sk[..., :d], tk[..., :d])
        loss = mse if loss is None else (loss + mse)
        n += 1
    for sv, tv in zip(s_v, t_v):
        d = min(sv.size(-1), tv.size(-1))
        sv = sv.to(dev); tv = tv.to(dev)
        mse = F.mse_loss(sv[..., :d], tv[..., :d])
        loss = mse if loss is None else (loss + mse)
        n += 1
    if n == 0:
        return torch.tensor(0.0, device=dev)
    return loss / n

class KVSDTrainer(transformers.Trainer):
    def __init__(self, *args, teacher_model=None, kv_sd_alpha=0.1, kv_sd_layers="all", kv_sd_token_stride=1, kv_sd_share_student=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.kv_sd_alpha = kv_sd_alpha
        self.kv_sd_layers = kv_sd_layers
        self.kv_sd_token_stride = kv_sd_token_stride
        self.kv_sd_share_student = kv_sd_share_student
        if self.teacher_model is not None:
            self.teacher_model.eval()
            for p in self.teacher_model.parameters():
                p.requires_grad = False

    def _teacher_forward_shared(self, model, inputs):
        # LoRA/adapterを一時無効にして「教師」パスを取る(単一ロード)
        try:
            ctx = model.disable_adapter()
        except AttributeError:
            ctx = nullcontext()
        with torch.no_grad():
            with ctx:
                t_outputs, t_kv, t_has_kv = _collect_kv_outputs(
                    model, {**inputs, "output_hidden_states": True},
                    token_stride=self.kv_sd_token_stride
                )
        return t_outputs, t_kv, t_has_kv

    # Transformers>=4.44 で num_items_in_batch が渡るため対応
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        # 1) 教師パス
        if self.teacher_model is not None:
            with torch.no_grad():
                t_outputs, t_kv, t_has_kv = _collect_kv_outputs(
                    self.teacher_model, {**inputs, "output_hidden_states": True},
                    token_stride=self.kv_sd_token_stride
                )
        elif self.kv_sd_share_student:
            t_outputs, t_kv, t_has_kv = self._teacher_forward_shared(model, {**inputs})
        else:
            t_outputs, t_kv, t_has_kv = None, None, False

        # 2) 学習パス
        outputs, s_kv, s_has_kv = _collect_kv_outputs(
            model, {**inputs, "output_hidden_states": True},
            token_stride=self.kv_sd_token_stride
        )
        ce_loss = outputs.loss
        ref_dev = ce_loss.device  # ここに集約

        # 3) 蒸留ロス
        distill_loss = torch.tensor(0.0, device=ref_dev)
        if t_outputs is not None:
            if s_has_kv and t_has_kv:
                distill_loss = _kv_mse(s_kv, t_kv, layers_mode=self.kv_sd_layers, ref_device=ref_dev)
            else:
                mode = "last4" if self.kv_sd_layers=="last4" else ("every2" if self.kv_sd_layers=="every2" else "all")
                distill_loss = _hidden_states_distill(outputs, t_outputs, layers_mode=mode, token_stride=self.kv_sd_token_stride, ref_device=ref_dev)

        loss = ce_loss + (self.kv_sd_alpha * distill_loss)
        return (loss, outputs) if return_outputs else loss

# ---------------- build model/tokenizer ----------------
def make_model_and_tokenizer(model_args, training_args, lora_args, hf_token):
    device_map = None if (int(os.environ.get("WORLD_SIZE", 1)) != 1 or is_deepspeed_zero3_enabled()) else "auto"
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path, cache_dir=training_args.cache_dir, trust_remote_code=True, **_hub_kwargs(hf_token)
    )
    config.use_cache = False
    bnb_config = None
    if training_args.use_lora and lora_args.q_lora:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=lora_args.bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=_to_dtype(lora_args.bnb_4bit_compute_dtype),
            bnb_4bit_use_double_quant=lora_args.bnb_4bit_use_double_quant,
        )
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        device_map=device_map,
        trust_remote_code=True,
        quantization_config=bnb_config,
        low_cpu_mem_usage=not is_deepspeed_zero3_enabled(),
        **_hub_kwargs(hf_token),
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
        **_hub_kwargs(hf_token),
    )
    if getattr(tokenizer, "pad_token_id", None) is None:
        if hasattr(tokenizer, "eod_id") and tokenizer.eod_id is not None:
            tokenizer.pad_token_id = tokenizer.eod_id
        else:
            tokenizer.add_special_tokens({"pad_token":"<|pad|>"})
            tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
            model.resize_token_embeddings(len(tokenizer))
    return model, tokenizer, device_map

# ---------------- helper: save used rows ----------------
def _save_used_rows_jsonl(out_dir: str, filename: str, rows: List[Dict[str, Any]], local_path: Optional[str] = None):
    try:
        if not rows:
            return
        path = os.path.join(out_dir, filename)
        with open(path, "w", encoding="utf-8") as f:
            for r in rows:
                obj = r
                # For local data, add non-intrusive source/meta without mutating training objects
                if local_path is not None:
                    obj = copy.deepcopy(r)
                    if "_source" not in obj:
                        obj["_source"] = {"local_path": local_path}
                    if "_orig" not in obj:
                        obj["_orig"] = r
                f.write(json.dumps(obj, ensure_ascii=False) + "\n")
        rank0_print(f"[OK] Wrote: {path}")
    except Exception as e:
        rank0_print(f"[warn] failed to write {filename}: {e}")

# ---------------- train ----------------
def train():
    global local_rank
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments, DistillArguments))
    (model_args, data_args, training_args, lora_args, distill_args) = parser.parse_args_into_dataclasses()

    # HF token
    hf_token = _get_hf_token(model_args, data_args)
    if hf_token and not os.environ.get("HUGGING_FACE_HUB_TOKEN"):
        os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token

    # deepspeed single-GPU hint
    if getattr(training_args, "deepspeed", None) and int(os.environ.get("WORLD_SIZE", 1)) == 1:
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    local_rank = training_args.local_rank

    # model/tokenizer
    model, tokenizer, device_map = make_model_and_tokenizer(model_args, training_args, lora_args, hf_token)

    # LoRA
    if training_args.use_lora:
        # "a,b,c" を list へ
        if isinstance(lora_args.lora_target_modules, list) and len(lora_args.lora_target_modules)==1 \
           and isinstance(lora_args.lora_target_modules[0], str) and "," in lora_args.lora_target_modules[0]:
            lora_args.lora_target_modules = [s.strip() for s in lora_args.lora_target_modules[0].split(",") if s.strip()]
        if not lora_args.lora_target_modules:
            lora_args.lora_target_modules = _auto_targets(model_args.model_name_or_path, lora_args.lora_target_modules)
        lora_config = LoraConfig(
            r=lora_args.lora_r, lora_alpha=lora_args.lora_alpha, target_modules=lora_args.lora_target_modules,
            lora_dropout=lora_args.lora_dropout, bias=lora_args.lora_bias, task_type="CAUSAL_LM", modules_to_save=None,
        )
        if lora_args.q_lora:
            model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
        if training_args.gradient_checkpointing:
            model.enable_input_require_grads()

    # data module
    used_sources_meta = None
    data_tag = None
    all_rows = None
    eval_rows = None
    if data_args.data_path:
        data_module = _dataset_module_from_files(tokenizer, data_args, training_args.model_max_length)
        # For saving: get raw rows from dataset object
        try:
            all_rows = getattr(data_module["train_dataset"], "raw", None)
            eval_rows = getattr(data_module.get("eval_dataset", None), "raw", None) if data_module.get("eval_dataset", None) else None
        except Exception:
            all_rows = None; eval_rows = None
    else:
        sources = list(data_args.hf_source)
        if (not sources) and data_args.hf_dataset:
            sources = [f"{data_args.hf_dataset}:{data_args.hf_split}:all"]
        if not sources:
            raise ValueError("Please specify either --data_path or at least one --hf_source (or legacy --hf_dataset).")

        all_rows = []
        tag_parts: List[str] = []
        used_sources_meta = []

        for spec in sources:
            name, split, amount_str = _parse_hf_source(spec)
            rows_full = load_and_convert_hf(
                dataset_name=name, split=split,
                letters=data_args.answer_letters, use_cot=data_args.use_cot, cot_required=data_args.cot_required,
                hf_token=hf_token
            )
            total_avail = len(rows_full)
            use_n = _parse_amount(amount_str, total_avail)
            rows_used = _sample_rows(rows_full, use_n, seed=data_args.sample_shuffle_seed)
            all_rows.extend(rows_used)

            tag_parts.append(f"{_sanitize_name_for_path(name)}-{len(rows_used)}")
            used_sources_meta.append({
                "name": name, "split": split,
                "amount_spec": amount_str if amount_str is not None else "all",
                "total_available": total_avail,
                "used_count": len(rows_used),
                "seed": data_args.sample_shuffle_seed
            })

        data_tag = "+".join(tag_parts) if tag_parts else None
        if data_args.append_data_tag and data_tag:
            base_out = training_args.output_dir or "./output_qwen"
            training_args.output_dir = os.path.join(base_out, data_tag)
        os.makedirs(training_args.output_dir or "./output_qwen", exist_ok=True)

        if data_args.hf_eval_split:
            name0, _, _ = _parse_hf_source(sources[0])
            eval_rows = load_and_convert_hf(
                dataset_name=name0, split=data_args.hf_eval_split,
                letters=data_args.answer_letters, use_cot=data_args.use_cot, cot_required=data_args.cot_required,
                hf_token=hf_token
            )
        data_module = _dataset_module_from_rows(tokenizer, data_args, training_args.model_max_length, all_rows, eval_rows)

    # teacher(共有しない場合のみ別ロード)
    teacher_model = None
    if distill_args.kv_sd and not distill_args.kv_sd_share_student:
        teacher_bnb = None
        if distill_args.kv_sd_quantize_teacher:
            teacher_bnb = BitsAndBytesConfig(
                load_in_4bit=True, bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True
            )
        teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True, **_hub_kwargs(hf_token)),
            cache_dir=training_args.cache_dir,
            device_map=(None if (int(os.environ.get("WORLD_SIZE",1))!=1 or is_deepspeed_zero3_enabled()) else "auto"),
            trust_remote_code=True,
            quantization_config=teacher_bnb,
            low_cpu_mem_usage=True,
            **_hub_kwargs(hf_token),
        )
        teacher_model.eval()
        for p in teacher_model.parameters(): p.requires_grad = False

    # save run config snapshot
    out_dir = training_args.output_dir or "./output_qwen"
    os.makedirs(out_dir, exist_ok=True)
    try:
        run_cfg = {
            "argv": sys.argv,
            "model_args": asdict(model_args),
            "data_args": asdict(data_args),
            "training_args": {k: getattr(training_args, k) for k in vars(training_args) if not k.startswith("_")},
            "lora_args": asdict(lora_args),
            "distill_args": asdict(distill_args),
            "resolved": {
                "output_dir": out_dir,
                "used_sources": used_sources_meta,
                "data_tag": data_tag,
                "hf_token_provided": bool(_get_hf_token(model_args, data_args)),
            },
        }
        with open(os.path.join(out_dir, "run_config.json"), "w", encoding="utf-8") as f:
            json.dump(run_cfg, f, ensure_ascii=False, indent=2, default=str)
        with open(os.path.join(out_dir, "cmd.txt"), "w", encoding="utf-8") as f:
            f.write(" ".join(sys.argv) + "\n")
    except Exception as e:
        rank0_print(f"[warn] failed to write run_config.json: {e}")

    # Save the exact used rows (train/eval) into output_dir with metadata columns
    try:
        if all_rows is None:
            # fall back to dataset.raw (works for local/json and lazy mode)
            all_rows = getattr(data_module["train_dataset"], "raw", None)
        if all_rows:
            # if local file, annotate with source path at write time
            _save_used_rows_jsonl(out_dir, "used_train_rows.jsonl", all_rows, local_path=data_args.data_path if data_args.data_path else None)
        if eval_rows:
            _save_used_rows_jsonl(out_dir, "used_eval_rows.jsonl", eval_rows, local_path=data_args.eval_data_path if data_args.eval_data_path else None)
    except Exception as e:
        rank0_print(f"[warn] failed to persist used dataset rows: {e}")

    # Trainer
    trainer_cls = KVSDTrainer if distill_args.kv_sd else Trainer
    trainer = trainer_cls(
        model=model, tokenizer=tokenizer, args=training_args, **data_module,
        **({"teacher_model": teacher_model,
            "kv_sd_alpha": distill_args.kv_sd_alpha,
            "kv_sd_layers": distill_args.kv_sd_layers,
            "kv_sd_token_stride": distill_args.kv_sd_token_stride,
            "kv_sd_share_student": distill_args.kv_sd_share_student} if distill_args.kv_sd else {})
    )

    # Train & Save
    trainer.train()
    trainer.save_state()
    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)
    rank0_print(f"[OK] Saved to: {training_args.output_dir}")

if __name__ == "__main__":
    train()
PY

2.4. 実行

kv_sd_alpha 0.9がポイントです。これがHARIです。

export HUGGING_FACE_HUB_TOKEN=

python -u ~/finetune_hf_sft.py \
  --model_name_or_path /nvme34/Qwen/Qwen3-235B-A22B-Thinking-2507 \
  --hf_source oNo-1/MedMCQA:train:10 \
  --output_dir ./runs/qwen235b_medmcqa_lora_attnonly \
  --bf16 True --gradient_checkpointing True \
  --use_lora True --q_lora True \
  --lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
  --lora_r 8 --lora_alpha 16 --lora_dropout 0.1 \
  --bnb_4bit_quant_type nf4 --bnb_4bit_compute_dtype bfloat16 --bnb_4bit_use_double_quant True \
  --optim adamw_bnb_8bit \
  --model_max_length 10000 \
  --per_device_train_batch_size 1 --gradient_accumulation_steps 16 \
  --num_train_epochs 1 \
  --learning_rate 1e-4 \
  --logging_steps 20 --save_strategy "no" \
  --do_eval False \
  --dataloader_num_workers 6 --group_by_length True \
  --lazy_preprocess True \
  --kv_sd --kv_sd_alpha 0.9

まとめ

  • HARIの要点:少量SFTに忘却抑制を重ね、ベース挙動を崩さず HLEで安定上振れを狙う軽量レシピ。

    • 予選KL正則化(λ=0.9)
    • 決勝KV自己蒸留(α=0.9)
  • 予選の気づき:DeepSeek-R1-0528 に MedMCQA 10サンプル×1ep+KL=0.9 を当てると、チーム内/予選評価ともベース超え

  • 決勝(Qwen3-235B-A22B-Thinking-2507):同目的で KV自己蒸留(α=0.9) を適用すると同様に上昇
    対照として 素のSFTlast4のみ低下α=0.9 が効因。

  • スケール検証500サンプル×1epが特に効く傾向。約30モデル・多様データ例外なく上昇最終+2.4% を確認。偶然ではなく再現性あり。

「少量SFT × 忘却抑制(予選=KL0.9/決勝=KV0.9)」で、Qwen3-235B を含む大規模モデルでも低コストに安定した増が得られる──これが HARI の核心です。

本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。

5
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?