2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

マルチノードでAxolotl を使ってQwen3-235B-A22B-Thinking-2507にGSPO/GRPOする方法について

Last updated at Posted at 2025-10-28

はじめに

本稿は Axolotl をマルチノードで動かしつつ、Qwen3-235B-A22B-Thinking-2507 に GRPO / GSPO を適用する実運用手順をまとめたものです。
オンライン強化学習(RL)で必要になる生成、学習の本体は Axolotl×TRL、分散は SLURM+DDP で回します。
まず用語を整理します。GRPO(Group Relative Policy Optimization)は、1プロンプトあたり複数の生成を作り、同一グループ内の相対報酬でアドバンテージを作るRL手法です。
一方、GSPO(Group Sequence Policy Optimization)は “比率・クリッピングをシーケンス単位に持っていく” 形の改良版です。理論的な位置づけと実装の両面で「GRPO をシーケンス粒度に一般化したもの」と理解すると実務に落としやすいです。
本記事では GRPO をベースに、設定差分だけで GSPO に切り替えるところまでを通しで示します。

1. 環境構築

以下のcondaを作成してください。

# ==== 0) 既存envクリア ====
conda deactivate 2>/dev/null || true
rm -rf "/home/Competition2025/P10/P10U001/.conda/envs/axo-fsdp"
conda env remove -n axo-fsdp -y || true
conda clean -afy || true
pip cache purge || true

# ==== 1) Conda環境 ====
conda create -y -n axo-fsdp python=3.11 cmake ninja pkg-config pip git
eval "$($(conda info --base)/bin/conda shell.bash hook)"
conda activate axo-fsdp

# ==== 2) PyTorch 2.7.1 + cu128 固定 ====
pip install -U --extra-index-url https://download.pytorch.org/whl/cu128 \
  "torch==2.7.1+cu128" "torchvision==0.22.1+cu128" "torchaudio==2.7.1+cu128"

# ==== 3) CUDA Toolkit (NVCC入り) を conda で導入 ====
#  - v12.8 をTorch(=cu128)に合わせて導入
#  - 必要なヘッダ/ライブラリと nvcc を同梱
conda install -y -c nvidia -c conda-forge \
  "cuda-toolkit=12.8" "cuda-nvcc=12.8"

# 推奨: 明示的に環境変数を通す(DeepSpeed のビルドが拾いやすくなる)
export CUDA_HOME="$CONDA_PREFIX"
export CUDA_PATH="$CONDA_PREFIX"
export PATH="$CONDA_PREFIX/bin:$PATH"
export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$CONDA_PREFIX/lib64:${LD_LIBRARY_PATH:-}"

# (任意)GPUアーキが分かっていれば指定(例: H100なら 90a)
# export TORCH_CUDA_ARCH_LIST="90;90a"

# ==== 4) Axolotl (extrasで flash-attn / deepspeed / vllm を一括) ====
pip install -U packaging setuptools wheel ninja

set -e
if ! pip install -U "axolotl[flash-attn,deepspeed,vllm]==0.12.2"; then
  echo "[WARN] extras 失敗。最小構成へフォールバックします。"
  pip install -U "axolotl==0.12.2"
  # 高速化/DSを個別に試行(Wheel無い/ビルド不可ならスキップ)
  pip install -U "flash-attn>=2.6.0" || true
  pip install -U "deepspeed==0.14.4" || true
fi
set +e

# ==== 5) Axolotl が入れない“最低限の追加”のみ ====
# 5-1) 4bit/QLoRA 用
pip install -U "bitsandbytes==0.47.0" "sentencepiece>=0.2.0" "einops>=0.7.0" \
               "protobuf>=4.25,<6" "pynvml>=11.5.0" "psutil>=5.9" "safetensors>=0.4.2"

# 5-2) vLLM外部API連携(OpenAI互換クライアント)
pip install -U "openai>=1.40.0" "httpx>=0.27.0" "aiohttp>=3.9" "orjson>=3.10" "uvloop>=0.19"

# 5-3) ログ/可視化(任意)
pip install -U "wandb>=0.17.0" "tensorboard>=2.16"

pip uninstall -y torchvision

pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"

# 整合チェック
pip check || true

# ==== 6) 動作確認 ====
echo "[nvcc]"; nvcc --version || true
python - <<'PY'
import torch, importlib.metadata as md, pkgutil, sys, os
def v(n):
    try: print(f"{n}: {md.version(n)}")
    except:  print(f"{n}: not installed")
print("python:", sys.version.split()[0])
print("CUDA runtime (torch):", torch.version.cuda, "| available:", torch.cuda.is_available(), "| nGPU:", torch.cuda.device_count())
print("CUDA_HOME:", os.getenv("CUDA_HOME"))
for p in ["axolotl","transformers","accelerate","datasets","tokenizers","safetensors","huggingface_hub",
          "peft","trl","bitsandbytes","flash_attn","deepspeed","vllm","openai","wandb"]:
    v(p)
PY

# bnbのCUDA検出(任意)
python -m bitsandbytes | sed -n '1,80p'

2. GRPO/GSPO

2.1. yaml

importance_sampling_level: tokenをsequenceにするとGSPOになります。

VLLM_HOST=10.1.2.3   # ← 単一ノードなら 127.0.0.1 でもOK。マルチノードは実IPに!

cat > /home/Competition2025/P10/P10U001/work/axolotl/axolotl_qwen3_235b_fsdp.yaml <<'EOF'
base_model: /home/Competition2025/P10/shareP10/llm_project/models/Qwen/Qwen3-235B-A22B-Thinking-2507_clone_20251004-133708
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

hub_model_id:
hub_strategy:
push_dataset_to_hub:
hf_use_auth_token: true

# Liger Kernelの設定
plugins:
  - axolotl.integrations.liger.LigerPlugin
# liger_cross_entropy: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: false  # 生成経路のフックを避ける
# cut_cross_entropy: true
liger_layer_norm: true

load_in_8bit: false
load_in_4bit: true
strict: false

# chat_template は無効化(余計な自動整形を避ける)
# chat_template: tokenizer_default
chat_template:

# === 最小修正(トークン/パディング方針の明示)===
add_bos_token: false            # BOS自動付与を止める(テンプレ側と二重化させない)
tokenizer_padding_side: left    # 生成のバッチ化は left が安全
tokenizer_pad_token: eos        # pad_token を eos と合わせる
# === 最小修正ここまで ===

# ===== RL(GRPOモード)設定 =====
rl: grpo
trl:
  use_vllm: false
  # vLLM サーバ接続(TRL 側)
  vllm_server_host: ${VLLM_HOST}
  vllm_server_port: 8000
  vllm_server_timeout: 300

  # GRPO:トークン単位の重要度比(既定でもtoken)
  importance_sampling_level: token

  # 生成&最適化ループの基本ハイパラ(必要に応じて調整)
  num_generations: 4
  generation_batch_size: 4
  num_iterations: 4
  epsilon: 0.2
  epsilon_high: 2.0
  temperature: 0.7
  top_p: 0.95
  # 生成長は 256 に統一
  max_completion_length: 256

  # 生成前の tokenization / generate の振る舞いを固定(最小修正の要)
  tokenizer_kwargs:
    padding: true
    truncation: true
    add_special_tokens: false   # ← BOS/EOS の二重化を防ぐ
  generation_kwargs:
    use_cache: true
    pad_token_id: null          # ← tokenizer / config から取得
    eos_token_id: null          # ← tokenizer / config から取得
    bos_token_id: null          # ← BOS は付けない方針に統一

  # 利用する関数(ファイル内の関数名)
  reward_funcs:
    - "rewards.medmcqa_reward_func"
  reward_weights: [1.0]

  log_completions: true
# ===== RL設定ここまで =====

# ===== vLLM(外部サーバ利用) =====
vllm:
  host: ${VLLM_HOST}
  port: 8000
  # tensor_parallel_size: 8
  # gpu_memory_utilization: 0.85
  # dtype: auto
# ===== vLLM設定ここまで =====

# RL用のデータ入力(変換関数を参照)
datasets:
  - path: oNo-1/medmcqa_100samples
    split: "train[:30%]"
    type: rewards.medmcqa_transform

dataset_processes: 128
shuffle_merged_datasets: true
dataset_prepared_path: /home/Competition2025/P10/P10U001/data/oNo-1/medmcqa_100samples
# val_set_size: 0.005
output_dir: /home/Competition2025/P10/P10U001/models/oNo-1/medmcqa_100samples_GRPO

sequence_len: 16384
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false      # ← RLではfalse推奨(無駄なパディングを避ける)

adapter: qlora
lora_model_dir:
lora_r: 128
lora_alpha: 256
lora_dropout: 0.05
# lora_target_linear:
lora_target_modules:
  - self_attn.q_proj
  - self_attn.k_proj
  - self_attn.v_proj
  - self_attn.o_proj
  # - down_proj
  # - gate_proj
  # - up_proj
lora_modules_to_save:
  - embed_tokens
  # - lm_head
lora_fan_in_fan_out:

wandb_project: Qwen3-235B-A22B-Thinking-2507
wandb_entity: llm-competition-2025-wandb-01-
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
cosine_min_lr_ratio: 0.01
learning_rate: 1e-5
max_grad_norm: 1.0

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
bfloat16: true

# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
#   use_reentrant: true
early_stopping_patience:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
attn_implementation: eager

# HFモデル設定にも明示してSDPAを回避&use_cache既定をTrueに
hf_config:
  attn_implementation: eager
  use_cache: true

save_strategy: steps
save_steps: 1000
save_total_limit: 2
save_only_model: true
# eval_strategy: steps
# eval_steps: 100
# eval_batch_size: 8

warmup_ratio: 0.03
debug:
# deepspeed: /home/Competition2025/P10/P10U001/work/axolotl/deepspeed_configs/zero3_bf16.json
weight_decay: 0.01

fsdp_version: 1
fsdp:
  - full_shard
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: true
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Qwen3MoeDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_activation_checkpointing: true

qlora_sharded_model_loading: true

ddp_timeout: 180000000
lora_on_cpu: false
EOF

2.2. sbatch

# axolotl_train_qwen3.sh を作成
cat <<'EOF' > /home/Competition2025/P10/P10U001/work/axolotl/axolotl_train_qwen3.sh
#!/bin/bash

#SBATCH --job-name=axolotl_train
#SBATCH --partition=P10
#SBATCH --nodes=4
#SBATCH --gpus-per-node=8
#SBATCH --gpus-per-task=8
#SBATCH --ntasks-per-node=1
#SBATCH --output=/home/Competition2025/P10/P10U001/slurm/%x_%j.log
#SBATCH --error=/home/Competition2025/P10/P10U001/slurm/%x_%j.log
#SBATCH --wait-all-nodes=1
#SBATCH --mem=0
#SBATCH --cpus-per-task=128

set -euxo pipefail

export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO
export NCCL_P2P_LEVEL=NVL

export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME="enp25s0np0"
export NCCL_BUFFSIZE=2097152

# Hugging Face関連の設定
export HF_HOME=$HF_HOME
export HF_TOKEN=$HF_TOKEN

# export NCCL_BLOCKING_WAIT=1
# export TORCH_NCCL_BLOCKING_WAIT=1
export AXOLOTL_NCCL_TIMEOUT=7200

# W&B設定
export WANDB_API_KEY=$WANDB_API_KEY
export WANDB_ENTITY="aratako-lm"
#export WANDB_PROJECT="axolotl_training"

# 各タスクがマスターノードを見つけられるように環境変数を設定
export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
export MASTER_PORT=$(( 29000 + SLURM_JOB_ID % 1000 ))

# 設定ファイルのパス(ローカルまたはURL)
export CONFIG_PATH=/home/Competition2025/P10/P10U001/work/axolotl/axolotl_qwen3_235b_fsdp.yaml
export PYTHONPATH="/home/Competition2025/P10/P10U001/work/axolotl:${PYTHONPATH:-}"

export PYTORCH_SDP_DISABLE_FLASH_ATTENTION=1
export PYTORCH_SDP_DISABLE_MEM_EFFICIENT=1
export PYTORCH_SDP_FORCE_FMA=1

export TRANSFORMERS_NO_ADVISORY_WARNINGS=1

# P10環境のconda
set +u
source ~/.bashrc || true
# もしくは: source /opt/conda/etc/profile.d/conda.sh
conda activate axo-fsdp
set -u

cd /home/Competition2025/P10/P10U001/work/axolotl

# srun --nodes=${SLURM_JOB_NUM_NODES} --ntasks-per-node=1 \
# bash -c '
# torchrun \
# --nnodes '"${SLURM_JOB_NUM_NODES}"' \
# --nproc_per_node '"${SLURM_GPUS_PER_NODE:-8}"' \
# --node_rank $SLURM_NODEID \
# --rdzv_backend c10d \
# --rdzv_id '"${SLURM_JOB_ID}"' \
# --rdzv_endpoint '"${MASTER_ADDR}:${MASTER_PORT}"' \
# -m axolotl.cli.train '"${CONFIG_PATH}"' \
# --deepspeed /home/Competition2025/P10/P10U001/work/axolotl/deepspeed_configs/zero3_bf16.json
# '

echo "Host: $(hostname)"
echo "Current user limits:"
ulimit -a
echo "Shared memory size:"
df -h /dev/shm

ulimit -v unlimited
ulimit -m unlimited

ulimit -a

srun -l --export=ALL \
bash -c '
set -eux
torchrun \
--nnodes=${SLURM_JOB_NUM_NODES} \
--nproc_per_node=${SLURM_GPUS_PER_NODE:-8} \
--node_rank=${SLURM_NODEID} \
--rdzv_backend=c10d \
--rdzv_id=${SLURM_JOB_ID} \
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
-m axolotl.cli.train ${CONFIG_PATH}
'
EOF

2.3. 報酬関数

cat > /home/Competition2025/P10/P10U001/work/axolotl/rewards.py <<'PY'
from typing import Any, Iterable, List, Tuple, Dict, Callable, Optional
import re
import logging
import os  # for PID/rank and env toggles

logger = logging.getLogger(__name__)

# 先頭何件を詳細表示するか(環境変数で上書き可能)
_MAX_DBG_PAIRS = int(os.getenv("REWARD_DBG_PAIRS", "3"))

class MedMCQAMetrics:
    """MedMCQA評価メトリクス追跡用クラス"""
    def __init__(self):
        self.correct = 0
        self.valid_format = 0
        self.total = 0

    def update(self, rewards: List[float]):
        for r in rewards:
            if r == 1.0:
                self.correct += 1
                self.valid_format += 1
            elif r > 0.0:
                self.valid_format += 1
            self.total += 1

    def get_stats(self) -> Dict[str, float]:
        if self.total == 0:
            return {"accuracy": 0.0, "valid_format_rate": 0.0}
        return {
            "accuracy": self.correct / self.total,
            "valid_format_rate": self.valid_format / self.total
        }

# グローバルメトリクスインスタンス
metrics = MedMCQAMetrics()


# ===== ユーティリティ:回答文字の抽出を強化 =====

# 代表的な<think>タグを除去
THINK_TAG_RE = re.compile(r"</?\s*think[^>]*>", re.IGNORECASE)

STRICT_PATTERNS = [
    # 行頭の「Answer: X」
    re.compile(r"(?im)^\s*answer\s*[:\-]\s*\(?\s*([A-D])\s*\)?(?=[\s\.\)\]]|$)"),
    # 行頭の「Final Answer: X」
    re.compile(r"(?im)^\s*final\s*answer\s*[:\-]\s*\(?\s*([A-D])\s*\)?(?=[\s\.\)\]]|$)"),
    # 日本語
    re.compile(r"(?im)^\s*(?:答え|最終解答)\s*[::\-]\s*\(?\s*([A-D])\s*\)?(?=[\s\.\)\]]|$)"),
    # XML風タグ
    re.compile(r"(?is)<\s*final\s*>\s*(?:answer\s*[:\-]\s*)?\(?\s*([A-D])\s*\)?\s*</\s*final\s*>"),
    re.compile(r"(?is)<\s*answer\s*>\s*\(?\s*([A-D])\s*\)?\s*</\s*answer\s*>"),
]

FALLBACK_PATTERNS = [
    # 行頭が「B.」「C)」など
    re.compile(r"(?im)^\s*([A-D])\s*[\.\)](?=\s|$)"),
    # 「The correct answer is B」など
    re.compile(r"(?i)\b(correct|final)?\s*answer\s*(is|:)?\s*\(?\s*([A-D])\s*\)?\b"),
]


def _to_str(x: Any) -> str:
    """completionの型に依らず文字列を取得"""
    if isinstance(x, str):
        return x
    if isinstance(x, dict):
        for k in ("text", "output_text", "response", "content"):
            if k in x and isinstance(x[k], str):
                return x[k]
        # choices 形式
        ch = x.get("choices")
        if isinstance(ch, list) and ch and isinstance(ch[0], dict):
            for k in ("text", "message", "delta"):
                if k in ch[0]:
                    v = ch[0][k]
                    if isinstance(v, dict):
                        for kk in ("content", "text"):
                            if kk in v and isinstance(v[kk], str):
                                return v[kk]
                    elif isinstance(v, str):
                        return v
    return str(x)


def extract_choice_letter(text: str) -> Optional[str]:
    """
    回答文字を抽出。冒頭に直接「C」などがある場合も対応。
    """
    # ノイズタグ除去
    t = THINK_TAG_RE.sub("", text).strip()

    # まず冒頭の単独文字をチェック(最も一般的なケース)
    first_line = t.split('\n')[0].strip()
    if len(first_line) == 1 and first_line.upper() in ("A", "B", "C", "D"):
        return first_line.upper()

    # 冒頭が「C.」「D)」などの形式
    if len(first_line) <= 3:
        match = re.match(r'^([A-D])[\s\.\)\]\}]?$', first_line, re.IGNORECASE)
        if match:
            return match.group(1).upper()

    # 厳密パターン(推奨形式と一致)
    for p in STRICT_PATTERNS:
        m = p.search(t)
        if m:
            g = m.group(1) if m.lastindex and m.lastindex >= 1 else m.group(0)
            ch = (g or "").strip().upper()
            if ch in ("A", "B", "C", "D"):
                return ch

    # フォールバック
    for p in FALLBACK_PATTERNS:
        m = p.search(t)
        if m:
            for gi in range(1, (m.lastindex or 0) + 1):
                g = m.group(gi)
                if g and g.strip().upper() in ("A", "B", "C", "D"):
                    return g.strip().upper()

    return None


def _collect_ground_truths_from_kwargs(**kwargs) -> List[str]:
    """
    kwargs に入ってくる可能性のある examples/batch/samples/data/rows などから
    ground_truth / answer を集める(ログ強化版)
    """
    gts: List[str] = []

    # 1) トップレベルのスカラ/配列を回収
    for sk in ("ground_truth", "answer"):
        sv = kwargs.get(sk)
        if isinstance(sv, str) and sv:
            gts.append(sv)
        elif isinstance(sv, list):
            for x in sv:
                if isinstance(x, str) and x:
                    gts.append(x)

    candidate_keys = [
        "examples", "batch", "batches", "samples", "data", "rows",
        "original_examples", "original_batch", "original_samples",
        "eval_rows", "eval_batch",
    ]
    present = [k for k in kwargs.keys() if isinstance(kwargs.get(k), (list, dict))]
    if present:
        logger.warning(f"[reward] kwargs keys present (container-like): {present}")

    found_count: Dict[str, int] = {}
    def _acc(key: str):
        found_count[key] = found_count.get(key, 0) + 1

    for k in candidate_keys:
        v = kwargs.get(k)
        if isinstance(v, list):
            for ex in v:
                if isinstance(ex, dict):
                    gt = ex.get("ground_truth", ex.get("answer", ""))
                    if isinstance(gt, str) and gt:
                        gts.append(gt); _acc(k)
        elif isinstance(v, dict):
            for vv in v.values():
                if isinstance(vv, list):
                    for ex in vv:
                        if isinstance(ex, dict):
                            gt = ex.get("ground_truth", ex.get("answer", ""))
                            if isinstance(gt, str) and gt:
                                gts.append(gt); _acc(k)

    if found_count:
        logger.warning(f"[DBG] GT sources: {found_count}")
    if gts:
        preview = ", ".join(repr(s) for s in gts[:_MAX_DBG_PAIRS])
        logger.warning(f"[DBG] GT samples (first {_MAX_DBG_PAIRS}): {preview}")

    return gts


def medmcqa_reward_func(
    completions: Iterable[Any],
    ground_truths: Optional[List[str]] = None,
    prompts: Optional[List[str]] = None,
    examples: Optional[List[Dict[str, Any]]] = None,  # 追加: 元のexample全体
    **kwargs: Any
) -> List[float]:
    """
    MedMCQA用の報酬関数
    - 正解の選択肢を選んだ場合: 1.0
    - 不正解だが有効な回答形式: 0.1
    - 無効な回答形式: 0.0
    """
    # ===== 追加: ENTRY ログ =====
    logger.warning(
        "[DBG-ENTRY] pid=%s examples=%s kwargs_keys=%s",
        os.getpid(),
        "None" if examples is None else f"len={len(examples)}",
        list(kwargs.keys())
    )

    rewards: List[float] = []
    completions_list = list(completions)

    # ground_truthsがない場合、examples/kwargsから抽出を試みる(防御的)
    if not ground_truths:
        collected: List[str] = []
        if examples:
            for ex in examples:
                if isinstance(ex, dict):
                    gt = ex.get("ground_truth", ex.get("answer", ""))
                    if isinstance(gt, str) and gt:
                        collected.append(gt)
        collected_kwargs = _collect_ground_truths_from_kwargs(**kwargs)
        if collected_kwargs:
            collected.extend(collected_kwargs)
        if collected:
            ground_truths = collected

    # デバッグログ
    if ground_truths:
        logger.info(f"Ground truths available: {len(ground_truths)} items")
        if len(ground_truths) > 0:
            logger.info(f"First ground truth: '{ground_truths[0]}'")
    else:
        logger.warning("No ground truths available!")

    # ===== 追加: GT 再構築後のログ =====
    logger.warning(
        "[DBG-GT-BUILD] len(gts)=%s first_gt=%s",
        (len(ground_truths) if ground_truths else 0),
        (ground_truths[0] if ground_truths else None),
    )

    # completions と gts の長さ合わせ(安全側)
    if ground_truths:
        if len(ground_truths) < len(completions_list):
            logger.warning(f"[reward] gts shorter than completions: {len(ground_truths)} < {len(completions_list)}; padding with ''")
            ground_truths = ground_truths + [""] * (len(completions_list) - len(ground_truths))
        elif len(ground_truths) > len(completions_list):
            logger.warning(f"[reward] gts longer than completions: trimming {len(ground_truths)} -> {len(completions_list)}")
            ground_truths = ground_truths[:len(completions_list)]

    # --- forループ直前(バッチ単位の情報を1回だけ出す) ---
    rank = 0
    try:
        import torch.distributed as dist  # type: ignore
        if dist.is_available() and dist.is_initialized():
            rank = dist.get_rank()
    except Exception:
        pass

    # ===== 追加: PRE-LOOP ログ(あなた指定の形式) =====
    logger.warning(
        "[DBG-PRE-LOOP] pid=%s len(completions_list)=%s len(ground_truths)=%s examples=%s keys=%s first_gt=%s",
        os.getpid(),
        len(completions_list),
        len(ground_truths) if ground_truths else 0,
        "None" if examples is None else f"len={len(examples)}",
        list(kwargs.keys()),
        ground_truths[0] if ground_truths else None,
    )

    # 既存の要約ログも維持
    logger.warning(
        "[DBG] rank=%s pid=%s len(completions_list)=%s len(ground_truths)=%s first_gt=%s",
        rank, os.getpid(),
        len(completions_list),
        (len(ground_truths) if ground_truths else 0),
        (ground_truths[0] if ground_truths else None),
    )
    if examples:
        logger.warning("[DBG] examples_len=%d first_example_keys=%s",
                       len(examples), list(examples[0].keys()) if isinstance(examples[0], dict) else type(examples[0]).__name__)
    if prompts is not None:
        try:
            logger.warning("[DBG] prompts_len=%d", len(prompts))
        except Exception:
            logger.warning("[DBG] prompts_type=%s", type(prompts).__name__)

    # --- 各completionごとのループ(インデックスを個別に確認) ---
    for i, completion in enumerate(completions_list):
        logger.warning("[DBG] idx=%d / n=%d", i, len(completions_list))

        completion_str = _to_str(completion).strip()

        if not completion_str:
            rewards.append(0.0)
            continue

        predicted_answer = extract_choice_letter(completion_str)

        # 有効な回答が見つからない
        if not predicted_answer:
            rewards.append(0.0)
            logger.debug(f"[format invalid] completion {i}: {completion_str[:120]}...")
            if i < _MAX_DBG_PAIRS:
                logger.warning("[DBG] format_invalid sample[%d]: first_line=%r", i, completion_str.splitlines()[0][:120] if completion_str else "")
            continue

        # 正解チェック
        if ground_truths and i < len(ground_truths):
            correct_answer = (ground_truths[i] or "").strip().upper()
            if i < _MAX_DBG_PAIRS:
                logger.warning("[DBG] pair[%d]: pred=%s gold=%s", i, predicted_answer, correct_answer)
            if correct_answer in ("A", "B", "C", "D"):
                if predicted_answer == correct_answer:
                    rewards.append(1.0)
                    logger.info(f"[CORRECT!] pred={predicted_answer}, gold={correct_answer} @ {i}")
                else:
                    rewards.append(0.1)
                    logger.debug(f"[wrong] pred={predicted_answer}, gold={correct_answer} @ {i}")
            else:
                # goldが無効な場合は形式のみ評価
                rewards.append(0.1)
                logger.warning(f"[gold invalid] index {i}, gold='{ground_truths[i]}'")
        else:
            # goldがない場合は形式のみ評価
            rewards.append(0.1)
            logger.warning(f"[gold missing] completion {i}")

    # メトリクスを更新
    metrics.update(rewards)

    # 定期的にメトリクスをログ出力
    if metrics.total % 100 == 0 and metrics.total > 0:
        stats = metrics.get_stats()
        logger.info(
            f"MedMCQA Metrics - Total: {metrics.total}, "
            f"Accuracy: {stats['accuracy']:.3f}, "
            f"Valid Format Rate: {stats['valid_format_rate']:.3f}"
        )

    return rewards


def medmcqa_transform(cfg: Any, *args: Any, **kwargs: Any) -> Tuple[
    Callable[[Dict[str, Any]], Dict[str, Any]],
    Dict[str, Any]
]:
    """
    MedMCQAデータセット用の変換関数
    最小修正ポイント:
      - テキストのみ返す(input_ids等は返さない)
      - “必要なら”軽量トークナイズで長さチェックし、超過時は左側を切り詰め(左パディング前提)
    """
    def _maybe_shorten_prompt(prompt: str, tokenizer: Any) -> str:
        # YAMLの sequence_len / max_completion_length を環境変数でも上書き可
        seq_len = int(os.getenv("AXO_SEQUENCE_LEN", "16384"))
        max_comp = int(os.getenv("AXO_MAX_COMPLETION", "256"))
        max_input = max(8, seq_len - max_comp)
        if tokenizer is None:
            return prompt
        try:
            toks = tokenizer(
                prompt,
                add_special_tokens=False,
                padding=False,
                truncation=False,
                return_attention_mask=False,
            )
            n = len(toks.get("input_ids", []))
            if n > max_input:
                keep_ratio = max_input / max(1, n)
                est_chars = max(64, int(len(prompt) * keep_ratio))
                trimmed = prompt[-est_chars:].lstrip()
                logger.warning(f"[XFORM-TRIM] input_tokens={n} -> {max_input} (chars {len(prompt)} -> {len(trimmed)})")
                return trimmed
        except Exception as e:
            logger.warning(f"[XFORM-TRIM] tokenizer check failed: {e}")
        return prompt

    def transform_fn(example: Dict[str, Any], tokenizer: Any = None) -> Dict[str, Any]:
        # 問題文の構築
        question = example.get("question", "").strip()

        # 選択肢の取得と整形(無いものはスキップ)
        options = []
        for opt_key, opt_label in [("opa", "A"), ("opb", "B"), ("opc", "C"), ("opd", "D")]:
            opt_text = example.get(opt_key, "").strip()
            if opt_text:
                options.append(f"{opt_label}. {opt_text}")

        # プロンプトの構築(回答形式をより明確に)
        prompt = (
            "You are a medical expert. Please answer the following medical question by selecting the correct option.\n"
            f"Question: {question}\n"
            "Options:\n"
            f"{chr(10).join(options)}\n"
            "Please provide your answer as a single letter (A, B, C, or D) on the first line, followed by your reasoning.\n"
        )

        # 必要なら長さを調整(左側を落として末尾=設問/選択肢/指示を優先して残す)
        prompt = _maybe_shorten_prompt(prompt, tokenizer)

        # 正解の取得(cop: correct option / 0->A, 1->B, 2->C, 3->D)
        cop_value = example.get("cop", -1)
        logger.warning(f"[DBG-XFORM] cop_value={cop_value!r} (type={type(cop_value).__name__})")
        try:
            cop_int = int(cop_value)
        except Exception:
            cop_int = -1

        if 0 <= cop_int <= 3:
            correct_answer = chr(65 + cop_int)  # 0->A, 1->B, 2->C, 3->D
        else:
            correct_answer = ""
            logger.warning(f"Invalid/unknown cop value (expected 0-3): {cop_value!r}")

        # 件名とトピック情報(デバッグやフィルタリング用)
        metadata = {
            "subject": example.get("subject_name", ""),
            "topic": example.get("topic_name", ""),
            "choice_type": example.get("choice_type", "")
        }

        return {
            "prompt": prompt,
            "answer": correct_answer,         # トレーナ側が使う正解
            "ground_truth": correct_answer,   # 報酬関数でも拾えるように重複保持
            "metadata": metadata
        }

    # MedMCQAの元の列を削除(transformで使用した列)
    remove_columns = [
        "question", "opa", "opb", "opc", "opd", "cop",
        "choice_type", "exp", "subject_name", "topic_name"
    ]

    return transform_fn, {"remove_columns": remove_columns}


# デバッグ用: シンプルな定数報酬関数
def const_reward_func(completions: Iterable[Any], value: float = 0.5, **kwargs: Any) -> List[float]:
    """デバッグ用の定数報酬関数"""
    return [float(value) for _ in completions]


# 元の関数も残す(互換性のため)
def completion_transform(cfg: Any, *args: Any, **kwargs: Any) -> Tuple[
    Callable[[Dict[str, Any]], Dict[str, Any]], Dict[str, Any]
]:
    """汎用的なcompletion変換(互換性のため残す)"""
    def transform_fn(example: Dict[str, Any], tokenizer: Any = None) -> Dict[str, Any]:
        prompt_text = (
            example.get("input")
            or example.get("question")
            or example.get("prompt")
            or ""
        )
        item: Dict[str, Any] = {"prompt": prompt_text}
        if example.get("output") is not None:
            item["answer"] = example["output"]
        elif example.get("response") is not None:
            item["answer"] = example["response"]
        return item

    return transform_fn, {"remove_columns": ["messages", "input", "output"]}
PY

2.4. 実行

# 1) 必要な環境変数をこのシェルに定義(例)
export HF_HOME="$HOME/.cache/huggingface"         # 好きな場所でOK
export HF_TOKEN="hf_xxxxxxxx"                     # あなたのHFトークン
export WANDB_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxx"   # あなたのW&B APIキー

JOBID=$(sbatch --parsable --export=ALL,HF_HOME,HF_TOKEN,WANDB_API_KEY,PYTHONPATH \
  /home/Competition2025/P10/P10U001/work/axolotl/axolotl_train_qwen3.sh)
echo "JOBID=$JOBID"

# ログを待ってから追う
LOG="/home/Competition2025/P10/P10U001/slurm/axolotl_train_${JOBID}.log"
until [ -f "$LOG" ] || [ -f "slurm-${JOBID}.out" ]; do sleep 1; done
tail -n +1 -F "${LOG:-slurm-${JOBID}.out}"

まとめ

以上の手順で、報酬関数を調整し、必要に応じてカリキュラム的にハイパラを移行すれば、クラスター事情に合った RL ループを無理なく運用できます。

本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?