概要
東京大学松尾・岩澤研究室主催のLLM開発コンペ2025に参加し,チームcaminoで報酬関数の実装を担当したため,そのあたりの知見を本記事でまとめておきます.特に数学タスクを解くLLMの学習・評価設計の参考になれば幸いです.
コンペについて
HLE(Human Last Exam) という非常に難しい問題を事後学習を用いてスコアを競うコンペティションです.HLEは様々な分野の問題が含有されていますが,特に数学タスク が多いことが特徴です.この数学タスクに対して事後学習であるRLHFを行う際の報酬関数の実装方法についてまとめていきます.なお,強化学習手法にGRPO を使用しました.
phi4-reasoning plusの報酬関数実装
チームcaminoではphi4-reasoning-plusというモデルを用いて,学習を行いました.phi4-reasoning-plusはdeepseek-r1等のreasoningモデル(答えだけではなく思考過程も出力するモデル)とは少し異なる報酬関数を設計しています.phi4-reasoning-plusでは,以下の式を報酬として,GRPOを学習しています.(詳細はテクレポ4.1節を参照).
\displaylines{R_{final}=w_{acc}R_{acc\_scaled}+w_{rep}R_{rep},\\
R_{acc\_scaled}=R^-_{max}+0.5\dot(R^-_{min}-R^-_{max})\dot(1+cos(\pi\rho_-)),\\
R_{rep}=-\max \left(\frac{\#\{5-\text{grams with freq.}>5\}}{\#\{5-\text{grams}\}},\frac{\text{max freq. of 5-grams with freq}. >5}{\#\{\text{words}\}/5}\right)\\
}
まあ,数式だけでは分かりづらいですよね.なので実際使用したpythonコードを見てみましょう.数式よりコードを見た方が分かりやすいと思います.
###
#phi4-reasoningの報酬関数を実装
###
###
#phi4-reasoningの報酬関数を実装
###
import math
import re
from collections import Counter
from sympy.parsing.latex import parse_latex
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning-plus")
# 論文のセクション4.1および4.2から引用した定数
# 報酬関数で使用する長さのパラメータ
L_MAX = 31744
L_POS_CONTROL = 25600
L_NEG_CONTROL = 3702
# 報酬値の範囲
R_MAX_POS = 1.0
R_MIN_POS = 0.5
R_MAX_NEG = -0.5
R_MIN_NEG = -1.0
# 最終的な報酬の重み
W_ACC = 8 / 13
W_REP = 1 / 13
# 繰り返しペナルティのパラメータ
NGRAM_SIZE = 5
NGRAM_FREQ_THRESHOLD = 5
_SOLUTION_CLIP_CHARS = 300
def find_last_boxed_content(text: str) -> str:
"""
文字列中の最後の "\\boxed{...}" の中身を、入れ子括弧を考慮して抽出します。
エスケープされた括弧 \{ や \} は無視します。
"""
try:
# 最後の "\\boxed{" の開始インデックスを探します
last_boxed_start_index = text.rfind("\\boxed{")
if last_boxed_start_index == -1:
return ""
# コンテンツの実際の開始位置
content_start_index = last_boxed_start_index + len("\\boxed{")
# 対応する閉じ括弧 '}' を探します
brace_level = 1
for i in range(content_start_index, len(text)):
char = text[i]
# LaTeXでエスケープされた括弧 \{ や \} はレベル計算に含めません
if text[i-1] == '\\' and (char == '{' or char == '}'):
continue
if char == '{':
brace_level += 1
elif char == '}':
brace_level -= 1
# brace_levelが0になったら、それが対応する閉じ括弧です
if brace_level == 0:
return text[content_start_index:i]
# 最後まで見ても対応する閉じ括弧が見つからなかった場合
return ""
except Exception:
# 何らかのエラーが発生した場合
return ""
def extract_thought_and_answer(solution_str: str) -> tuple[str, str, bool]:
"""
文字列から<think>...</think>と最後の\\boxed{...}を抽出します。
\\boxed{...}内の入れ子括弧に対応しています。
"""
# <think>...</think> の抽出
think_match = re.search(r"<think>(.*?)</think>", solution_str, re.DOTALL)
if think_match:
thinking_process = think_match.group(1).strip()
is_format_valid = True
else:
thinking_process = ""
is_format_valid = False
# \\boxed{...} の抽出
answer = find_last_boxed_content(solution_str)
return thinking_process, answer, is_format_valid
def _compute_repetition_penalty(text: str) -> float:
"""
同じ単語が繰り返している場合のペナルティをn-gramの頻度に基づいて計算します。
"""
words = text.split()
if len(words) < NGRAM_SIZE:
return 0.0
# n-gramを生成
ngrams = [" ".join(words[i:i+NGRAM_SIZE]) for i in range(len(words) - NGRAM_SIZE + 1)]
if not ngrams:
return 0.0
ngram_counts = Counter(ngrams)
frequent_ngrams = {k: v for k, v in ngram_counts.items() if v > NGRAM_FREQ_THRESHOLD}
if not frequent_ngrams:
return 0.0
term1 = len(frequent_ngrams) / len(ngrams)
max_freq = max(frequent_ngrams.values())
total_possible_ngrams = len(words) / NGRAM_SIZE if len(words) > 0 else 1
term2 = max_freq / total_possible_ngrams
penalty = -max(term1, term2)
return penalty
def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None):
"""
Phi-4-reasoning論文で説明されている報酬関数に基づいて最終的なスコアを計算します。
Args:
solution_str: モデルから生成された完全なテキスト。(tokenizedではなく、文字列形式)
ground_truth: 正解。
data_source: データソースの名前。現在は "gsm8k" のみ対応。
"""
# 1. 出力文字列を解析し、フォーマットを検証
thinking_process, answer, is_format_valid = extract_thought_and_answer(solution_str)
L=len(TOKENIZER.tokenize(solution_str))
print("---solution_str---")
print(solution_str) # Debugging output
# 2. フォーマット違反のオーバーライドを処理
# <think>タグが不正な場合は is_format_valid が False になる
if not is_format_valid:
r_acc_scaled = -1.0
# 生成が不完全な場合
elif L >= L_MAX-1:
# imcomplete(eostokenなし)はこの関数では厳密な実装はできないので,max_lengthを超えた場合にフォーマット違反として扱う
# ここでは、L_MAXを超える場合にフォーマット違反として扱う
# (Lは開始トークンおよび終了トークンを含まず,L_MAXは終了トークンを含むため,L_MAX-1と比較)
# TODO:imcompleteの完全な実装
r_acc_scaled = -0.5
else:
print("---answer---")
print(answer) # Debugging output
print("---ground_truth---")
print(ground_truth) # Debugging output
# 3. フォーマットが正常な場合、長さ認識型の正解度報酬を計算
is_correct = (answer is not None and answer.lower() == ground_truth.lower())
#回答のトークン数を求める
L= len(TOKENIZER.tokenize(solution_str))
if is_correct:
rho_plus = min(1.0, max(0, L - L_POS_CONTROL) / (L_MAX - L_POS_CONTROL))
cos_term = 0.5 * (R_MAX_POS - R_MIN_POS) * (1 + math.cos(math.pi * rho_plus))
r_acc_scaled = R_MIN_POS + cos_term
else:
rho_minus = min(1.0, L / L_NEG_CONTROL)
cos_term = 0.5 * (R_MIN_NEG - R_MAX_NEG) * (1 + math.cos(math.pi * rho_minus))
r_acc_scaled = R_MAX_NEG + cos_term
# 4. 繰り返しペナルティを計算 (文字列全体を対象)
r_rep = _compute_repetition_penalty(solution_str)
# 5. 最終的な重み付きスコアを計算
final_score = (W_ACC * r_acc_scaled) + (W_REP * r_rep)
print("---final_score---")
print(final_score) # Debugging output
return final_score
主な関数は,
-
find_last_boxed_content:\boxed{}で囲まれている中身を抽出して,最終回答を取り出す関数 -
extract_thought_and_answer:<think>タグで囲まれているかの判定および思考過程,最終回答の抽出 -
compute_score: メイン処理.ここで,回答の一致判定や報酬を計算する.
となっています.コメントアウトに詳しく書いてあるのできちんと理解したい方は読んでみてください.
問題点: 一致判定の不正確さ
これでうまくいって万事OK!となれば良かったのですが,数学問題では,回答が数式となる場合,ground truthがtex形式の場合が多い一方で,モデルの回答は必ずしもtex形式であるとは限らず意味は同じでも不一致となってしまうことが多々ありました.
model answer: 1/2
ground truth: \frac{1}{2}
このような場合,単純な一致判定だと,Falseとなりますが,表記が異なるだけで意味は同じなはずです.この表記ゆれを解決する必要がありました.
Latexを吸収するための工夫1: latex parserの導入
そこで,sympyのlatex parserを使用して,texコードの表記を判定できるようにしてみました.latex parserとはpythonで数式を扱うためのモジュールで,latex表記を吸収するために用いています.詳しくは以下のブログがわかりやすいです.
https://tokibito.hatenablog.com/entry/2024/06/28/221009
なお,導入するには,larkライブラリが必要です.
from sympy.parsing.latex import parse_latex
#-----------------------------------
#(中略)
#-----------------------------------
# latex parser部分のみを抽出
try:
latex_answer=parse_latex(str(answer).lower(),backend="lark")
latex_ground_truth=parse_latex(str(ground_truth).lower(),backend="lark")
except Exception as e: # 必要に応じて全ての例外をキャッチ
latex_answer=str(answer).lower().replace(" ","")
latex_ground_truth=str(ground_truth).lower().replace(" ","")
finally:
signal.alarm(0)
これにより,大体のlatex形式は吸収できるようになりました.
latex parser導入後の報酬関数プログラム全コード
###
#phi4-reasoningの報酬関数を実装
###
import math
import re
import signal
from collections import Counter
from sympy.parsing.latex import parse_latex
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning-plus")
# 論文のセクション4.1および4.2から引用した定数
# 報酬関数で使用する長さのパラメータ
L_MAX = 31744
L_POS_CONTROL = 25600
L_NEG_CONTROL = 3702
# 報酬値の範囲
R_MAX_POS = 1.0
R_MIN_POS = 0.5
R_MAX_NEG = -0.5
R_MIN_NEG = -1.0
# 最終的な報酬の重み
W_ACC = 8 / 13
W_REP = 1 / 13
# 繰り返しペナルティのパラメータ
NGRAM_SIZE = 5
NGRAM_FREQ_THRESHOLD = 5
_SOLUTION_CLIP_CHARS = 300
# タイムアウト時に呼び出され、例外を発生させる関数
def timeout_handler(signum, frame):
raise TimeoutError("処理がタイムアウトしました。")
def find_last_boxed_content(text: str) -> str:
"""
文字列中の最後の "\\boxed{...}" の中身を、入れ子括弧を考慮して抽出します。
エスケープされた括弧 \{ や \} は無視します。
"""
try:
# 最後の "\\boxed{" の開始インデックスを探します
last_boxed_start_index = text.rfind("\\boxed{")
if last_boxed_start_index == -1:
return ""
# コンテンツの実際の開始位置
content_start_index = last_boxed_start_index + len("\\boxed{")
# 対応する閉じ括弧 '}' を探します
brace_level = 1
for i in range(content_start_index, len(text)):
char = text[i]
# LaTeXでエスケープされた括弧 \{ や \} はレベル計算に含めません
if text[i-1] == '\\' and (char == '{' or char == '}'):
continue
if char == '{':
brace_level += 1
elif char == '}':
brace_level -= 1
# brace_levelが0になったら、それが対応する閉じ括弧です
if brace_level == 0:
return text[content_start_index:i]
# 最後まで見ても対応する閉じ括弧が見つからなかった場合
return ""
except Exception:
# 何らかのエラーが発生した場合
return ""
def extract_thought_and_answer(solution_str: str) -> tuple[str, str, bool]:
"""
文字列から<think>...</think>と最後の\\boxed{...}を抽出します。
\\boxed{...}内の入れ子括弧に対応しています。
"""
# <think>...</think> の抽出ロジックは変更ありません
think_match = re.search(r"<think>(.*?)</think>", solution_str, re.DOTALL)
if think_match:
thinking_process = think_match.group(1).strip()
is_format_valid = True
else:
thinking_process = ""
is_format_valid = False
# \\boxed{...} の抽出を新しい堅牢な関数に置き換えます
answer = find_last_boxed_content(solution_str)
return thinking_process, answer, is_format_valid
def _compute_repetition_penalty(text: str) -> float:
"""
n-gramの頻度に基づいて繰り返しペナルティを計算します。
"""
words = text.split()
if len(words) < NGRAM_SIZE:
return 0.0
# n-gramを生成
ngrams = [" ".join(words[i:i+NGRAM_SIZE]) for i in range(len(words) - NGRAM_SIZE + 1)]
if not ngrams:
return 0.0
ngram_counts = Counter(ngrams)
frequent_ngrams = {k: v for k, v in ngram_counts.items() if v > NGRAM_FREQ_THRESHOLD}
if not frequent_ngrams:
return 0.0
term1 = len(frequent_ngrams) / len(ngrams)
max_freq = max(frequent_ngrams.values())
total_possible_ngrams = len(words) / NGRAM_SIZE if len(words) > 0 else 1
term2 = max_freq / total_possible_ngrams
penalty = -max(term1, term2)
return penalty
def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None):
"""
Phi-4-reasoning論文で説明されている報酬関数に基づいて最終的なスコアを計算します。
Args:
solution_str: モデルから生成された完全なテキスト。(tokenizedではなく、文字列形式)
ground_truth: 正解。
data_source: データソースの名前。現在は "gsm8k" のみ対応。
"""
# 1. 出力文字列を解析し、フォーマットを検証
thinking_process, answer, is_format_valid = extract_thought_and_answer(solution_str)
L=len(TOKENIZER.tokenize(solution_str))
print("---solution_str---")
print(solution_str) # Debugging output
signal.signal(signal.SIGALRM, timeout_handler)
# 30秒でタイムアウトするように設定(必要に応じて変更)
signal.alarm(30)
try:
latex_answer=parse_latex(str(answer).lower(),backend="lark")
latex_ground_truth=parse_latex(str(ground_truth).lower(),backend="lark")
except Exception as e: # 必要に応じて全ての例外をキャッチ
latex_answer=str(answer).lower().replace(" ","")
latex_ground_truth=str(ground_truth).lower().replace(" ","")
finally:
signal.alarm(0)
print("---is_format_valid---")
print(is_format_valid)
print("---answer---")
print(answer) # Debugging output
print("---ground_truth---")
print(ground_truth) # Debugging output
# 2. フォーマット違反のオーバーライドを処理
# <think>タグが不正な場合は is_format_valid が False になる
# answerが不適切な形の場合もフォーマット違反とする
if not is_format_valid:
r_acc_scaled = -1.0
# 生成が不完全な場合
elif L >= L_MAX-1:
# imcomplete(eostokenなし)はこの関数では厳密な実装はできないので,max_lengthを超えた場合にフォーマット違反として扱う
# ここでは、L_MAXを超える場合にフォーマット違反として扱う
# (Lは開始トークンおよび終了トークンを含まず,L_MAXは終了トークンを含むため,L_MAX-1と比較)
r_acc_scaled = -0.5
else:
# 3. 回答が正解かどうかを報酬に反映
#ground_truthがlatex構文に適していなかった場合,元のanswerと比較する
is_correct= (latex_answer is not None and latex_answer == latex_ground_truth)
#トークン数を求める
L= len(TOKENIZER.tokenize(solution_str))
if is_correct:
rho_plus = min(1.0, max(0, L - L_POS_CONTROL) / (L_MAX - L_POS_CONTROL))
cos_term = 0.5 * (R_MAX_POS - R_MIN_POS) * (1 + math.cos(math.pi * rho_plus))
r_acc_scaled = R_MIN_POS + cos_term
else:
rho_minus = min(1.0, L / L_NEG_CONTROL)
cos_term = 0.5 * (R_MIN_NEG - R_MAX_NEG) * (1 + math.cos(math.pi * rho_minus))
r_acc_scaled = R_MAX_NEG + cos_term
# 4. 繰り返しペナルティを計算 (文字列全体を対象)
r_rep = _compute_repetition_penalty(solution_str)
# 5. 最終的な重み付きスコアを計算
final_score = (W_ACC * r_acc_scaled) + (W_REP * r_rep)
print("---final_score---")
print(final_score) # Debugging output
return final_score
問題点: latex parserの限界
sympyのlatex parserでは,読み込めない数式がいくつかありました.
2010^{2010^{2010}}
このような数式は,無限ループのようなものに陥るらしく,応答が返ってきませんでした.
(他にもいくつかあった気がします....)
問題点: 数式としては等価だが表記が違うケース
数学的には等価でも,表記差のため一致判定が False になることがあります.
---answer---
\left(1-10^{-1/9}\right)
---ground_truth---
\left(1-\left(\frac{1}{10}\right)^{\frac{1}{9}}\right)
わかりづらいので,数式に書き直してみます.
\text{answer:}
\left(1-10^{-1/9}\right)
\text{ground truth:}\left(1-\left(\frac{1}{10}\right)^{\frac{1}{9}}\right)
この2つの数式は表記が異なるものの,数式的には等価です.そのため,判定としてはTrueであることが望ましいですが,lantex parserではこの辺りを吸収できずFalseとなってしまいます.
LLMを用いた判定
そこで,外部LLMを用いた正誤判定を行いました.外部LLMはコンペの規約上無料のものを使用しなければならなかったため,Groq APIを使用しました.
# Groqの判定関数のみ抽出
def groq_complex_match(answer,ground_truth):
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
#Rate limitsを回避するため,system promptで,True かFalseのみを返すように指定する.
system_input='''
Are the `answer` and `ground_truth` below semantically the same? Answer only with `True` or `False` and Put your final answer inside \\boxed{}\n
'''
# Set the system prompt
system_prompt = {
"role": "system",
"content": system_input
}
# Set the user prompt
user_input = f'''###answer:\n
{answer}\n
###ground_truth:\n
{ground_truth}\n
###judge:\n
'''
user_prompt = {
"role": "user", "content": user_input
}
# Initialize the chat history
chat_history = [system_prompt, user_prompt]
#Requenst limits対策のため,1.5秒ほど止める
time.sleep(1.5)
response = client.chat.completions.create(model="meta-llama/llama-4-scout-17b-16e-instruct",
messages=chat_history,
max_tokens=1000,
temperature=0.01)
# Print the response
llm_judge=response.choices[0].message.content
llm_judge=find_last_boxed_content(llm_judge)
return llm_judge
これにより,ほとんどの回答を正しく判定することができました.
ただ,GroqのRate limitsにより,学習中Groqが使えなくなる時がありました.今コンペでは,無料APIのみの使用が許可されていたため,Groqにしましたが,もし有料OKであれば,Openrouter 等の代替も選択肢になると思います.
(Openrouterは10ドル課金すると,1日に1000リクエストまで可能)
最終的な報酬関数全コード
以下のコードには回答が短すぎないようにするために,テクレポの式にshort answer penaltyが追加されています.これは,Phi4-reasoningの報酬関数の仕様上,不正解でもトークン数が短い方が報酬が大きく設定されることから,回答自体を返さない(answerのトークン数0)になることが散見されたからです.
###
#phi4-reasoningの報酬関数を実装
###
import math
import os
import re
import signal
import time
from collections import Counter
from groq import Groq
from sympy.parsing.latex import parse_latex
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning-plus")
# 論文のセクション4.1および4.2から引用した定数
# 報酬関数で使用する長さのパラメータ
L_MAX = 31744
L_POS_CONTROL = 25600
L_NEG_CONTROL = 3702
# 報酬値の範囲
R_MAX_POS = 1.0
R_MIN_POS = 0.5
R_MAX_NEG = -0.5
R_MIN_NEG = -1.0
# 最終的な報酬の重み
W_ACC = 8 / 13
W_REP = 1 / 13
# 繰り返しペナルティのパラメータ
NGRAM_SIZE = 5
NGRAM_FREQ_THRESHOLD = 5
_SOLUTION_CLIP_CHARS = 300
# ★追加:short answerペナルティの重み(reward hack対策)
W_LEN = 1 / 13
ANSWER_LEN_THRESHOLD = 100
SHORT_ANSWER_PENALTY = -1.0
# タイムアウト時に呼び出され、例外を発生させる関数
def timeout_handler(signum, frame):
raise TimeoutError("処理がタイムアウトしました。")
def find_last_boxed_content(text: str) -> str:
"""
文字列中の最後の "\\boxed{...}" の中身を、入れ子括弧を考慮して抽出します。
エスケープされた括弧 \{ や \} は無視します。
"""
try:
# 最後の "\\boxed{" の開始インデックスを探します
last_boxed_start_index = text.rfind("\\boxed{")
if last_boxed_start_index == -1:
return ""
# コンテンツの実際の開始位置
content_start_index = last_boxed_start_index + len("\\boxed{")
# 対応する閉じ括弧 '}' を探します
brace_level = 1
for i in range(content_start_index, len(text)):
char = text[i]
# LaTeXでエスケープされた括弧 \{ や \} はレベル計算に含めません
if text[i-1] == '\\' and (char == '{' or char == '}'):
continue
if char == '{':
brace_level += 1
elif char == '}':
brace_level -= 1
# brace_levelが0になったら、それが対応する閉じ括弧です
if brace_level == 0:
return text[content_start_index:i]
# 最後まで見ても対応する閉じ括弧が見つからなかった場合
return ""
except Exception:
# 何らかのエラーが発生した場合
return ""
def groq_complex_match(answer,ground_truth):
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
system_input='''
Are the `answer` and `ground_truth` below semantically the same? Answer only with `True` or `False` and Put your final answer inside \\boxed{}\n
'''
# Set the system prompt
system_prompt = {
"role": "system",
"content": system_input
}
# Set the user prompt
user_input = f'''###answer:\n
{answer}\n
###ground_truth:\n
{ground_truth}\n
###judge:\n
'''
user_prompt = {
"role": "user", "content": user_input
}
# Initialize the chat history
chat_history = [system_prompt, user_prompt]
#Requenst limits対策
time.sleep(1.5)
response = client.chat.completions.create(model="meta-llama/llama-4-scout-17b-16e-instruct",
messages=chat_history,
max_tokens=1000,
temperature=0.01)
# Print the response
llm_judge=response.choices[0].message.content
llm_judge=find_last_boxed_content(llm_judge)
return llm_judge
def extract_thought_and_answer(solution_str: str) -> tuple[str, str, bool]:
"""
文字列から<think>...</think>と最後の\\boxed{...}を抽出します。
\\boxed{...}内の入れ子括弧に対応しています。
"""
# <think>...</think> の抽出ロジックは変更ありません
think_match = re.search(r"<think>(.*?)</think>", solution_str, re.DOTALL)
if think_match:
thinking_process = think_match.group(1).strip()
is_format_valid = True
else:
thinking_process = ""
is_format_valid = False
# \\boxed{...} の抽出を新しい堅牢な関数に置き換えます
answer = find_last_boxed_content(solution_str)
return thinking_process, answer, is_format_valid
def _compute_repetition_penalty(text: str) -> float:
"""
n-gramの頻度に基づいて繰り返しペナルティを計算します。
"""
words = text.split()
if len(words) < NGRAM_SIZE:
return 0.0
# n-gramを生成
ngrams = [" ".join(words[i:i+NGRAM_SIZE]) for i in range(len(words) - NGRAM_SIZE + 1)]
if not ngrams:
return 0.0
ngram_counts = Counter(ngrams)
frequent_ngrams = {k: v for k, v in ngram_counts.items() if v > NGRAM_FREQ_THRESHOLD}
if not frequent_ngrams:
return 0.0
term1 = len(frequent_ngrams) / len(ngrams)
max_freq = max(frequent_ngrams.values())
total_possible_ngrams = len(words) / NGRAM_SIZE if len(words) > 0 else 1
term2 = max_freq / total_possible_ngrams
penalty = -max(term1, term2)
return penalty
def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None):
"""
Phi-4-reasoning論文で説明されている報酬関数に基づいて最終的なスコアを計算します。
Args:
solution_str: モデルから生成された完全なテキスト。(tokenizedではなく、文字列形式)
ground_truth: 正解。
data_source: データソースの名前。現在は "gsm8k" のみ対応。
"""
# 1. 出力文字列を解析し、フォーマットを検証
thinking_process, answer, is_format_valid = extract_thought_and_answer(solution_str)
L=len(TOKENIZER.tokenize(solution_str))
print("---solution_str---")
print(solution_str) # Debugging output
# ★追加: </think>以降のanswerを抽出してトークン長を測る
post_think_match = re.search(r"</think>(.*)$", solution_str, re.DOTALL)
answer_segment = post_think_match.group(1).strip() if post_think_match else ""
answer_token_len = len(TOKENIZER.tokenize(answer_segment))
signal.signal(signal.SIGALRM, timeout_handler)
# 30秒でタイムアウトするように設定(必要に応じて変更)
signal.alarm(30)
try:
latex_answer=parse_latex(str(answer).lower(),backend="lark")
latex_ground_truth=parse_latex(str(ground_truth).lower(),backend="lark")
except Exception as e: # 必要に応じて全ての例外をキャッチ
latex_answer=str(answer).lower().replace(" ","")
latex_ground_truth=str(ground_truth).lower().replace(" ","")
finally:
signal.alarm(0)
print("---is_format_valid---")
print(is_format_valid)
print("---answer---")
print(answer) # Debugging output
print("---ground_truth---")
print(ground_truth) # Debugging output
# 2. フォーマット違反のオーバーライドを処理
# <think>タグが不正な場合は is_format_valid が False になる
# answerが不適切な形の場合もフォーマット違反とする
if not is_format_valid:
r_acc_scaled = -1.0
# 生成が不完全な場合
elif L >= L_MAX-1:
# imcomplete(eostokenなし)はこの関数では厳密な実装はできないので,max_lengthを超えた場合にフォーマット違反として扱う
# ここでは、L_MAXを超える場合にフォーマット違反として扱う
# (Lは開始トークンおよび終了トークンを含まず,L_MAXは終了トークンを含むため,L_MAX-1と比較)
# TODO:imcompleteの完全な実装
r_acc_scaled = -0.5
else:
# 3. 回答が正解かどうかを報酬に反映
#ground_truthがlatex構文に適していなかった場合,元のanswerと比較する
is_correct= (latex_answer is not None and latex_answer == latex_ground_truth)
if not is_correct:
signal.alarm(30)
try:
llm_correct_judge=groq_complex_match(answer,ground_truth)
except Exception as e: # 必要に応じて全ての例外をキャッチ
llm_correct_judge="False"
finally:
signal.alarm(0)
if llm_correct_judge=="True":
is_correct=True
else:
is_correct=False
# 全トークン数を求める
L= len(TOKENIZER.tokenize(solution_str))
if is_correct:
rho_plus = min(1.0, max(0, L - L_POS_CONTROL) / (L_MAX - L_POS_CONTROL))
cos_term = 0.5 * (R_MAX_POS - R_MIN_POS) * (1 + math.cos(math.pi * rho_plus))
r_acc_scaled = R_MIN_POS + cos_term
else:
rho_minus = min(1.0, L / L_NEG_CONTROL)
cos_term = 0.5 * (R_MIN_NEG - R_MAX_NEG) * (1 + math.cos(math.pi * rho_minus))
r_acc_scaled = R_MAX_NEG + cos_term
# 4. 繰り返しペナルティを計算 (文字列全体を対象)
r_rep = _compute_repetition_penalty(solution_str)
# ★ 追加: short answerペナルティ(短いアンサーの場合ペナルティ)
r_len = SHORT_ANSWER_PENALTY if answer_token_len <= ANSWER_LEN_THRESHOLD else 0.0
print("---answer_token_len---")
print(answer_token_len)
# 5. 最終的な重み付きスコアを計算
final_score = (W_ACC * r_acc_scaled) + (W_REP * r_rep)+(W_LEN*r_len)
print("---final_score---")
print(final_score) # Debugging output
return final_score
最後に
本記事では報酬関数の数式判定方法についてlatex parserおよび外部LLMを用いた判定方法についてまとめました.
-
latex parserで表記揺れ吸収 → 多くのケースで有効
-
それでも数学的等価性の判定は難しく、外部 LLM を併用すると精度が上がる
-
無料API前提だと Rate Limits がボトルネックになりうる
複雑な数学問題は,単に数字が出てくるだけでなく,数式が回答として出てくるため,外部LLMを使用しないと,正確な判定は難しい印象でした.また,HLEは問題が難しく,回答も複雑なものになりがちなので,判定にも大きなモデルを使用した方が良いと思いますが,GroqのRate Limitsに引っかかり,無料では難しかったです.
最後になりますが,滅多にない機械を提供して下さった松尾・岩沢研の皆様,インフラ周りを提供してくださったさくらインターネットはじめ関係者の皆様には心より感謝申し上げます.
本プロジェクトは,国立研究開発法人新エネルギー・産業技術総合開発機構(以下「NEDO」)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われます.