0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

GRPO (Group Relative Policy Optimization) をMS-SWIFTを使って実装する

Last updated at Posted at 2025-10-31

※下記のコードはチームメンバーのある方(後日公開予定)と協力し作成したものになります。

概要

SWIFTフレームワークを使用してQLoRA+GRPO(Group Relative Policy Optimization)による強化学習を実行する手順をまとめています。ByteDance社のverlではMoEモデルにおけるLoRA対応がされていないようでしたので、Qwen3-235BでのGSPOを1ノードで動かすために検証しました。ただし、本コードは動作検証のためLlama-32.B-1B-Instructを使用しています。

参考文献)

前提条件

  • H100 1ノード(8台)

1. 計算ノードの確保

srun --partition ××× --nodes=1 --gpus-per-node=8 --cpus-per-task=240 --time=02:00:00 --job-name="GSPO_training" --pty bash -i

2. 環境構築

2.1 Conda環境の作成

# Conda環境作成
conda create -n swift_grpo python=3.10 -y
conda activate swift_grpo

# 必要なパッケージのインストール
pip install ms-swift[llm]==3.8.1
pip install transformers==4.56.2
pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121
pip install accelerate==1.10.1
pip install peft==0.17.1
pip install bitsandbytes==0.47.0
pip install trl==0.20.0
pip install math_verify==0.5.2
pip install pyarrow
pip install pandas

2.2 作業ディレクトリの作成


# 個人の実験ディレクトリを作成
mkdir -p ~/swift_grpo_workspace/{data,outputs,logs}
cd ~/swift_grpo_workspace

3. 報酬関数の実装

報酬スコアは下記のように設定した。

  • 正解なら 1.0
  • 不正解だが数字を抽出できたら 0.1
  • 数字を抽出できなかったら 0.5
  • データセットの正解が見つからなかったら 0.0
# math_reward_function.py を最終解決版で上書き
cat > math_reward_function.py << 'EOF'
import re
from typing import List, Dict, Any

class CorrectnessRewardFunction:
    def __init__(self, dataset: List[Dict[str, Any]]):
        self.prompt_answer_map = self._create_prompt_answer_map(dataset)
        print(f"\n--- INFO: (Final Forgiving Version) prompt_answer_map created. Count: {len(self.prompt_answer_map)} ---\n")

    def _create_prompt_answer_map(self, dataset: List[Dict[str, Any]]) -> Dict[str, str]:
        prompt_map = {}
        for item in dataset:
            if 'messages' in item and isinstance(item['messages'], list):
                user_messages = [msg['content'] for msg in item['messages'] if msg.get('role') == 'user']
                answer = item.get('extra_info', {}).get('answer', '')
                if user_messages and answer:
                    prompt_text = user_messages[-1]
                    prompt_map[prompt_text] = answer
        return prompt_map

    def __call__(self, completions: List[str], **kwargs) -> List[float]:
        prompts = []
        if 'messages' in kwargs:
            for message_list in kwargs['messages']:
                user_content = [msg['content'] for msg in message_list if msg.get('role') == 'user']
                if user_content:
                    prompts.append(user_content[-1])

        rewards = []
        if not prompts or len(prompts) != len(completions):
            return [0.0] * len(completions)

        for i, completion in enumerate(completions):
            current_prompt = prompts[i]
            gt_full_string = self.prompt_answer_map.get(current_prompt)

            if gt_full_string is None:
                rewards.append(0.0)
                continue

            generated_answer_str = self._extract_final_answer(completion)
            ground_truth_answer_str = self._extract_final_answer(gt_full_string)
            reward = self._calculate_reward(generated_answer_str, ground_truth_answer_str)
            rewards.append(reward)

        return rewards

    def _extract_final_answer(self, text: str) -> str:
        if not isinstance(text, str):
            return None
        
        sanitized_text = str(text).replace('$', '').replace(',', '')
        
        # 1. まず「#### <数字>」の正しい形式を探す
        match = re.search(r'####\s*([+-]?\d+\.?\d*)', sanitized_text)
        if match and match.group(1):
            return match.group(1)

        # 2. 正しい形式が見つからない場合、フォールバックとして文章中の「最後の数字」を探す
        all_numbers = re.findall(r'([+-]?\d+\.?\d+)', sanitized_text)
        if all_numbers:
            return all_numbers[-1] # 最後に見つかった数字を返す

        # 3. それでも見つからなければ None を返す
        return None

    def _calculate_reward(self, generated_answer: str, ground_truth_answer: str) -> float:
        if generated_answer is None: return -0.5
        if ground_truth_answer is None: return 0.0
        try:
            gen_num = float(generated_answer)
            gt_num = float(ground_truth_answer)
            if abs(gen_num - gt_num) < 1e-4: return 1.0
            else: return 0.1
        except (ValueError, TypeError): return -0.5

def get_reward_fn(dataset_path: str):
    from datasets import load_dataset
    train_dataset = load_dataset('json', data_files=dataset_path, split='train')
    return CorrectnessRewardFunction(list(train_dataset))
EOF

4. データセットの前処理

データは練習用としてGSM8Kを用いています。

データセットの変換

cat > convert_to_jsonl.py << 'EOF'
import pandas as pd
import os

parquet_path = '/home/llm_project/data/raw/gsm8k/train.parquet'
jsonl_path = 'data/train.jsonl'

print(f"Reading Parquet file from: {parquet_path}")

try:
    df = pd.read_parquet(parquet_path)
    print("Successfully read Parquet file.")
    
    if 'prompt' in df.columns:
        df.rename(columns={'prompt': 'messages'}, inplace=True)
        print("Renamed 'prompt' column to 'messages'.")
    else:
        print("Warning: 'prompt' column not found. No columns were renamed.")

    os.makedirs(os.path.dirname(jsonl_path), exist_ok=True)

    print(f"Saving new structured data to: {jsonl_path}")
    df.to_json(jsonl_path, orient='records', lines=True, force_ascii=False)

    print("Conversion complete! ")
    print(f"New dataset is available at: {os.path.abspath(jsonl_path)}")

except Exception as e:
    print(f"An error occurred: {e}")

EOF

実行

python convert_to_jsonl.py

データセットのチェック

cat > check_dataset.py << 'EOF'
from datasets import load_dataset
import json

jsonl_path = 'data/train.jsonl'

print(f"Checking dataset format for: {jsonl_path}")

try:
    # データセットを読み込む
    dataset = load_dataset('json', data_files=jsonl_path)
    
    # 最初の1件を取得して表示
    first_example = dataset['train'][0]
    
    print("\n✅ Dataset loaded successfully!")
    print("\n--- First Example ---")
    # 見やすいように整形して表示
    print(json.dumps(first_example, indent=2, ensure_ascii=False))
    
    # 重要なキーの存在チェック
    if 'messages' in first_example and isinstance(first_example['messages'], list):
        print("\n✅ 'messages' key exists and is a list.")
        message = first_example['messages'][0]
        if 'role' in message and 'content' in message:
            print("✅ 'role' and 'content' keys exist in the first message.")
        else:
            print("\n❌ ERROR: 'role' or 'content' key is missing in the message.")
    else:
        print("\n❌ ERROR: 'messages' key is missing or is not a list.")

except Exception as e:
    print(f"\n❌ An error occurred during dataset check: {e}")

EOF

実行

python check_dataset.py

5. GRPO実行

モデルは検証用にLlama-32.B-1B-Instructを使用しています。

# 実行スクリプトを本番用に更新
cat > run_grpo_with_llama.sh << 'EOF'
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH="${PYTHONPATH}:$(pwd)"

export MODEL_PATH="/home/llm_project/models/base/Llama-3.2-1B-Instruct"
export OUTPUT_DIR="outputs/grpo_llama_final_run_$(date +%Y%m%d_%H%M%S)"
export ABSOLUTE_DATA_PATH="$HOME/swift_grpo_workspace/data/train.jsonl"

mkdir -p "$OUTPUT_DIR"

python - << 'PYTHON_EOF'
import os
from swift.llm import RLHFArguments, rlhf_main
from math_reward_function import get_reward_fn

dataset_path = os.environ.get('ABSOLUTE_DATA_PATH')
output_dir = os.environ.get('OUTPUT_DIR')
model_path = os.environ.get('MODEL_PATH')

print("Initializing final reward function...")
reward_function = get_reward_fn(dataset_path)

args = RLHFArguments(
    rlhf_type='grpo',
    model=model_path,
    dataset=dataset_path,
    output_dir=output_dir,
    train_type='lora',
    lora_rank=8,
    lora_alpha=16,
    fp16=True,
    bf16=False,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    max_steps=100,
    num_generations=2,
    generation_batch_size=2,
    learning_rate=1e-5,
    max_length=1024,
    max_completion_length=512,
    logging_steps=1,
    save_steps=20,
    save_total_limit=3,
    save_only_model=True,
    reward_funcs=[reward_function],
    reward_weights=[1.0],
    gradient_checkpointing=True,
)

print(f"Starting GRPO training with model: {model_path}")
print(f"Dataset: {dataset_path}")
rlhf_main(args)
PYTHON_EOF
EOF

# スクリプトに実行権限を付与
chmod +x run_grpo_with_reward.sh

実行

# SLURMジョブとして実行(2時間)
sbatch -J grpo_qwen235b -p P10 -t 2:00:00 --gres=gpu:8 run_grpo_with_reward.sh

# または直接実行(インタラクティブノードで)
./run_grpo_with_reward.sh

GRPO実行中のlogを見ると、verlの場合と違い、報酬スコアが向上していく様子が見られず、当たったり間違えたりを繰り返していました。そのため、本コードは動作するものの、改良が必要である認識です。

参考)報酬スコアが0になる原因を特定するためのコード

初期に実装した報酬関数では報酬スコアが0になり続けるという問題が起きていました。問題の原因を特定するために、下記のコードを使って入力と生成した出力、スコア付けするために処理した結果などをlogに吐き出させました。結果、報酬関数にテキスト形式ではない形式で渡してしまっていたことが原因と判明し、上記のコードのように修正しました。

verify_reward_function.py:報酬関数のデバック用コード

# verify_reward_function.py を詳細トレース用に更新
cat > verify_reward_function.py << 'EOF'
import re
from typing import List, Dict, Any

class CorrectnessRewardFunction:
    def __init__(self, dataset: List[Dict[str, Any]]):
        self.prompt_answer_map = self._create_prompt_answer_map(dataset)
        print(f"\n--- INFO: (Detailed Trace Verification) Reward function initialized. ---\n")

    def _create_prompt_answer_map(self, dataset: List[Dict[str, Any]]) -> Dict[str, str]:
        prompt_map = {}
        for item in dataset:
            if 'messages' in item and isinstance(item['messages'], list):
                user_messages = [msg['content'] for msg in item['messages'] if msg.get('role') == 'user']
                answer = item.get('extra_info', {}).get('answer', '')
                if user_messages and answer:
                    prompt_text = user_messages[-1]
                    prompt_map[prompt_text] = answer
        return prompt_map

    def __call__(self, completions: List[str], **kwargs) -> List[float]:
        # 1件目のデータで全プロセスをトレースする
        trace_log = ["\n\n====================== REWARD CALCULATION TRACE ======================\n"]

        # Step 1: プロンプト抽出
        prompt = "EXTRACTION FAILED"
        if 'messages' in kwargs:
            message_list = kwargs['messages'][0]
            user_content = [msg['content'] for msg in message_list if msg.get('role') == 'user']
            if user_content:
                prompt = user_content[-1]
        trace_log.append(f"--- [Step 1] Extracted Prompt ---\n'{prompt}'\n")

        # Step 2: モデル生成文の取得
        completion = completions[0]
        trace_log.append(f"--- [Step 2] Model Completion Raw Text ---\n'{completion}'\n")

        # Step 3: 正解データの検索
        ground_truth_full_string = self.prompt_answer_map.get(prompt)
        trace_log.append(f"--- [Step 3] Ground Truth Lookup ---\nSuccessful: {ground_truth_full_string is not None}\nFound: '{ground_truth_full_string}'\n")

        # Step 4a: モデル生成文から答えを抽出
        trace_log.append(f"--- [Step 4a] Extracting Answer from Model Completion ---")
        trace_log.append(f"Calling _extract_final_answer with:\n'{completion}'")
        generated_answer = self._extract_final_answer(completion)
        trace_log.append(f"Result -> '{generated_answer}'\n")

        # Step 4b: 正解データから答えを抽出
        trace_log.append(f"--- [Step 4b] Extracting Answer from Ground Truth ---")
        trace_log.append(f"Calling _extract_final_answer with:\n'{ground_truth_full_string}'")
        ground_truth_answer = self._extract_final_answer(ground_truth_full_string)
        trace_log.append(f"Result -> '{ground_truth_answer}'\n")

        # Step 5: 最終スコアの計算
        trace_log.append(f"--- [Step 5] Calculating Final Score ---")
        trace_log.append(f"Calling _calculate_reward with generated_answer='{generated_answer}' and ground_truth_answer='{ground_truth_answer}'")
        final_reward = self._calculate_reward(generated_answer, ground_truth_answer)
        trace_log.append(f"Result -> {final_reward}\n")

        trace_log.append("=========================== END OF TRACE ===========================\n")
        raise ValueError("".join(trace_log))

    def _extract_final_answer(self, text: str) -> str:
        if not isinstance(text, str): return None
        # $記号とカンマに対応した正規表現
        match = re.search(r'####\s*([+-]?[\d,]*\.?\d+)', str(text).replace('$', ''))
        if match:
            return match.group(1).replace(',', '')
        return None

    def _calculate_reward(self, generated_answer: str, ground_truth_answer: str) -> float:
        if generated_answer is None: return -0.5
        if ground_truth_answer is None: return 0.0
        try:
            gen_num = float(generated_answer)
            gt_num = float(ground_truth_answer)
            if abs(gen_num - gt_num) < 1e-4: return 1.0
            else: return 0.1
        except (ValueError, TypeError): return -0.5

def get_reward_fn(dataset_path: str):
    from datasets import load_dataset
    train_dataset = load_dataset('json', data_files=dataset_path, split='train')
    return CorrectnessRewardFunction(list(train_dataset))
EOF

run_grpo_with_llama.sh:上記のコードを用いてGRPOを実行するコード

cat > run_grpo_with_llama.sh << 'EOF'
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PYTHONPATH="${PYTHONPATH}:$(pwd)"

export MODEL_PATH="/home/llm_project/models/base/Llama-3.2-1B-Instruct"
export OUTPUT_DIR="outputs/grpo_llama_detailed_verify_$(date +%Y%m%d_%H%M%S)"
export ABSOLUTE_DATA_PATH="$HOME/swift_grpo_workspace/data/train.jsonl"

mkdir -p "$OUTPUT_DIR"

python - << 'PYTHON_EOF'
import os
from swift.llm import RLHFArguments, rlhf_main
from verify_reward_function import get_reward_fn

dataset_path = os.environ.get('ABSOLUTE_DATA_PATH')
output_dir = os.environ.get('OUTPUT_DIR')
model_path = os.environ.get('MODEL_PATH')

print("Initializing detailed trace verification...")
reward_function = get_reward_fn(dataset_path)

args = RLHFArguments(
    rlhf_type='grpo',
    model=model_path,
    dataset=dataset_path,
    output_dir=output_dir,
    train_type='lora',
    lora_rank=8,
    lora_alpha=16,
    fp16=True,
    bf16=False,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    max_steps=10,
    num_generations=2,
    generation_batch_size=2,
    learning_rate=1e-5,
    max_length=1024,
    max_completion_length=512,
    logging_steps=1,
    reward_funcs=[reward_function],
    reward_weights=[1.0],
    gradient_checkpointing=True,
)

print(f"Starting GRPO training with model: {model_path}")
print(f"Dataset: {dataset_path}")
rlhf_main(args)
PYTHON_EOF

本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?