※下記のコードはチームメンバーのある方(後日公開予定)と協力し作成したものになります。
概要
SWIFTフレームワークを使用してQLoRA+GRPO(Group Relative Policy Optimization)による強化学習を実行する手順をまとめています。ByteDance社のverlではMoEモデルにおけるLoRA対応がされていないようでしたので、Qwen3-235BでのGSPOを1ノードで動かすために検証しました。ただし、本コードは動作検証のためLlama-32.B-1B-Instructを使用しています。
参考文献)
- DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models (arXiv:2402.03300)
- ms-swift GRPO
前提条件
- H100 1ノード(8台)
1. 計算ノードの確保
srun --partition ××× --nodes=1 --gpus-per-node=8 --cpus-per-task=240 --time=02:00:00 --job-name="GSPO_training" --pty bash -i
2. 環境構築
2.1 Conda環境の作成
# Conda環境作成
conda create -n swift_grpo python=3.10 -y
conda activate swift_grpo
# 必要なパッケージのインストール
pip install ms-swift[llm]==3.8.1
pip install transformers==4.56.2
pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121
pip install accelerate==1.10.1
pip install peft==0.17.1
pip install bitsandbytes==0.47.0
pip install trl==0.20.0
pip install math_verify==0.5.2
pip install pyarrow
pip install pandas
2.2 作業ディレクトリの作成
# 個人の実験ディレクトリを作成
mkdir -p ~/swift_grpo_workspace/{data,outputs,logs}
cd ~/swift_grpo_workspace
3. 報酬関数の実装
報酬スコアは下記のように設定した。
- 正解なら
1.0 - 不正解だが数字を抽出できたら
0.1 - 数字を抽出できなかったら
0.5 - データセットの正解が見つからなかったら
0.0
# math_reward_function.py を最終解決版で上書き
cat > math_reward_function.py << 'EOF'
import re
from typing import List, Dict, Any
class CorrectnessRewardFunction:
def __init__(self, dataset: List[Dict[str, Any]]):
self.prompt_answer_map = self._create_prompt_answer_map(dataset)
print(f"\n--- INFO: (Final Forgiving Version) prompt_answer_map created. Count: {len(self.prompt_answer_map)} ---\n")
def _create_prompt_answer_map(self, dataset: List[Dict[str, Any]]) -> Dict[str, str]:
prompt_map = {}
for item in dataset:
if 'messages' in item and isinstance(item['messages'], list):
user_messages = [msg['content'] for msg in item['messages'] if msg.get('role') == 'user']
answer = item.get('extra_info', {}).get('answer', '')
if user_messages and answer:
prompt_text = user_messages[-1]
prompt_map[prompt_text] = answer
return prompt_map
def __call__(self, completions: List[str], **kwargs) -> List[float]:
prompts = []
if 'messages' in kwargs:
for message_list in kwargs['messages']:
user_content = [msg['content'] for msg in message_list if msg.get('role') == 'user']
if user_content:
prompts.append(user_content[-1])
rewards = []
if not prompts or len(prompts) != len(completions):
return [0.0] * len(completions)
for i, completion in enumerate(completions):
current_prompt = prompts[i]
gt_full_string = self.prompt_answer_map.get(current_prompt)
if gt_full_string is None:
rewards.append(0.0)
continue
generated_answer_str = self._extract_final_answer(completion)
ground_truth_answer_str = self._extract_final_answer(gt_full_string)
reward = self._calculate_reward(generated_answer_str, ground_truth_answer_str)
rewards.append(reward)
return rewards
def _extract_final_answer(self, text: str) -> str:
if not isinstance(text, str):
return None
sanitized_text = str(text).replace('$', '').replace(',', '')
# 1. まず「#### <数字>」の正しい形式を探す
match = re.search(r'####\s*([+-]?\d+\.?\d*)', sanitized_text)
if match and match.group(1):
return match.group(1)
# 2. 正しい形式が見つからない場合、フォールバックとして文章中の「最後の数字」を探す
all_numbers = re.findall(r'([+-]?\d+\.?\d+)', sanitized_text)
if all_numbers:
return all_numbers[-1] # 最後に見つかった数字を返す
# 3. それでも見つからなければ None を返す
return None
def _calculate_reward(self, generated_answer: str, ground_truth_answer: str) -> float:
if generated_answer is None: return -0.5
if ground_truth_answer is None: return 0.0
try:
gen_num = float(generated_answer)
gt_num = float(ground_truth_answer)
if abs(gen_num - gt_num) < 1e-4: return 1.0
else: return 0.1
except (ValueError, TypeError): return -0.5
def get_reward_fn(dataset_path: str):
from datasets import load_dataset
train_dataset = load_dataset('json', data_files=dataset_path, split='train')
return CorrectnessRewardFunction(list(train_dataset))
EOF
4. データセットの前処理
データは練習用としてGSM8Kを用いています。
データセットの変換
cat > convert_to_jsonl.py << 'EOF'
import pandas as pd
import os
parquet_path = '/home/llm_project/data/raw/gsm8k/train.parquet'
jsonl_path = 'data/train.jsonl'
print(f"Reading Parquet file from: {parquet_path}")
try:
df = pd.read_parquet(parquet_path)
print("Successfully read Parquet file.")
if 'prompt' in df.columns:
df.rename(columns={'prompt': 'messages'}, inplace=True)
print("Renamed 'prompt' column to 'messages'.")
else:
print("Warning: 'prompt' column not found. No columns were renamed.")
os.makedirs(os.path.dirname(jsonl_path), exist_ok=True)
print(f"Saving new structured data to: {jsonl_path}")
df.to_json(jsonl_path, orient='records', lines=True, force_ascii=False)
print("Conversion complete! ")
print(f"New dataset is available at: {os.path.abspath(jsonl_path)}")
except Exception as e:
print(f"An error occurred: {e}")
EOF
実行
python convert_to_jsonl.py
データセットのチェック
cat > check_dataset.py << 'EOF'
from datasets import load_dataset
import json
jsonl_path = 'data/train.jsonl'
print(f"Checking dataset format for: {jsonl_path}")
try:
# データセットを読み込む
dataset = load_dataset('json', data_files=jsonl_path)
# 最初の1件を取得して表示
first_example = dataset['train'][0]
print("\n✅ Dataset loaded successfully!")
print("\n--- First Example ---")
# 見やすいように整形して表示
print(json.dumps(first_example, indent=2, ensure_ascii=False))
# 重要なキーの存在チェック
if 'messages' in first_example and isinstance(first_example['messages'], list):
print("\n✅ 'messages' key exists and is a list.")
message = first_example['messages'][0]
if 'role' in message and 'content' in message:
print("✅ 'role' and 'content' keys exist in the first message.")
else:
print("\n❌ ERROR: 'role' or 'content' key is missing in the message.")
else:
print("\n❌ ERROR: 'messages' key is missing or is not a list.")
except Exception as e:
print(f"\n❌ An error occurred during dataset check: {e}")
EOF
実行
python check_dataset.py
5. GRPO実行
モデルは検証用にLlama-32.B-1B-Instructを使用しています。
# 実行スクリプトを本番用に更新
cat > run_grpo_with_llama.sh << 'EOF'
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
export MODEL_PATH="/home/llm_project/models/base/Llama-3.2-1B-Instruct"
export OUTPUT_DIR="outputs/grpo_llama_final_run_$(date +%Y%m%d_%H%M%S)"
export ABSOLUTE_DATA_PATH="$HOME/swift_grpo_workspace/data/train.jsonl"
mkdir -p "$OUTPUT_DIR"
python - << 'PYTHON_EOF'
import os
from swift.llm import RLHFArguments, rlhf_main
from math_reward_function import get_reward_fn
dataset_path = os.environ.get('ABSOLUTE_DATA_PATH')
output_dir = os.environ.get('OUTPUT_DIR')
model_path = os.environ.get('MODEL_PATH')
print("Initializing final reward function...")
reward_function = get_reward_fn(dataset_path)
args = RLHFArguments(
rlhf_type='grpo',
model=model_path,
dataset=dataset_path,
output_dir=output_dir,
train_type='lora',
lora_rank=8,
lora_alpha=16,
fp16=True,
bf16=False,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
max_steps=100,
num_generations=2,
generation_batch_size=2,
learning_rate=1e-5,
max_length=1024,
max_completion_length=512,
logging_steps=1,
save_steps=20,
save_total_limit=3,
save_only_model=True,
reward_funcs=[reward_function],
reward_weights=[1.0],
gradient_checkpointing=True,
)
print(f"Starting GRPO training with model: {model_path}")
print(f"Dataset: {dataset_path}")
rlhf_main(args)
PYTHON_EOF
EOF
# スクリプトに実行権限を付与
chmod +x run_grpo_with_reward.sh
実行
# SLURMジョブとして実行(2時間)
sbatch -J grpo_qwen235b -p P10 -t 2:00:00 --gres=gpu:8 run_grpo_with_reward.sh
# または直接実行(インタラクティブノードで)
./run_grpo_with_reward.sh
GRPO実行中のlogを見ると、verlの場合と違い、報酬スコアが向上していく様子が見られず、当たったり間違えたりを繰り返していました。そのため、本コードは動作するものの、改良が必要である認識です。
参考)報酬スコアが0になる原因を特定するためのコード
初期に実装した報酬関数では報酬スコアが0になり続けるという問題が起きていました。問題の原因を特定するために、下記のコードを使って入力と生成した出力、スコア付けするために処理した結果などをlogに吐き出させました。結果、報酬関数にテキスト形式ではない形式で渡してしまっていたことが原因と判明し、上記のコードのように修正しました。
verify_reward_function.py:報酬関数のデバック用コード
# verify_reward_function.py を詳細トレース用に更新
cat > verify_reward_function.py << 'EOF'
import re
from typing import List, Dict, Any
class CorrectnessRewardFunction:
def __init__(self, dataset: List[Dict[str, Any]]):
self.prompt_answer_map = self._create_prompt_answer_map(dataset)
print(f"\n--- INFO: (Detailed Trace Verification) Reward function initialized. ---\n")
def _create_prompt_answer_map(self, dataset: List[Dict[str, Any]]) -> Dict[str, str]:
prompt_map = {}
for item in dataset:
if 'messages' in item and isinstance(item['messages'], list):
user_messages = [msg['content'] for msg in item['messages'] if msg.get('role') == 'user']
answer = item.get('extra_info', {}).get('answer', '')
if user_messages and answer:
prompt_text = user_messages[-1]
prompt_map[prompt_text] = answer
return prompt_map
def __call__(self, completions: List[str], **kwargs) -> List[float]:
# 1件目のデータで全プロセスをトレースする
trace_log = ["\n\n====================== REWARD CALCULATION TRACE ======================\n"]
# Step 1: プロンプト抽出
prompt = "EXTRACTION FAILED"
if 'messages' in kwargs:
message_list = kwargs['messages'][0]
user_content = [msg['content'] for msg in message_list if msg.get('role') == 'user']
if user_content:
prompt = user_content[-1]
trace_log.append(f"--- [Step 1] Extracted Prompt ---\n'{prompt}'\n")
# Step 2: モデル生成文の取得
completion = completions[0]
trace_log.append(f"--- [Step 2] Model Completion Raw Text ---\n'{completion}'\n")
# Step 3: 正解データの検索
ground_truth_full_string = self.prompt_answer_map.get(prompt)
trace_log.append(f"--- [Step 3] Ground Truth Lookup ---\nSuccessful: {ground_truth_full_string is not None}\nFound: '{ground_truth_full_string}'\n")
# Step 4a: モデル生成文から答えを抽出
trace_log.append(f"--- [Step 4a] Extracting Answer from Model Completion ---")
trace_log.append(f"Calling _extract_final_answer with:\n'{completion}'")
generated_answer = self._extract_final_answer(completion)
trace_log.append(f"Result -> '{generated_answer}'\n")
# Step 4b: 正解データから答えを抽出
trace_log.append(f"--- [Step 4b] Extracting Answer from Ground Truth ---")
trace_log.append(f"Calling _extract_final_answer with:\n'{ground_truth_full_string}'")
ground_truth_answer = self._extract_final_answer(ground_truth_full_string)
trace_log.append(f"Result -> '{ground_truth_answer}'\n")
# Step 5: 最終スコアの計算
trace_log.append(f"--- [Step 5] Calculating Final Score ---")
trace_log.append(f"Calling _calculate_reward with generated_answer='{generated_answer}' and ground_truth_answer='{ground_truth_answer}'")
final_reward = self._calculate_reward(generated_answer, ground_truth_answer)
trace_log.append(f"Result -> {final_reward}\n")
trace_log.append("=========================== END OF TRACE ===========================\n")
raise ValueError("".join(trace_log))
def _extract_final_answer(self, text: str) -> str:
if not isinstance(text, str): return None
# $記号とカンマに対応した正規表現
match = re.search(r'####\s*([+-]?[\d,]*\.?\d+)', str(text).replace('$', ''))
if match:
return match.group(1).replace(',', '')
return None
def _calculate_reward(self, generated_answer: str, ground_truth_answer: str) -> float:
if generated_answer is None: return -0.5
if ground_truth_answer is None: return 0.0
try:
gen_num = float(generated_answer)
gt_num = float(ground_truth_answer)
if abs(gen_num - gt_num) < 1e-4: return 1.0
else: return 0.1
except (ValueError, TypeError): return -0.5
def get_reward_fn(dataset_path: str):
from datasets import load_dataset
train_dataset = load_dataset('json', data_files=dataset_path, split='train')
return CorrectnessRewardFunction(list(train_dataset))
EOF
run_grpo_with_llama.sh:上記のコードを用いてGRPOを実行するコード
cat > run_grpo_with_llama.sh << 'EOF'
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
export MODEL_PATH="/home/llm_project/models/base/Llama-3.2-1B-Instruct"
export OUTPUT_DIR="outputs/grpo_llama_detailed_verify_$(date +%Y%m%d_%H%M%S)"
export ABSOLUTE_DATA_PATH="$HOME/swift_grpo_workspace/data/train.jsonl"
mkdir -p "$OUTPUT_DIR"
python - << 'PYTHON_EOF'
import os
from swift.llm import RLHFArguments, rlhf_main
from verify_reward_function import get_reward_fn
dataset_path = os.environ.get('ABSOLUTE_DATA_PATH')
output_dir = os.environ.get('OUTPUT_DIR')
model_path = os.environ.get('MODEL_PATH')
print("Initializing detailed trace verification...")
reward_function = get_reward_fn(dataset_path)
args = RLHFArguments(
rlhf_type='grpo',
model=model_path,
dataset=dataset_path,
output_dir=output_dir,
train_type='lora',
lora_rank=8,
lora_alpha=16,
fp16=True,
bf16=False,
per_device_train_batch_size=1,
gradient_accumulation_steps=2,
max_steps=10,
num_generations=2,
generation_batch_size=2,
learning_rate=1e-5,
max_length=1024,
max_completion_length=512,
logging_steps=1,
reward_funcs=[reward_function],
reward_weights=[1.0],
gradient_checkpointing=True,
)
print(f"Starting GRPO training with model: {model_path}")
print(f"Dataset: {dataset_path}")
rlhf_main(args)
PYTHON_EOF
本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われました。