0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMで数学のタスクを解いてみる

Posted at
  • GSM8Kを利用
    • サンプルとして5問のみを利用
  • モデル
    • local(ollama)とopenrouterを切り替えられる
  • プロンプト
    • build_prompt()として実装
      • 改善の余地あり
  • 解答の抽出
    • LLMはあくまで文字列を返却するので、解答を抽出してあげる処理が必要。シンプルに文字列中に出現する最後の数字列とした
      • 改善の余地あり
  • temperature
    • 数学の問題を解かせるので特にランダム性は不要のため、0.0とした
  • max_tokens
    • LLMの出力トークンに対するハードリミット。GSM8Kでは、この数値で十分だった
import os
import subprocess
import requests
import re
from datasets import load_dataset
import time
import json
from datetime import datetime

# BACKEND: "ollama" or "openrouter" で切り替え
# BACKEND = 'ollama'
BACKEND = 'openrouter'

# Ollama 用設定
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "gpt-oss:latest")

# OpenRouter 用設定
OPENROUTER_API_KEY = os.getenv(
    "OPENROUTER_API_KEY")
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "openai/gpt-oss-20b")
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"


def extract_last_number(text: str):
    """
    文字列中に出現する数字列のうち、最後のものを返す。
    見つからなければ None を返す。
    """
    if text is None:
        return None
    matches = re.findall(r"-?\d+", text)
    if not matches:
        return None
    return matches[-1]


def build_prompt(question: str) -> str:
    return f"""You are a helpful math tutor. Read the following grade-school math problem and answer it.

Question:
{question}

Please show a short reasoning and then give the final answer as a number at the end, prefixed with 'Answer: '."""


def openrouter_generate(prompt: str) -> str:
    """
    OpenRouter の /v1/chat/completions を使ってモデルを呼び出す。
    """
    if not OPENROUTER_API_KEY:
        raise RuntimeError("OPENROUTER_API_KEY is not set")

    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": OPENROUTER_MODEL,
        "messages": [
            {"role": "system", "content": "You are a helpful math tutor."},
            {"role": "user", "content": prompt},
        ],
        "temperature": 0.0,
        "max_tokens": 512,
    }

    resp = requests.post(OPENROUTER_BASE_URL,
                         headers=headers, json=payload, timeout=60)
    if resp.status_code != 200:
        raise RuntimeError(f"OpenRouter error: {resp.status_code} {resp.text}")

    data = resp.json()
    try:
        return data["choices"][0]["message"]["content"].strip()
    except (KeyError, IndexError) as e:
        raise RuntimeError(
            f"Unexpected OpenRouter response format: {data}") from e


def generate_answer(prompt: str) -> str:
    """
    BACKEND 設定に応じて、Ollama または OpenRouter を呼び分ける。
    """
    if BACKEND == "ollama":
        return ollama_generate(prompt)
    elif BACKEND == "openrouter":
        return openrouter_generate(prompt)
    else:
        raise RuntimeError(f"Unknown LLM_BACKEND: {BACKEND}")


def ollama_generate(prompt: str) -> str:
    proc = subprocess.run(
        ["ollama", "run", OLLAMA_MODEL],
        input=prompt,
        text=True,
        capture_output=True,
        check=False,
    )
    if proc.returncode != 0:
        raise RuntimeError(f"Ollama error: {proc.stderr}")
    return proc.stdout.strip()


def main():
    dataset = load_dataset("gsm8k", "main")
    test_split = dataset["test"]

    num_samples = 5
    questions = [test_split[i]["question"] for i in range(num_samples)]
    ground_truths = [test_split[i]["answer"] for i in range(num_samples)]

    correct_count = 0
    per_sample_results = []
    start_time = time.time()

    for i in range(num_samples):
        prompt = build_prompt(questions[i])
        model_answer = generate_answer(prompt)

        # model_answer と ground_truth から最後に出現する数字列を抽出
        pred_num = extract_last_number(model_answer)
        gt_num = extract_last_number(ground_truths[i])

        is_correct = (
            pred_num is not None and gt_num is not None and pred_num == gt_num)
        if is_correct:
            correct_count += 1
        check = "OK" if is_correct else "NG"

        per_sample_results.append({
            "idx": i,
            "question": questions[i],
            "model_answer": model_answer,
            "pred_num": pred_num,
            "ground_truth": ground_truths[i],
            "gt_num": gt_num,
            "is_correct": is_correct
        })

        print(f"=== Sample {i} ===")
        print("question:")
        print(questions[i])
        print("\nmodel_answer:")
        print(model_answer)
        print("\nextracted_number:")
        print(pred_num)
        print("\nground_truth:")
        print(ground_truths[i])
        print("\ncheck:")
        print(check)
        print("\n" + "-" * 60 + "\n")

    end_time = time.time()
    print(f"Accuracy: {correct_count}/{num_samples}")
    print(f"Total time: {end_time - start_time:.3f} seconds")

    # Save log as JSON with timestamp
    os.makedirs("log", exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_data = {
        "meta": {
            "timestamp": timestamp,
            "backend": BACKEND,
            "ollama_model": OLLAMA_MODEL,
            "openrouter_model": OPENROUTER_MODEL,
            "num_samples": num_samples
        },
        "prompt_config": {
            "template": build_prompt("{question}"),
            "temperature": 0.0,
            "max_tokens": 512
        },
        "extract_config": {
            "method": "extract_last_number",
            "regex": r"-?\d+"
        },
        "summary": {
            "correct": correct_count,
            "accuracy": correct_count / num_samples,
            "total_time_sec": end_time - start_time
        },
        "samples": per_sample_results
    }
    log_path = os.path.join("log", f"run_{timestamp}.json")
    with open(log_path, "w") as f:
        json.dump(run_data, f, ensure_ascii=False, indent=2)
    print(f"Saved log to {log_path}")


if __name__ == "__main__":
    main()
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?