import os
import subprocess
import requests
import re
from datasets import load_dataset
import time
import json
from datetime import datetime
# BACKEND: "ollama" or "openrouter" で切り替え
# BACKEND = 'ollama'
BACKEND = 'openrouter'
# Ollama 用設定
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "gpt-oss:latest")
# OpenRouter 用設定
OPENROUTER_API_KEY = os.getenv(
"OPENROUTER_API_KEY")
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "openai/gpt-oss-20b")
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
def extract_last_number(text: str):
"""
文字列中に出現する数字列のうち、最後のものを返す。
見つからなければ None を返す。
"""
if text is None:
return None
matches = re.findall(r"-?\d+", text)
if not matches:
return None
return matches[-1]
def build_prompt(question: str) -> str:
return f"""You are a helpful math tutor. Read the following grade-school math problem and answer it.
Question:
{question}
Please show a short reasoning and then give the final answer as a number at the end, prefixed with 'Answer: '."""
def openrouter_generate(prompt: str) -> str:
"""
OpenRouter の /v1/chat/completions を使ってモデルを呼び出す。
"""
if not OPENROUTER_API_KEY:
raise RuntimeError("OPENROUTER_API_KEY is not set")
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
}
payload = {
"model": OPENROUTER_MODEL,
"messages": [
{"role": "system", "content": "You are a helpful math tutor."},
{"role": "user", "content": prompt},
],
"temperature": 0.0,
"max_tokens": 512,
}
resp = requests.post(OPENROUTER_BASE_URL,
headers=headers, json=payload, timeout=60)
if resp.status_code != 200:
raise RuntimeError(f"OpenRouter error: {resp.status_code} {resp.text}")
data = resp.json()
try:
return data["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as e:
raise RuntimeError(
f"Unexpected OpenRouter response format: {data}") from e
def generate_answer(prompt: str) -> str:
"""
BACKEND 設定に応じて、Ollama または OpenRouter を呼び分ける。
"""
if BACKEND == "ollama":
return ollama_generate(prompt)
elif BACKEND == "openrouter":
return openrouter_generate(prompt)
else:
raise RuntimeError(f"Unknown LLM_BACKEND: {BACKEND}")
def ollama_generate(prompt: str) -> str:
proc = subprocess.run(
["ollama", "run", OLLAMA_MODEL],
input=prompt,
text=True,
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError(f"Ollama error: {proc.stderr}")
return proc.stdout.strip()
def main():
dataset = load_dataset("gsm8k", "main")
test_split = dataset["test"]
num_samples = 5
questions = [test_split[i]["question"] for i in range(num_samples)]
ground_truths = [test_split[i]["answer"] for i in range(num_samples)]
correct_count = 0
per_sample_results = []
start_time = time.time()
for i in range(num_samples):
prompt = build_prompt(questions[i])
model_answer = generate_answer(prompt)
# model_answer と ground_truth から最後に出現する数字列を抽出
pred_num = extract_last_number(model_answer)
gt_num = extract_last_number(ground_truths[i])
is_correct = (
pred_num is not None and gt_num is not None and pred_num == gt_num)
if is_correct:
correct_count += 1
check = "OK" if is_correct else "NG"
per_sample_results.append({
"idx": i,
"question": questions[i],
"model_answer": model_answer,
"pred_num": pred_num,
"ground_truth": ground_truths[i],
"gt_num": gt_num,
"is_correct": is_correct
})
print(f"=== Sample {i} ===")
print("question:")
print(questions[i])
print("\nmodel_answer:")
print(model_answer)
print("\nextracted_number:")
print(pred_num)
print("\nground_truth:")
print(ground_truths[i])
print("\ncheck:")
print(check)
print("\n" + "-" * 60 + "\n")
end_time = time.time()
print(f"Accuracy: {correct_count}/{num_samples}")
print(f"Total time: {end_time - start_time:.3f} seconds")
# Save log as JSON with timestamp
os.makedirs("log", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
run_data = {
"meta": {
"timestamp": timestamp,
"backend": BACKEND,
"ollama_model": OLLAMA_MODEL,
"openrouter_model": OPENROUTER_MODEL,
"num_samples": num_samples
},
"prompt_config": {
"template": build_prompt("{question}"),
"temperature": 0.0,
"max_tokens": 512
},
"extract_config": {
"method": "extract_last_number",
"regex": r"-?\d+"
},
"summary": {
"correct": correct_count,
"accuracy": correct_count / num_samples,
"total_time_sec": end_time - start_time
},
"samples": per_sample_results
}
log_path = os.path.join("log", f"run_{timestamp}.json")
with open(log_path, "w") as f:
json.dump(run_data, f, ensure_ascii=False, indent=2)
print(f"Saved log to {log_path}")
if __name__ == "__main__":
main()