6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMの強化学習 GRPO(Group Relative Policy Optimization)の高速化

Last updated at Posted at 2025-10-11

はじめに

本記事は松尾研LLM開発コンペ2025に参加し、チームcaminoで実施したLLMの強化学習手法であるGRPO(Group Relative Policy Optimization)の高速化についてまとめたものです。

GRPOの学習に時間がかかりすぎるという課題がありましたが、推論エンジンをvLLMからSGLangへ切り替えることで、学習速度が約2倍に向上しました。
この記事では、この高速化について説明します。

0.GRPOについて

GRPOは大雑把ですが、以下のような流れで強化学習を進める手法です。

複数の候補を生成
1つの入力に対して複数の出力候補を生成し、それぞれに報酬(スコア)を与えます。

報酬を元にモデルを更新
報酬の高い出力は出やすく、低い出力は出にくくなるようにモデルを更新します。

具体例
ある入力に対し、4つの出力候補と以下の報酬が得られたとします。
報酬スコア: [1.0, 0.5, -0.2, -0.5]
この結果に基づき、スコアが高い 1.00.5 の出力は、生成されやすくなり、スコアが低い -0.2-0.5 の出力は、生成されにくくなります。

参考文献 Looking deeper into the GRPO method

1.GRPOに関連するフレームワーク

運営提供のスクリプトを用い、GRPO は verl で実施しました。
verl は SFT・GRPO のマルチノード学習に対応し、GRPO では vLLMSGLang などの推論エンジンを「引数の切り替え」で変更できます。
詳細は verl 公式ドキュメント をご覧ください。あわせて参考資料を以下にまとめます。

参考資料(GRPO/SGLang)

リンク 概要
verl による GRPO verl で GRPO を使う際の設定・引数
SGLang verl で SGLang を使うための設定
trl による GRPO trl での GRPO 実装・コード例
unsloth による GRPO unsloth での GRPO 実装・コード例

今回はverlを使用したため、trl、unslothによるGRPOは実施していません。

2.高速化の必要性について

GRPOの学習の1ステップあたりの学習に約870秒(14分30分)もかかり、100ステップ分の学習時間に換算すると約1日を要し、学習の高速化が必要でした。

学習条件(1ステップ目:約870秒)

GPU環境: NVIDIA H100 80GB × 8
モデル: microsoft/Phi-4-reasoning-plus(14B)
データ: Omni-MATH_difficulty5plus_qa_rl(使用カラム:question, ground_truth)
推論エンジン: vLLM
学習パラメータ:出力長 31,744/1入力あたり生成数(rollout)=1/バッチサイズ(train_batch_size)=16 → 1ステップの総生成数=16

学習時間の内訳

3ステップ学習を行い、3ステップ目の学習時間を比較しています。

図1. 学習時間の内訳(3ステップ目)

image.png

学習時間(878.89秒)の内、推論時間(849.20秒)が多くを占めることがわかります。


表1. 各処理フェーズの時間内訳

項目 時間(秒)   内容
timing_s/step 878.89 1ステップの合計時間
timing_s/gen 849.20 推論時間の合計
timing_s/update_actor 19.31 報酬に基づくパラメータ更新
timing_s/old_log_prob 4.57 旧ポリシーでの確率計算
timing_s/ref 4.62 参照モデルでの確率計算
timing_s/reward 1.18 報酬モデルによる評価
timing_s/reshard 1.04 GPU間の再配置         

図1と同様、学習の大部分が推論時間に費やされていることがわかります。

補足:差分について
上記の学習時間は、Weights & Biases(wandb)のログから取得したデータですが、各項目を単純に合算すると 879.92秒 になります。しかし、timing_s/step の合計は 878.89秒 であり、1.03秒 の差がありました。この差分が生まれる理由は、分かりませんでした。

3.推論エンジンとしてSGLangを試した理由

SGLangには、vLLMよりも何かしら強みがあると思っていた
きっかけは、運営から提供された資料にSGLangが挙げられていたことです。また、チームで共有されたvLLMとSGLangの比較資料から、SGLangは推論時のメモリ効率などで優位性があるのではないかと考えるようになりました。

すぐに試せる環境だった
verlとSGLangをインストールし、学習時の引数を一つ変更するだけで、すぐにSGLangを使用することができました。
(変更箇所: actor_rollout_ref.rollout.name=sglang

挑戦を後押しするチーム文化であった
良さそうな手法ならとりあえずやってみようという雰囲気がチームにあり、学習時間の削減のために、SGLangの導入がスムーズに決まりました。

4. SGLangへの切り替えによる学習速度の改善

推論エンジンをSGLangに切り替えた結果、1ステップあたりの学習時間は半分以下へ短縮されました。vLLM使用時と学習条件を同一に揃えて実施しています。

学習時間の比較

図2. 学習時間の内訳(3ステップ目)

SGLang_vs_vLLM.png

推論時間が849.20秒から382.91秒となり、推論時間が半分以下になっています。


表2. 各処理フェーズの時間内訳

項目 vLLM SGLang 削減時間 (秒)
timing_s/step (合計) 878.89 秒 408.19 秒 -470.70 秒
timing_s/gen (推論) 849.20 秒 382.91 秒 -466.29 秒
timing_s/update_actor (更新) 19.31 秒 16.26 秒 -3.05 秒
timing_s/old_log_prob (確率計算) 4.57 秒 4.19 秒 -0.38 秒
timing_s/ref (参照モデル) 4.62 秒 3.62 秒 -1.00 秒
timing_s/reward (報酬計算) 1.18 秒 1.20 秒 -0.02 秒
timing_s/reshard (再配置) 1.04 秒 2.31 秒 -1.27 秒

学習時間削減の大部分は timing_s/gen(推論)であり、 timing_s/update_actorなど推論以外のモデル更新等の部分は削減の幅が小さい。

補足:差分について
上記の学習時間は、Weights & Biases(wandb)のログから取得したデータですが、各項目を単純に合算すると、vLLMは1.03秒、SGLangは2.3秒の差があり、timing_s/step(合計)と各項目(timing_s/****)の足し算が合わなかったです。この差分が生まれる理由は、分かりませんでした。

5.vLLMとSGLangの推論速度の違いについての考察

vLLMとSGLangはメモリ割り当ての方法が異なっていますが、なぜSGLangがvLLMより推論速度が速いのかについては、明確な理由がわかってないです。

メモリ割り当ての方法

SGLang
空きGPUメモリのうち、モデル重み・KV など静的用途に gpu_memory_utilization で指定した割合を使用。残りの(1-gpu_memory_utilization)のメモリも推論中に使用されます。

vLLM
vLLM インスタンスは合計メモリの gpu_memory_utilization のみを使用。

参考文献 Rollout Generation Tuning

6.verl + SGLangでの環境構構築および学習の実行

この章からは実際にverlでSGLangを使用するための手順を説明していきます。
SGLang用にcondaで環境を構築し、学習を行いました。以下に、その手順を示します。

vLLMをインストールしたconda環境にSGLangを追加しようとしたところ、バージョンの不一致によるエラーに遭遇しました。原因は、SGLangのインストール時に、PyTorchやFlash Attentionなどのライブラリが自動で新しいバージョンに更新され、既存の他のライブラリとの不一致が発生しました。
このような状況を避けるため、verlでvLLMとSGLangを使い分ける際は、それぞれ専用の環境を構築することをおすすめします。

Step 0: 使用するライブラリの確認

verlの公式ドキュメントに従い、以下のライブラリで環境を構築しました。
verl v0.4.1 を使用)

  • PyTorch: 2.6.0+cu124
  • CUDA: 12.4
  • flashinfer-python: 0.2.5+cu124torch2.6
  • SGLang: 0.4.6.post5
  • sgl-kernel: 0.1.4

Step 1: moduleコマンドの実施

moduleコマンドで必要なライブラリをロードします。

module reset
module load nccl/2.22.3
module load hpcx/2.18.1-gcc-cuda12/hpcx-mt
module load miniconda/24.7.1-py311
source /home/appli/miniconda3/24.7.1-py311/etc/profile.d/conda.sh

Step 2: conda環境の作成

新しい conda 環境の作成および有効化を行う。
今回はconda_env_sglang124という名前で作成しています。

export CONDA_PATH="$HOME/conda_env_sglang124"
conda create --prefix "$CONDA_PATH" python=3.11 -y
conda activate "$CONDA_PATH"

Step 3: CUDA Toolkit 12.4 のインストール

CUDA Toolkit 12.4 をインストールします。

conda install -y -c nvidia/label/cuda-12.4.1 cuda-toolkit=12.4.1
conda install -y -c conda-forge cudnn

以下のコマンドを実施し、CUDAが正しくインストールされているか確認します。
release 12.4 と表示されれば問題ありません。

which nvcc && nvcc --version
# => release 12.4 と表示されればOK

Step 4: PyTorch 2.6.0+cu124 のインストール

ここが手順の中で最も重要です。
PyTorch公式の CUDA 12.4 対応パッケージを使い、torch,torchvision,torchaudioのバージョンを指定してインストール します。

python -m pip install --upgrade pip wheel
python -m pip install --index-url https://download.pytorch.org/whl/cu124 \
  torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0

インストール後、以下のコマンドでバージョンを確認します。
torch: 2.6.0 および torch.version.cuda: 12.4 と表示されれば正常です。

import torch, json
print(json.dumps({
  "torch": torch.__version__,
  "torch.version.cuda": getattr(torch.version, "cuda", None),
  "cuda.is_available": torch.cuda.is_available(),
}, indent=2))
# => torch: 2.6.0, torch.version.cuda: 12.4 が表示されればOK

Step 5: SGLang関連のライブラリのインストール

SGLang を使用するため、verl v0.4.1 と互換性のあるバージョンを指定しインストールします。sgl-kernel は SGLang の GPU カーネル部分を扱うためのモジュールであり、SGLang 本体とバージョンを揃える必要があります。

python -m pip install "sglang[all]==0.4.6.post5"
python -m pip install "sgl-kernel==0.1.4"

次のコードを実行して、インストールが正しく完了しているか確認します。
両方の項目が true と表示されれば正常です。

import importlib, inspect, json
out={}
try:
  from sglang.srt.entrypoints.engine import ServerArgs
  out["ServerArgs_has_mm_attention_backend"] = "mm_attention_backend" in inspect.signature(ServerArgs.__init__).parameters
except Exception as e:
  out["ServerArgs_has_mm_attention_backend"] = f"ERROR:{type(e).__name__}"

try:
  u = importlib.import_module("sglang.srt.utils")
  out["utils_has_maybe_set_triton_cache_manager"] = hasattr(u, "maybe_set_triton_cache_manager")
except Exception as e:
  out["utils_has_maybe_set_triton_cache_manager"] = f"ERROR:{type(e).__name__}"

print(json.dumps(out, indent=2))

#trueが表示されれば正常にインストールされている

Step 6: FlashInfer と FlashAttention のインストール

モデルの推論を高速化するために、GPU最適化ライブラリである FlashInfer と FlashAttention をインストールします。

# flash-infer (任意だが推奨)
python -m pip install "flashinfer-python==0.2.5" || true

# flash-attn (H100向けにアーキテクチャを指定)
export TORCH_CUDA_ARCH_LIST="90"
python -m pip install "flash-attn==2.6.3" --no-build-isolation

それぞれのライブラリが正しくインポートできるかを確認します。
エラーが出ず、True やバージョン番号が表示されれば正常です。

python - <<'PY'
ok=False
try:
  import flashinfer, importlib
  print("flashinfer:", getattr(flashinfer, "__version__", "?"))
  ok=True
except Exception as e:
  print("flashinfer import ERROR:", type(e).__name__, e)
print("flashinfer_ok:", ok)
PY
python - <<'PY'
import importlib
print("flash_attn import:", bool(importlib.util.find_spec("flash_attn")))
PY

Step 7: verl v0.4.1 本体のインストール

pipで公開されている最新版ではなく、SGLangとの互換性が確認されている特定のバージョン(2024年7月15日時点)をGitHubリポジトリから直接インストールします。--no-depsオプションで、verlが意図せず他のライブラリを更新するのを防ぎます。

# まだcloneしてなければ
# git clone git@github.com:volcengine/verl.git

cd verl
# git checkout <7/15時点のコミットハッシュ> などを実行し、バージョンを固定
python -m pip install --no-deps -e .

Step 8: verl の追加依存ライブラリをインストール

--no-depsverlをインストールしたため、不足している依存ライブラリを手動で追加します。

# Ray はVERLの要求バージョンに合わせる
python -m pip install "ray[default]>=2.41.0"

# 残りのライブラリを一括でインストール
# tensordictはバージョンを固定して下さい
python -m pip install accelerate codetiming hydra-core peft pybind11 pylatexenc wandb "tensordict<=0.6.2" torchdata

Step 9: インストールの確認

すべてのインストールが完了したら、以下のスクリプトを実行して主要ライブラリが正しく導入されているかを確認します。
各ライブラリのバージョンが表示され、エラーが出なければ環境は正常に構築されています。

import json, inspect, torch, sglang
import sglang.srt.utils as U
from sglang.srt.entrypoints.engine import ServerArgs
import verl, ray, tensordict

print(json.dumps({
  "torch": torch.__version__,
  "torch.cuda": getattr(torch.version, "cuda", "?"),
  "sglang": getattr(sglang, "__version__", "?"),
  "ray": getattr(ray, "__version__", "?"),
  "tensordict": getattr(tensordict, "__version__", "?"),
  "ServerArgs_has_mm_attention_backend":
      "mm_attention_backend" in inspect.signature(ServerArgs.__init__).parameters,
  "utils_has_maybe_set_triton_cache_manager":
      hasattr(U, "maybe_set_triton_cache_manager"),
}, indent=2))

Step 10: 構築した環境の保存

環境構築が完了したら、バックアップ用に現在のライブラリ構成を保存します。

conda list --explicit > conda-spec.sglang124.txt
python -m pip freeze > pip-req.sglang124.txt

Step 11: 学習の実行

環境構築後、以下のスクリプトと報酬関数を使用し、学習を実行しました。

phi4-grpo_8gpu_sglang-fast.sh(クリックで展開)
#!/bin/bash
#SBATCH --job-name=grpo-test
#SBATCH -p P08
#SBATCH --nodelist=osk-gpu73
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=180
#SBATCH --time=20:00:00
#SBATCH --output=/home/Competition2025/P08/P08U010/training/grpo/logs/slurm-%j.out
#SBATCH --error=/home/Competition2025/P08/P08U010/training/grpo/logs/slurm-%j.err


# 現在のモジュール環境をリセットする(読み込まれている全てのモジュールをアンロード)
module reset

# NCCL(NVIDIA Collective Communications Library)バージョン2.22.3を読み込む
module load nccl/2.22.3

# HPC-X(高性能通信ライブラリ)バージョン2.18.1をCUDA 12およびGCCに対応する構成で読み込む
module load hpcx/2.18.1-gcc-cuda12/hpcx-mt

module load miniconda/24.7.1-py311

source /home/appli/miniconda3/24.7.1-py311/etc/profile.d/conda.sh

# condaコマンドが使えることを確認。
which conda && echo "====" && conda --version

#step0 でインストールした conda のディレクトリ
#export CONDA_PATH="~/conda_env"
export CONDA_PATH="~/conda_env_sglang124"

source ~/.bashrc

conda init

conda config --set auto_activate_base false

# 念のため既に有効化されているPython仮想環境がある場合に備えてリセットのために無効化する。
conda deactivate
conda deactivate

# 作成したPython仮想環境を有効化。
conda activate $CONDA_PATH

mkdir -p ~/training/grpo

mkdir -p ~/training/grpo/checkpoints

cd ~/training/grpo

export NCCL_SOCKET_IFNAME=enp25s0np0
export NVTE_FUSED_ATTN=1

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

unset ROCR_VISIBLE_DEVICES
ulimit -v unlimited

export WANDB_ENTITY="LLM-Compe-2025-Camino"
export WANDB_PROJECT_NAME="competition_verl_grpo_fast"
export WANDB_RUN_NAME="Phi4_rplus_SGLang_Omini-MATH_reward0813"

export CUDA_LAUNCH_BLOCKING=1
export HYDRA_FULL_ERROR=1

# [ADDED] SGLang推奨: TP初期化のメモリ差チェックを無効化(安定化)
export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True  # :contentReference[oaicite:5]{index=5}

PYTHONUNBUFFERED=1 python -m verl.trainer.main_ppo \
 algorithm.adv_estimator=grpo\
 data.train_files=/home/Competition2025/P08/shareP08/dataset/Omni-MATH_difficulty5plus_qa_rl/train.parquet\
 data.val_files=/home/Competition2025/P08/shareP08/dataset/Omni-MATH_difficulty5plus_qa_rl/test.parquet\
 data.train_batch_size=16\
 data.max_prompt_length=1020\
 data.max_response_length=31744\
 +data.dataloader_num_workers=8\
 reward_model.enable=False\
 custom_reward_function.path=$HOME/deps/verl/verl/utils/reward_score/phi4_reward.py\
 actor_rollout_ref.rollout.name=sglang\
 actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4\
 actor_rollout_ref.rollout.tensor_model_parallel_size=2\
 actor_rollout_ref.rollout.gpu_memory_utilization=0.7\
 actor_rollout_ref.rollout.enable_chunked_prefill=True\
 actor_rollout_ref.rollout.max_num_batched_tokens=33792\
 actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4\
 actor_rollout_ref.actor.fsdp_config.param_offload=True\
 actor_rollout_ref.actor.fsdp_config.optimizer_offload=True\
 +actor_rollout_ref.ref.fsdp_config.optimizer_offload=True\
 actor_rollout_ref.ref.fsdp_config.param_offload=True\
 actor_rollout_ref.model.path=/home/Competition2025/P08/shareP08/model/Phi-4-reasoning-plus\
 actor_rollout_ref.model.use_remove_padding=True\
 actor_rollout_ref.model.enable_gradient_checkpointing=True\
 actor_rollout_ref.actor.strategy="fsdp2"\
 actor_rollout_ref.actor.optim.lr=1e-6\
 actor_rollout_ref.actor.optim.lr_warmup_steps=10\
 actor_rollout_ref.actor.ppo_mini_batch_size=64\
 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4\
 actor_rollout_ref.actor.use_dynamic_bsz=True\
 actor_rollout_ref.actor.use_kl_loss=True \
 actor_rollout_ref.actor.kl_loss_coef=0.001 \
 actor_rollout_ref.actor.kl_loss_type=low_var_kl \
 actor_rollout_ref.actor.ppo_max_token_len_per_gpu=39904\
 actor_rollout_ref.actor.fsdp_config.forward_prefetch=True\
 actor_rollout_ref.rollout.n=1\
 +actor_rollout_ref.ref.fsdp_config.model_dtype=bf16\
 +actor_rollout_ref.actor.fsdp_config.model_dtype=bf16\
 algorithm.use_kl_in_reward=False\
 trainer.val_before_train=False \
 trainer.n_gpus_per_node=8 \
 trainer.nnodes=1 \
 trainer.save_freq=10 \
 trainer.test_freq=100 \
 trainer.default_local_dir=$HOME/training/grpo/ckpt_phi4_rplus_SGLang_Omni-MATH_reward0813\
 trainer.logger=['console','wandb'] \
 trainer.project_name=$WANDB_PROJECT_NAME\
 trainer.experiment_name=$WANDB_RUN_NAME\
 trainer.total_epochs=2 2>&1 | tee verl_demo.lo
 
phi4_reward.py(クリックで展開)
###
#phi4-reasoningの報酬関数を実装
###
import math
import re
import signal
from collections import Counter

from sympy.parsing.latex import parse_latex
from transformers import AutoTokenizer

TOKENIZER = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning-plus")

# 論文のセクション4.1および4.2から引用した定数
# 報酬関数で使用する長さのパラメータ
L_MAX = 31744
L_POS_CONTROL = 25600
L_NEG_CONTROL = 3702

# 報酬値の範囲
R_MAX_POS = 1.0
R_MIN_POS = 0.5
R_MAX_NEG = -0.5
R_MIN_NEG = -1.0

# 最終的な報酬の重み
W_ACC = 8 / 13
W_REP = 1 / 13

# 繰り返しペナルティのパラメータ
NGRAM_SIZE = 5
NGRAM_FREQ_THRESHOLD = 5
_SOLUTION_CLIP_CHARS = 300

# タイムアウト時に呼び出され、例外を発生させる関数
def timeout_handler(signum, frame):
    raise TimeoutError("処理がタイムアウトしました。")
def extract_solution(solution_str, method="strict"):
    assert method in ["strict", "flexible"]

    # Optimization: Regular expression matching on very long strings can be slow.
    # For math problems, the final answer is usually at the end.
    # We only match on the last 300 characters, which is a safe approximation for 300 tokens.
    if len(solution_str) > _SOLUTION_CLIP_CHARS:
        solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]

    if method == "strict":
        # this also tests the formatting of the model
        solutions = re.findall("####(\\-?[0-9\\.\\,]+)", solution_str)
        if len(solutions) == 0:
            final_answer = None
        else:
            # take the last solution
            final_answer = solutions[-1].replace(",", "").replace("$", "")
    elif method == "flexible":
        answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
        final_answer = None
        if len(answer) == 0:
            # no reward is there is no answer
            pass
        else:
            invalid_str = ["", "."]
            # find the last number that is not '.'
            for final_answer in reversed(answer):
                if final_answer not in invalid_str:
                    break
    return final_answer

def find_last_boxed_content(text: str) -> str:
    """
    文字列中の最後の "\\boxed{...}" の中身を、入れ子括弧を考慮して抽出します。
    エスケープされた括弧 \{ や \} は無視します。
    """
    try:
        # 最後の "\\boxed{" の開始インデックスを探します
        last_boxed_start_index = text.rfind("\\boxed{")
        if last_boxed_start_index == -1:
            return ""

        # コンテンツの実際の開始位置
        content_start_index = last_boxed_start_index + len("\\boxed{")

        # 対応する閉じ括弧 '}' を探します
        brace_level = 1
        for i in range(content_start_index, len(text)):
            char = text[i]

            # LaTeXでエスケープされた括弧 \{ や \} はレベル計算に含めません
            if text[i-1] == '\\' and (char == '{' or char == '}'):
                continue

            if char == '{':
                brace_level += 1
            elif char == '}':
                brace_level -= 1

            # brace_levelが0になったら、それが対応する閉じ括弧です
            if brace_level == 0:
                return text[content_start_index:i]

        # 最後まで見ても対応する閉じ括弧が見つからなかった場合
        return ""

    except Exception:
        # 何らかのエラーが発生した場合
        return ""
def extract_thought_and_answer(solution_str: str) -> tuple[str, str, bool]:
    """
    文字列から<think>...</think>と最後の\\boxed{...}を抽出します。
    \\boxed{...}内の入れ子括弧に対応しています。
    """
    # <think>...</think> の抽出ロジックは変更ありません
    think_match = re.search(r"<think>(.*?)</think>", solution_str, re.DOTALL)

    if think_match:
        thinking_process = think_match.group(1).strip()
        is_format_valid = True
    else:
        thinking_process = ""
        is_format_valid = False

    # \\boxed{...} の抽出を新しい堅牢な関数に置き換えます
    answer = find_last_boxed_content(solution_str)

    return thinking_process, answer, is_format_valid
def parse_solution(solution_str: str) -> tuple[str | None, str | None, bool]:
    """
    <think>...</think>{answer} という構造の文字列を解析する。

    Args:
        solution_str: モデルの生成出力。

    Returns:
        タプル (thinking_process, answer, is_format_valid)
    """
    # 正規表現を使用して<think>ブロックとそれに続く回答を抽出
    match = re.search(r"<think>(.*?)</think>(.*)", solution_str, re.DOTALL)

    if match:
        thinking_process = match.group(1).strip()
        answer = match.group(2).strip()
        return thinking_process, answer, True
    else:
        # thinkタグが見つからない、または形式が不正
        return None, None, False

def _compute_repetition_penalty(text: str) -> float:
    """
    n-gramの頻度に基づいて繰り返しペナルティを計算します。
    """
    words = text.split()
    if len(words) < NGRAM_SIZE:
        return 0.0
    # n-gramを生成
    ngrams = [" ".join(words[i:i+NGRAM_SIZE]) for i in range(len(words) - NGRAM_SIZE + 1)]
    if not ngrams:
        return 0.0

    ngram_counts = Counter(ngrams)
    frequent_ngrams = {k: v for k, v in ngram_counts.items() if v > NGRAM_FREQ_THRESHOLD}

    if not frequent_ngrams:
        return 0.0

    term1 = len(frequent_ngrams) / len(ngrams)
    max_freq = max(frequent_ngrams.values())
    total_possible_ngrams = len(words) / NGRAM_SIZE if len(words) > 0 else 1
    term2 = max_freq / total_possible_ngrams

    penalty = -max(term1, term2)
    return penalty

def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None):
    """
    Phi-4-reasoning論文で説明されている報酬関数に基づいて最終的なスコアを計算します。

    Args:
        solution_str: モデルから生成された完全なテキスト。(tokenizedではなく、文字列形式)
        ground_truth: 正解。
        data_source: データソースの名前。現在は "gsm8k" のみ対応。

    """
    # 1. 出力文字列を解析し、フォーマットを検証
    thinking_process, answer, is_format_valid = extract_thought_and_answer(solution_str)
    L=len(TOKENIZER.tokenize(solution_str))
    print("---solution_str---")
    print(solution_str)  # Debugging output

    signal.signal(signal.SIGALRM, timeout_handler)
    # 30秒でタイムアウトするように設定(必要に応じて変更)
    signal.alarm(30)
    try:
        latex_answer=parse_latex(str(answer).lower(),backend="lark")
        latex_ground_truth=parse_latex(str(ground_truth).lower(),backend="lark")
    except Exception as e:  # 必要に応じて全ての例外をキャッチ
        latex_answer=str(answer).lower().replace(" ","")
        latex_ground_truth=str(ground_truth).lower().replace(" ","")
    finally:
        signal.alarm(0)
    print("---is_format_valid---")
    print(is_format_valid)
    print("---answer---")
    print(answer)  # Debugging output
    print("---ground_truth---")
    print(ground_truth)  # Debugging output
    # 2. フォーマット違反のオーバーライドを処理
    # <think>タグが不正な場合は is_format_valid が False になる
    # answerが不適切な形の場合もフォーマット違反とする
    if not is_format_valid:
        r_acc_scaled = -1.0
    # 生成が不完全な場合
    elif L >= L_MAX-1:
        # imcomplete(eostokenなし)はこの関数では厳密な実装はできないので,max_lengthを超えた場合にフォーマット違反として扱う
        # ここでは、L_MAXを超える場合にフォーマット違反として扱う
                # (Lは開始トークンおよび終了トークンを含まず,L_MAXは終了トークンを含むため,L_MAX-1と比較)
        # TODO:imcompleteの完全な実装
        r_acc_scaled = -0.5
    else:
        # 3. 回答が正解かどうかを報酬に反映
        #ground_truthがlatex構文に適していなかった場合,元のanswerと比較する
        is_correct= (latex_answer is not None and latex_answer == latex_ground_truth)
        # 注記: 論文ではトークン長が使用されていますが、ここでは単語数を代理として使用します。
        # 正確な実装には、トークナイザが必要です。
        #L = len(solution_str.split())
        L= len(TOKENIZER.tokenize(solution_str))

        if is_correct:
            rho_plus = min(1.0, max(0, L - L_POS_CONTROL) / (L_MAX - L_POS_CONTROL))
            cos_term = 0.5 * (R_MAX_POS - R_MIN_POS) * (1 + math.cos(math.pi * rho_plus))
            r_acc_scaled = R_MIN_POS + cos_term
        else:
            rho_minus = min(1.0, L / L_NEG_CONTROL)
            cos_term = 0.5 * (R_MIN_NEG - R_MAX_NEG) * (1 + math.cos(math.pi * rho_minus))
            r_acc_scaled = R_MAX_NEG + cos_term

    # 4. 繰り返しペナルティを計算 (文字列全体を対象)
    r_rep = _compute_repetition_penalty(solution_str)

    # 5. 最終的な重み付きスコアを計算
    final_score = (W_ACC * r_acc_scaled) + (W_REP * r_rep)
    print("---final_score---")
    print(final_score)  # Debugging output

    return final_score
学習ログ(3step目)
step:3 - global_seqlen/min:6637.000 - global_seqlen/max:61182.000 - global_seqlen/minmax_diff:54545.000 - global_seqlen/balanced_min:32380.000 - global_seqlen/balanced_max:36841.000 - global_seqlen/mean:35474.000 - actor/entropy:0.874 - actor/kl_loss:0.000 - actor/kl_coef:0.001 - actor/pg_loss:0.462 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:0.016 - perf/mfu/actor:0.000 - perf/max_memory_allocated_gb:56.216 - perf/max_memory_reserved_gb:70.156 - perf/cpu_memory_used_gb:127.258 - actor/lr:0.000 - training/global_step:3.000 - training/epoch:0.000 - critic/score/mean:-0.020 - critic/score/max:0.615 - critic/score/min:-0.616 - critic/rewards/mean:-0.020 - critic/rewards/max:0.615 - critic/rewards/min:-0.616 - critic/advantages/mean:-0.326 - critic/advantages/max:0.615 - critic/advantages/min:-0.616 - critic/returns/mean:-0.326 - critic/returns/max:0.615 - critic/returns/min:-0.616 - response_length/mean:17435.688 - response_length/max:31744.000 - response_length/min:1680.000 - response_length/clip_ratio:0.250 - prompt_length/mean:301.312 - prompt_length/max:352.000 - prompt_length/min:258.000 - prompt_length/clip_ratio:0.000 - timing_s/generate_sequences:380.170 - timing_s/reshard:2.311 - timing_s/gen:382.913 - timing_s/reward:1.198 - timing_s/old_log_prob:4.186 - timing_s/ref:3.621 - timing_s/adv:0.002 - timing_s/update_actor:16.263 - timing_s/step:408.190 - timing_per_token_ms/update_actor:0.057 - timing_per_token_ms/adv:0.000 - timing_per_token_ms/ref:0.013 - timing_per_token_ms/gen:1.373 - perf/total_num_tokens:283792.000 - perf/time_per_step:408.190 - perf/throughput:86.906

上記のスクリプトと報酬関数はSGLangの環境構築後に使用したものであり、最終的にチームCaminoが提出したモデルの学習に使用したものとはパラメータ等が異なります。

7.SGLangのデメリット

SGLangにより学習を高速化できましたが、デメリットも存在しました。

課題
verl環境のGRPOにおいて、SGLangでLoRAは使用不可であり、70Bのモデルで学習を行う際、1ノード(8GPU)のメモリに収まらない。
(SGLang単体ではLoRAが使用可であることから、将来的にアップデートされ、verlでも、LoRAが使用可となる可能性もあります。)
対応
試行回数を増やす、学習を安定化させるため、1ノードで学習を行う方針であったため、以下のように推論エンジンを使い分けました。

  • 70Bモデル(LoRA必須の場合):1ノードでの学習を行うため、vLLM を使用
  • 14Bモデル:SGLangを使用し学習を高速化

8.まとめ

GRPOは、1つの入力に対して複数の出力を生成する仕組みを取っており、この推論部分が学習の中でかなりの時間を占めていました。この状況で、推論エンジンを vLLM から SGLang に切り替えたところ、学習時間を半分以下に短縮することができました。

環境構築にはライブラリのバージョン管理など多少の手間がかかりますが、それに見合う効果がありました。(Dockerなどを活用できれば、もっと簡単に環境構築できると思います。)
GRPO などの強化学習に取り組む際には、試行回数を増やすために、SGLang のような推論エンジンを使用し、推論時速度を上げることが重要であると感じました。

最後になりますが、チームCaminoの皆様については、GRPOの高速化に挑戦させていただきありがとうございました。松尾・岩澤研の方々については、コンペの運営ありがとうございました。さくらインターネット様についてはインフラのご提供ありがとうございました。

本プロジェクトは、国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます。

6
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?