はじめに
本稿は Axolotl をマルチノードで動かしつつ、Qwen3-235B-A22B-Thinking-2507 に GRPO / GSPO を適用する実運用手順をまとめたものです。
オンライン強化学習(RL)で必要になる生成、学習の本体は Axolotl×TRL、分散は SLURM+DDP で回します。
まず用語を整理します。GRPO(Group Relative Policy Optimization)は、1プロンプトあたり複数の生成を作り、同一グループ内の相対報酬でアドバンテージを作るRL手法です。
一方、GSPO(Group Sequence Policy Optimization)は “比率・クリッピングをシーケンス単位に持っていく” 形の改良版です。理論的な位置づけと実装の両面で「GRPO をシーケンス粒度に一般化したもの」と理解すると実務に落としやすいです。
本記事では GRPO をベースに、設定差分だけで GSPO に切り替えるところまでを通しで示します。
1. 環境構築
以下のcondaを作成してください。
# ==== 0) 既存envクリア ====
conda deactivate 2>/dev/null || true
rm -rf "/home/Competition2025/P10/P10U001/.conda/envs/axo-fsdp"
conda env remove -n axo-fsdp -y || true
conda clean -afy || true
pip cache purge || true
# ==== 1) Conda環境 ====
conda create -y -n axo-fsdp python=3.11 cmake ninja pkg-config pip git
eval "$($(conda info --base)/bin/conda shell.bash hook)"
conda activate axo-fsdp
# ==== 2) PyTorch 2.7.1 + cu128 固定 ====
pip install -U --extra-index-url https://download.pytorch.org/whl/cu128 \
"torch==2.7.1+cu128" "torchvision==0.22.1+cu128" "torchaudio==2.7.1+cu128"
# ==== 3) CUDA Toolkit (NVCC入り) を conda で導入 ====
# - v12.8 をTorch(=cu128)に合わせて導入
# - 必要なヘッダ/ライブラリと nvcc を同梱
conda install -y -c nvidia -c conda-forge \
"cuda-toolkit=12.8" "cuda-nvcc=12.8"
# 推奨: 明示的に環境変数を通す(DeepSpeed のビルドが拾いやすくなる)
export CUDA_HOME="$CONDA_PREFIX"
export CUDA_PATH="$CONDA_PREFIX"
export PATH="$CONDA_PREFIX/bin:$PATH"
export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:$CONDA_PREFIX/lib64:${LD_LIBRARY_PATH:-}"
# (任意)GPUアーキが分かっていれば指定(例: H100なら 90a)
# export TORCH_CUDA_ARCH_LIST="90;90a"
# ==== 4) Axolotl (extrasで flash-attn / deepspeed / vllm を一括) ====
pip install -U packaging setuptools wheel ninja
set -e
if ! pip install -U "axolotl[flash-attn,deepspeed,vllm]==0.12.2"; then
echo "[WARN] extras 失敗。最小構成へフォールバックします。"
pip install -U "axolotl==0.12.2"
# 高速化/DSを個別に試行(Wheel無い/ビルド不可ならスキップ)
pip install -U "flash-attn>=2.6.0" || true
pip install -U "deepspeed==0.14.4" || true
fi
set +e
# ==== 5) Axolotl が入れない“最低限の追加”のみ ====
# 5-1) 4bit/QLoRA 用
pip install -U "bitsandbytes==0.47.0" "sentencepiece>=0.2.0" "einops>=0.7.0" \
"protobuf>=4.25,<6" "pynvml>=11.5.0" "psutil>=5.9" "safetensors>=0.4.2"
# 5-2) vLLM外部API連携(OpenAI互換クライアント)
pip install -U "openai>=1.40.0" "httpx>=0.27.0" "aiohttp>=3.9" "orjson>=3.10" "uvloop>=0.19"
# 5-3) ログ/可視化(任意)
pip install -U "wandb>=0.17.0" "tensorboard>=2.16"
pip uninstall -y torchvision
pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
# 整合チェック
pip check || true
# ==== 6) 動作確認 ====
echo "[nvcc]"; nvcc --version || true
python - <<'PY'
import torch, importlib.metadata as md, pkgutil, sys, os
def v(n):
try: print(f"{n}: {md.version(n)}")
except: print(f"{n}: not installed")
print("python:", sys.version.split()[0])
print("CUDA runtime (torch):", torch.version.cuda, "| available:", torch.cuda.is_available(), "| nGPU:", torch.cuda.device_count())
print("CUDA_HOME:", os.getenv("CUDA_HOME"))
for p in ["axolotl","transformers","accelerate","datasets","tokenizers","safetensors","huggingface_hub",
"peft","trl","bitsandbytes","flash_attn","deepspeed","vllm","openai","wandb"]:
v(p)
PY
# bnbのCUDA検出(任意)
python -m bitsandbytes | sed -n '1,80p'
2. GRPO/GSPO
2.1. yaml
importance_sampling_level: tokenをsequenceにするとGSPOになります。
VLLM_HOST=10.1.2.3 # ← 単一ノードなら 127.0.0.1 でもOK。マルチノードは実IPに!
cat > /home/Competition2025/P10/P10U001/work/axolotl/axolotl_qwen3_235b_fsdp.yaml <<'EOF'
base_model: /home/Competition2025/P10/shareP10/llm_project/models/Qwen/Qwen3-235B-A22B-Thinking-2507_clone_20251004-133708
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
hub_model_id:
hub_strategy:
push_dataset_to_hub:
hf_use_auth_token: true
# Liger Kernelの設定
plugins:
- axolotl.integrations.liger.LigerPlugin
# liger_cross_entropy: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: false # 生成経路のフックを避ける
# cut_cross_entropy: true
liger_layer_norm: true
load_in_8bit: false
load_in_4bit: true
strict: false
# chat_template は無効化(余計な自動整形を避ける)
# chat_template: tokenizer_default
chat_template:
# === 最小修正(トークン/パディング方針の明示)===
add_bos_token: false # BOS自動付与を止める(テンプレ側と二重化させない)
tokenizer_padding_side: left # 生成のバッチ化は left が安全
tokenizer_pad_token: eos # pad_token を eos と合わせる
# === 最小修正ここまで ===
# ===== RL(GRPOモード)設定 =====
rl: grpo
trl:
use_vllm: false
# vLLM サーバ接続(TRL 側)
vllm_server_host: ${VLLM_HOST}
vllm_server_port: 8000
vllm_server_timeout: 300
# GRPO:トークン単位の重要度比(既定でもtoken)
importance_sampling_level: token
# 生成&最適化ループの基本ハイパラ(必要に応じて調整)
num_generations: 4
generation_batch_size: 4
num_iterations: 4
epsilon: 0.2
epsilon_high: 2.0
temperature: 0.7
top_p: 0.95
# 生成長は 256 に統一
max_completion_length: 256
# 生成前の tokenization / generate の振る舞いを固定(最小修正の要)
tokenizer_kwargs:
padding: true
truncation: true
add_special_tokens: false # ← BOS/EOS の二重化を防ぐ
generation_kwargs:
use_cache: true
pad_token_id: null # ← tokenizer / config から取得
eos_token_id: null # ← tokenizer / config から取得
bos_token_id: null # ← BOS は付けない方針に統一
# 利用する関数(ファイル内の関数名)
reward_funcs:
- "rewards.medmcqa_reward_func"
reward_weights: [1.0]
log_completions: true
# ===== RL設定ここまで =====
# ===== vLLM(外部サーバ利用) =====
vllm:
host: ${VLLM_HOST}
port: 8000
# tensor_parallel_size: 8
# gpu_memory_utilization: 0.85
# dtype: auto
# ===== vLLM設定ここまで =====
# RL用のデータ入力(変換関数を参照)
datasets:
- path: oNo-1/medmcqa_100samples
split: "train[:30%]"
type: rewards.medmcqa_transform
dataset_processes: 128
shuffle_merged_datasets: true
dataset_prepared_path: /home/Competition2025/P10/P10U001/data/oNo-1/medmcqa_100samples
# val_set_size: 0.005
output_dir: /home/Competition2025/P10/P10U001/models/oNo-1/medmcqa_100samples_GRPO
sequence_len: 16384
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false # ← RLではfalse推奨(無駄なパディングを避ける)
adapter: qlora
lora_model_dir:
lora_r: 128
lora_alpha: 256
lora_dropout: 0.05
# lora_target_linear:
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
# - down_proj
# - gate_proj
# - up_proj
lora_modules_to_save:
- embed_tokens
# - lm_head
lora_fan_in_fan_out:
wandb_project: Qwen3-235B-A22B-Thinking-2507
wandb_entity: llm-competition-2025-wandb-01-
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
cosine_min_lr_ratio: 0.01
learning_rate: 1e-5
max_grad_norm: 1.0
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
bfloat16: true
# gradient_checkpointing: true
# gradient_checkpointing_kwargs:
# use_reentrant: true
early_stopping_patience:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
attn_implementation: eager
# HFモデル設定にも明示してSDPAを回避&use_cache既定をTrueに
hf_config:
attn_implementation: eager
use_cache: true
save_strategy: steps
save_steps: 1000
save_total_limit: 2
save_only_model: true
# eval_strategy: steps
# eval_steps: 100
# eval_batch_size: 8
warmup_ratio: 0.03
debug:
# deepspeed: /home/Competition2025/P10/P10U001/work/axolotl/deepspeed_configs/zero3_bf16.json
weight_decay: 0.01
fsdp_version: 1
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3MoeDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_activation_checkpointing: true
qlora_sharded_model_loading: true
ddp_timeout: 180000000
lora_on_cpu: false
EOF
2.2. sbatch
# axolotl_train_qwen3.sh を作成
cat <<'EOF' > /home/Competition2025/P10/P10U001/work/axolotl/axolotl_train_qwen3.sh
#!/bin/bash
#SBATCH --job-name=axolotl_train
#SBATCH --partition=P10
#SBATCH --nodes=4
#SBATCH --gpus-per-node=8
#SBATCH --gpus-per-task=8
#SBATCH --ntasks-per-node=1
#SBATCH --output=/home/Competition2025/P10/P10U001/slurm/%x_%j.log
#SBATCH --error=/home/Competition2025/P10/P10U001/slurm/%x_%j.log
#SBATCH --wait-all-nodes=1
#SBATCH --mem=0
#SBATCH --cpus-per-task=128
set -euxo pipefail
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_DEBUG=INFO
export NCCL_P2P_LEVEL=NVL
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME="enp25s0np0"
export NCCL_BUFFSIZE=2097152
# Hugging Face関連の設定
export HF_HOME=$HF_HOME
export HF_TOKEN=$HF_TOKEN
# export NCCL_BLOCKING_WAIT=1
# export TORCH_NCCL_BLOCKING_WAIT=1
export AXOLOTL_NCCL_TIMEOUT=7200
# W&B設定
export WANDB_API_KEY=$WANDB_API_KEY
export WANDB_ENTITY="aratako-lm"
#export WANDB_PROJECT="axolotl_training"
# 各タスクがマスターノードを見つけられるように環境変数を設定
export MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
export MASTER_PORT=$(( 29000 + SLURM_JOB_ID % 1000 ))
# 設定ファイルのパス(ローカルまたはURL)
export CONFIG_PATH=/home/Competition2025/P10/P10U001/work/axolotl/axolotl_qwen3_235b_fsdp.yaml
export PYTHONPATH="/home/Competition2025/P10/P10U001/work/axolotl:${PYTHONPATH:-}"
export PYTORCH_SDP_DISABLE_FLASH_ATTENTION=1
export PYTORCH_SDP_DISABLE_MEM_EFFICIENT=1
export PYTORCH_SDP_FORCE_FMA=1
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
# P10環境のconda
set +u
source ~/.bashrc || true
# もしくは: source /opt/conda/etc/profile.d/conda.sh
conda activate axo-fsdp
set -u
cd /home/Competition2025/P10/P10U001/work/axolotl
# srun --nodes=${SLURM_JOB_NUM_NODES} --ntasks-per-node=1 \
# bash -c '
# torchrun \
# --nnodes '"${SLURM_JOB_NUM_NODES}"' \
# --nproc_per_node '"${SLURM_GPUS_PER_NODE:-8}"' \
# --node_rank $SLURM_NODEID \
# --rdzv_backend c10d \
# --rdzv_id '"${SLURM_JOB_ID}"' \
# --rdzv_endpoint '"${MASTER_ADDR}:${MASTER_PORT}"' \
# -m axolotl.cli.train '"${CONFIG_PATH}"' \
# --deepspeed /home/Competition2025/P10/P10U001/work/axolotl/deepspeed_configs/zero3_bf16.json
# '
echo "Host: $(hostname)"
echo "Current user limits:"
ulimit -a
echo "Shared memory size:"
df -h /dev/shm
ulimit -v unlimited
ulimit -m unlimited
ulimit -a
srun -l --export=ALL \
bash -c '
set -eux
torchrun \
--nnodes=${SLURM_JOB_NUM_NODES} \
--nproc_per_node=${SLURM_GPUS_PER_NODE:-8} \
--node_rank=${SLURM_NODEID} \
--rdzv_backend=c10d \
--rdzv_id=${SLURM_JOB_ID} \
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
-m axolotl.cli.train ${CONFIG_PATH}
'
EOF
2.3. 報酬関数
cat > /home/Competition2025/P10/P10U001/work/axolotl/rewards.py <<'PY'
from typing import Any, Iterable, List, Tuple, Dict, Callable, Optional
import re
import logging
import os # for PID/rank and env toggles
logger = logging.getLogger(__name__)
# 先頭何件を詳細表示するか(環境変数で上書き可能)
_MAX_DBG_PAIRS = int(os.getenv("REWARD_DBG_PAIRS", "3"))
class MedMCQAMetrics:
"""MedMCQA評価メトリクス追跡用クラス"""
def __init__(self):
self.correct = 0
self.valid_format = 0
self.total = 0
def update(self, rewards: List[float]):
for r in rewards:
if r == 1.0:
self.correct += 1
self.valid_format += 1
elif r > 0.0:
self.valid_format += 1
self.total += 1
def get_stats(self) -> Dict[str, float]:
if self.total == 0:
return {"accuracy": 0.0, "valid_format_rate": 0.0}
return {
"accuracy": self.correct / self.total,
"valid_format_rate": self.valid_format / self.total
}
# グローバルメトリクスインスタンス
metrics = MedMCQAMetrics()
# ===== ユーティリティ:回答文字の抽出を強化 =====
# 代表的な<think>タグを除去
THINK_TAG_RE = re.compile(r"</?\s*think[^>]*>", re.IGNORECASE)
STRICT_PATTERNS = [
# 行頭の「Answer: X」
re.compile(r"(?im)^\s*answer\s*[:\-]\s*\(?\s*([A-D])\s*\)?(?=[\s\.\)\]]|$)"),
# 行頭の「Final Answer: X」
re.compile(r"(?im)^\s*final\s*answer\s*[:\-]\s*\(?\s*([A-D])\s*\)?(?=[\s\.\)\]]|$)"),
# 日本語
re.compile(r"(?im)^\s*(?:答え|最終解答)\s*[::\-]\s*\(?\s*([A-D])\s*\)?(?=[\s\.\)\]]|$)"),
# XML風タグ
re.compile(r"(?is)<\s*final\s*>\s*(?:answer\s*[:\-]\s*)?\(?\s*([A-D])\s*\)?\s*</\s*final\s*>"),
re.compile(r"(?is)<\s*answer\s*>\s*\(?\s*([A-D])\s*\)?\s*</\s*answer\s*>"),
]
FALLBACK_PATTERNS = [
# 行頭が「B.」「C)」など
re.compile(r"(?im)^\s*([A-D])\s*[\.\)](?=\s|$)"),
# 「The correct answer is B」など
re.compile(r"(?i)\b(correct|final)?\s*answer\s*(is|:)?\s*\(?\s*([A-D])\s*\)?\b"),
]
def _to_str(x: Any) -> str:
"""completionの型に依らず文字列を取得"""
if isinstance(x, str):
return x
if isinstance(x, dict):
for k in ("text", "output_text", "response", "content"):
if k in x and isinstance(x[k], str):
return x[k]
# choices 形式
ch = x.get("choices")
if isinstance(ch, list) and ch and isinstance(ch[0], dict):
for k in ("text", "message", "delta"):
if k in ch[0]:
v = ch[0][k]
if isinstance(v, dict):
for kk in ("content", "text"):
if kk in v and isinstance(v[kk], str):
return v[kk]
elif isinstance(v, str):
return v
return str(x)
def extract_choice_letter(text: str) -> Optional[str]:
"""
回答文字を抽出。冒頭に直接「C」などがある場合も対応。
"""
# ノイズタグ除去
t = THINK_TAG_RE.sub("", text).strip()
# まず冒頭の単独文字をチェック(最も一般的なケース)
first_line = t.split('\n')[0].strip()
if len(first_line) == 1 and first_line.upper() in ("A", "B", "C", "D"):
return first_line.upper()
# 冒頭が「C.」「D)」などの形式
if len(first_line) <= 3:
match = re.match(r'^([A-D])[\s\.\)\]\}]?$', first_line, re.IGNORECASE)
if match:
return match.group(1).upper()
# 厳密パターン(推奨形式と一致)
for p in STRICT_PATTERNS:
m = p.search(t)
if m:
g = m.group(1) if m.lastindex and m.lastindex >= 1 else m.group(0)
ch = (g or "").strip().upper()
if ch in ("A", "B", "C", "D"):
return ch
# フォールバック
for p in FALLBACK_PATTERNS:
m = p.search(t)
if m:
for gi in range(1, (m.lastindex or 0) + 1):
g = m.group(gi)
if g and g.strip().upper() in ("A", "B", "C", "D"):
return g.strip().upper()
return None
def _collect_ground_truths_from_kwargs(**kwargs) -> List[str]:
"""
kwargs に入ってくる可能性のある examples/batch/samples/data/rows などから
ground_truth / answer を集める(ログ強化版)
"""
gts: List[str] = []
# 1) トップレベルのスカラ/配列を回収
for sk in ("ground_truth", "answer"):
sv = kwargs.get(sk)
if isinstance(sv, str) and sv:
gts.append(sv)
elif isinstance(sv, list):
for x in sv:
if isinstance(x, str) and x:
gts.append(x)
candidate_keys = [
"examples", "batch", "batches", "samples", "data", "rows",
"original_examples", "original_batch", "original_samples",
"eval_rows", "eval_batch",
]
present = [k for k in kwargs.keys() if isinstance(kwargs.get(k), (list, dict))]
if present:
logger.warning(f"[reward] kwargs keys present (container-like): {present}")
found_count: Dict[str, int] = {}
def _acc(key: str):
found_count[key] = found_count.get(key, 0) + 1
for k in candidate_keys:
v = kwargs.get(k)
if isinstance(v, list):
for ex in v:
if isinstance(ex, dict):
gt = ex.get("ground_truth", ex.get("answer", ""))
if isinstance(gt, str) and gt:
gts.append(gt); _acc(k)
elif isinstance(v, dict):
for vv in v.values():
if isinstance(vv, list):
for ex in vv:
if isinstance(ex, dict):
gt = ex.get("ground_truth", ex.get("answer", ""))
if isinstance(gt, str) and gt:
gts.append(gt); _acc(k)
if found_count:
logger.warning(f"[DBG] GT sources: {found_count}")
if gts:
preview = ", ".join(repr(s) for s in gts[:_MAX_DBG_PAIRS])
logger.warning(f"[DBG] GT samples (first {_MAX_DBG_PAIRS}): {preview}")
return gts
def medmcqa_reward_func(
completions: Iterable[Any],
ground_truths: Optional[List[str]] = None,
prompts: Optional[List[str]] = None,
examples: Optional[List[Dict[str, Any]]] = None, # 追加: 元のexample全体
**kwargs: Any
) -> List[float]:
"""
MedMCQA用の報酬関数
- 正解の選択肢を選んだ場合: 1.0
- 不正解だが有効な回答形式: 0.1
- 無効な回答形式: 0.0
"""
# ===== 追加: ENTRY ログ =====
logger.warning(
"[DBG-ENTRY] pid=%s examples=%s kwargs_keys=%s",
os.getpid(),
"None" if examples is None else f"len={len(examples)}",
list(kwargs.keys())
)
rewards: List[float] = []
completions_list = list(completions)
# ground_truthsがない場合、examples/kwargsから抽出を試みる(防御的)
if not ground_truths:
collected: List[str] = []
if examples:
for ex in examples:
if isinstance(ex, dict):
gt = ex.get("ground_truth", ex.get("answer", ""))
if isinstance(gt, str) and gt:
collected.append(gt)
collected_kwargs = _collect_ground_truths_from_kwargs(**kwargs)
if collected_kwargs:
collected.extend(collected_kwargs)
if collected:
ground_truths = collected
# デバッグログ
if ground_truths:
logger.info(f"Ground truths available: {len(ground_truths)} items")
if len(ground_truths) > 0:
logger.info(f"First ground truth: '{ground_truths[0]}'")
else:
logger.warning("No ground truths available!")
# ===== 追加: GT 再構築後のログ =====
logger.warning(
"[DBG-GT-BUILD] len(gts)=%s first_gt=%s",
(len(ground_truths) if ground_truths else 0),
(ground_truths[0] if ground_truths else None),
)
# completions と gts の長さ合わせ(安全側)
if ground_truths:
if len(ground_truths) < len(completions_list):
logger.warning(f"[reward] gts shorter than completions: {len(ground_truths)} < {len(completions_list)}; padding with ''")
ground_truths = ground_truths + [""] * (len(completions_list) - len(ground_truths))
elif len(ground_truths) > len(completions_list):
logger.warning(f"[reward] gts longer than completions: trimming {len(ground_truths)} -> {len(completions_list)}")
ground_truths = ground_truths[:len(completions_list)]
# --- forループ直前(バッチ単位の情報を1回だけ出す) ---
rank = 0
try:
import torch.distributed as dist # type: ignore
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
except Exception:
pass
# ===== 追加: PRE-LOOP ログ(あなた指定の形式) =====
logger.warning(
"[DBG-PRE-LOOP] pid=%s len(completions_list)=%s len(ground_truths)=%s examples=%s keys=%s first_gt=%s",
os.getpid(),
len(completions_list),
len(ground_truths) if ground_truths else 0,
"None" if examples is None else f"len={len(examples)}",
list(kwargs.keys()),
ground_truths[0] if ground_truths else None,
)
# 既存の要約ログも維持
logger.warning(
"[DBG] rank=%s pid=%s len(completions_list)=%s len(ground_truths)=%s first_gt=%s",
rank, os.getpid(),
len(completions_list),
(len(ground_truths) if ground_truths else 0),
(ground_truths[0] if ground_truths else None),
)
if examples:
logger.warning("[DBG] examples_len=%d first_example_keys=%s",
len(examples), list(examples[0].keys()) if isinstance(examples[0], dict) else type(examples[0]).__name__)
if prompts is not None:
try:
logger.warning("[DBG] prompts_len=%d", len(prompts))
except Exception:
logger.warning("[DBG] prompts_type=%s", type(prompts).__name__)
# --- 各completionごとのループ(インデックスを個別に確認) ---
for i, completion in enumerate(completions_list):
logger.warning("[DBG] idx=%d / n=%d", i, len(completions_list))
completion_str = _to_str(completion).strip()
if not completion_str:
rewards.append(0.0)
continue
predicted_answer = extract_choice_letter(completion_str)
# 有効な回答が見つからない
if not predicted_answer:
rewards.append(0.0)
logger.debug(f"[format invalid] completion {i}: {completion_str[:120]}...")
if i < _MAX_DBG_PAIRS:
logger.warning("[DBG] format_invalid sample[%d]: first_line=%r", i, completion_str.splitlines()[0][:120] if completion_str else "")
continue
# 正解チェック
if ground_truths and i < len(ground_truths):
correct_answer = (ground_truths[i] or "").strip().upper()
if i < _MAX_DBG_PAIRS:
logger.warning("[DBG] pair[%d]: pred=%s gold=%s", i, predicted_answer, correct_answer)
if correct_answer in ("A", "B", "C", "D"):
if predicted_answer == correct_answer:
rewards.append(1.0)
logger.info(f"[CORRECT!] pred={predicted_answer}, gold={correct_answer} @ {i}")
else:
rewards.append(0.1)
logger.debug(f"[wrong] pred={predicted_answer}, gold={correct_answer} @ {i}")
else:
# goldが無効な場合は形式のみ評価
rewards.append(0.1)
logger.warning(f"[gold invalid] index {i}, gold='{ground_truths[i]}'")
else:
# goldがない場合は形式のみ評価
rewards.append(0.1)
logger.warning(f"[gold missing] completion {i}")
# メトリクスを更新
metrics.update(rewards)
# 定期的にメトリクスをログ出力
if metrics.total % 100 == 0 and metrics.total > 0:
stats = metrics.get_stats()
logger.info(
f"MedMCQA Metrics - Total: {metrics.total}, "
f"Accuracy: {stats['accuracy']:.3f}, "
f"Valid Format Rate: {stats['valid_format_rate']:.3f}"
)
return rewards
def medmcqa_transform(cfg: Any, *args: Any, **kwargs: Any) -> Tuple[
Callable[[Dict[str, Any]], Dict[str, Any]],
Dict[str, Any]
]:
"""
MedMCQAデータセット用の変換関数
最小修正ポイント:
- テキストのみ返す(input_ids等は返さない)
- “必要なら”軽量トークナイズで長さチェックし、超過時は左側を切り詰め(左パディング前提)
"""
def _maybe_shorten_prompt(prompt: str, tokenizer: Any) -> str:
# YAMLの sequence_len / max_completion_length を環境変数でも上書き可
seq_len = int(os.getenv("AXO_SEQUENCE_LEN", "16384"))
max_comp = int(os.getenv("AXO_MAX_COMPLETION", "256"))
max_input = max(8, seq_len - max_comp)
if tokenizer is None:
return prompt
try:
toks = tokenizer(
prompt,
add_special_tokens=False,
padding=False,
truncation=False,
return_attention_mask=False,
)
n = len(toks.get("input_ids", []))
if n > max_input:
keep_ratio = max_input / max(1, n)
est_chars = max(64, int(len(prompt) * keep_ratio))
trimmed = prompt[-est_chars:].lstrip()
logger.warning(f"[XFORM-TRIM] input_tokens={n} -> {max_input} (chars {len(prompt)} -> {len(trimmed)})")
return trimmed
except Exception as e:
logger.warning(f"[XFORM-TRIM] tokenizer check failed: {e}")
return prompt
def transform_fn(example: Dict[str, Any], tokenizer: Any = None) -> Dict[str, Any]:
# 問題文の構築
question = example.get("question", "").strip()
# 選択肢の取得と整形(無いものはスキップ)
options = []
for opt_key, opt_label in [("opa", "A"), ("opb", "B"), ("opc", "C"), ("opd", "D")]:
opt_text = example.get(opt_key, "").strip()
if opt_text:
options.append(f"{opt_label}. {opt_text}")
# プロンプトの構築(回答形式をより明確に)
prompt = (
"You are a medical expert. Please answer the following medical question by selecting the correct option.\n"
f"Question: {question}\n"
"Options:\n"
f"{chr(10).join(options)}\n"
"Please provide your answer as a single letter (A, B, C, or D) on the first line, followed by your reasoning.\n"
)
# 必要なら長さを調整(左側を落として末尾=設問/選択肢/指示を優先して残す)
prompt = _maybe_shorten_prompt(prompt, tokenizer)
# 正解の取得(cop: correct option / 0->A, 1->B, 2->C, 3->D)
cop_value = example.get("cop", -1)
logger.warning(f"[DBG-XFORM] cop_value={cop_value!r} (type={type(cop_value).__name__})")
try:
cop_int = int(cop_value)
except Exception:
cop_int = -1
if 0 <= cop_int <= 3:
correct_answer = chr(65 + cop_int) # 0->A, 1->B, 2->C, 3->D
else:
correct_answer = ""
logger.warning(f"Invalid/unknown cop value (expected 0-3): {cop_value!r}")
# 件名とトピック情報(デバッグやフィルタリング用)
metadata = {
"subject": example.get("subject_name", ""),
"topic": example.get("topic_name", ""),
"choice_type": example.get("choice_type", "")
}
return {
"prompt": prompt,
"answer": correct_answer, # トレーナ側が使う正解
"ground_truth": correct_answer, # 報酬関数でも拾えるように重複保持
"metadata": metadata
}
# MedMCQAの元の列を削除(transformで使用した列)
remove_columns = [
"question", "opa", "opb", "opc", "opd", "cop",
"choice_type", "exp", "subject_name", "topic_name"
]
return transform_fn, {"remove_columns": remove_columns}
# デバッグ用: シンプルな定数報酬関数
def const_reward_func(completions: Iterable[Any], value: float = 0.5, **kwargs: Any) -> List[float]:
"""デバッグ用の定数報酬関数"""
return [float(value) for _ in completions]
# 元の関数も残す(互換性のため)
def completion_transform(cfg: Any, *args: Any, **kwargs: Any) -> Tuple[
Callable[[Dict[str, Any]], Dict[str, Any]], Dict[str, Any]
]:
"""汎用的なcompletion変換(互換性のため残す)"""
def transform_fn(example: Dict[str, Any], tokenizer: Any = None) -> Dict[str, Any]:
prompt_text = (
example.get("input")
or example.get("question")
or example.get("prompt")
or ""
)
item: Dict[str, Any] = {"prompt": prompt_text}
if example.get("output") is not None:
item["answer"] = example["output"]
elif example.get("response") is not None:
item["answer"] = example["response"]
return item
return transform_fn, {"remove_columns": ["messages", "input", "output"]}
PY
2.4. 実行
# 1) 必要な環境変数をこのシェルに定義(例)
export HF_HOME="$HOME/.cache/huggingface" # 好きな場所でOK
export HF_TOKEN="hf_xxxxxxxx" # あなたのHFトークン
export WANDB_API_KEY="xxxxxxxxxxxxxxxxxxxxxxxx" # あなたのW&B APIキー
JOBID=$(sbatch --parsable --export=ALL,HF_HOME,HF_TOKEN,WANDB_API_KEY,PYTHONPATH \
/home/Competition2025/P10/P10U001/work/axolotl/axolotl_train_qwen3.sh)
echo "JOBID=$JOBID"
# ログを待ってから追う
LOG="/home/Competition2025/P10/P10U001/slurm/axolotl_train_${JOBID}.log"
until [ -f "$LOG" ] || [ -f "slurm-${JOBID}.out" ]; do sleep 1; done
tail -n +1 -F "${LOG:-slurm-${JOBID}.out}"
まとめ
以上の手順で、報酬関数を調整し、必要に応じてカリキュラム的にハイパラを移行すれば、クラスター事情に合った RL ループを無理なく運用できます。
本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。