0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

llama.cppのバッチサイズと推論速度の関係性を検証した

0
Last updated at Posted at 2026-01-14

llama.cppのバッチサイズと推論速度の関係性を検証した

はじめに

LocalでMinimax-M2.1をホストしopencodeで動かすアプローチで、安定したコーディングエージェントになり、最近はvibeコーディングを楽しんでいます。
Prefillの速度が速いことからllama.cppを使ってMinimax-M2.1を動かしているのですが、
--batch-sizeオプションはlmstudioでllama.cppを動かす場合だと512で固定されてしまいます。batch sizeを変えることでPrefillの速度を改善できるという情報を知ったので、
batch sizeが推論速度に与える影響を、定量的に測定してみました。MiniMax-M2.1モデルを使用して、バッチサイズを128から8192まで7段階でテストしたので、その結果を共有します。

実験の目的

  • バッチサイズを上げると本当に速くなるのか?
  • どの程度まで上げると効果が増幅するのか?
  • 適切なバッチサイズはいくらか?

環境

項目 詳細
PC Mac Studio M3 Ultra (512GB メモリ)
GPU Apple Silicon (Metal使用)
モデル MiniMax-M2.1-UD-Q6_K_XL (unsloth)
ツール llama.cpp (llama-server)

バッチサイズとは

--batch-sizeは、一度に処理できるトークンの最大数を指定するオプションです。値が大きいほど、複数のリクエストを効率的に並列処理できます。

--batch-size N    : 論理的な最大バッチサイズ (デフォルト: 2048)
--ubatch-size N   : 物理的な最大バッチサイズ (デフォルト: 512)

実験条件

基本設定

  • 入力トークン数: 50,000 tokens
  • 出力トークン数: 1,000 tokens
  • 同時リクエスト数: 1(単一リクエスト)
  • コンテキストサイズ: 32,768
  • 連続バッチング: 有効 (--cont-batching)
  • 各設定でのテスト回数: 3回

テストしたバッチサイズ

128, 256, 512, 1024, 2048, 4096, 8192

実験方法

1. サーバー起動とログ取得

# 例: バッチサイズ2048でサーバーを起動
./build/bin/llama-server \
    --host 127.0.0.1 \
    --port 8080 \
    -m /path/to/model.gguf \
    --batch-size 2048 \
    --ubatch-size 2048 \
    -c 32768 \
    -n 1000 \
    --log-file server_debug_2048.log \
    --cont-batching

2. リクエスト実行

OpenAI互換APIを使用して推論を実行:

curl http://127.0.0.1:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "model",
        "prompt": "長いプロンプトテキスト...",
        "n_predict": 1000,
        "temperature": 0.7
    }'

3. ログ解析

サーバーには--log-fileオプションでデバッグログを出力させ、そこからPrefill/Evalタイムを抽出:

# extract_eval_times.py でCSVに変換
python3 extract_eval_times.py server_debug_2048.log -o eval_times_2048.csv

# stats.py で統計情報を取得
python3 stats.py eval_times_2048.csv --detailed

実験1: 基本テスト(入力10K、出力10K)

実験条件

設定
入力トークン数 10,000
出力トークン数 10,000
同時リクエスト数 1
コンテキストサイズ 32,768

結果

バッチサイズ 総合TPS* Prefill TPS** Inference TPS
128 46.13 327 31.14
256 46.03 531 31.08
512 46.15 642 31.16
1024 45.99 729 31.05
2048 46.02 765 31.08

*総合TPSはリクエスト全体のスループット
**Prefill TPSはログから抽出した加重平均

分析

入力と出力が同量の場合、総合TPSはほぼ一定(約46) となりました。

理由を確認するため、時間構成を分解してみる:

Batch Size 128:
  Prefill時間:   30.6s (8.7%)
  Inference時間: 321.5s (91.3%)

Batch Size 2048:
  Prefill時間:   13.1s (3.9%)
  Inference時間: 321.5s (96.1%)

Inference(生成)時間が全体の約90%以上を占めるため、Prefillが改善しても総合時間への寄与は小さい。

実験2: 長文入力テスト(入力50K、出力1K)

実験条件

設定
入力トークン数 50,000
出力トークン数 1,000
同時リクエスト数 1
コンテキストサイズ 32,768

全バッチサイズ結果(7種類)

バッチサイズ Prefill TPS* Inference TPS 総合TPS
512 593 32.8 256.6
1024 663 33.0 257.5
2048 702 33.0 258.6
4096 700 33.1 259.5
8192 679 33.1 258.5

*Prefill TPSはログから抽出した加重平均

バッチサイズ別改善率(128 → 8192)

二つの実験の結果をまとめると

128 → 256: +62% (327 → 531)
256 → 512: +21% (531 → 642)
512 → 1024: +14% (642 → 729)
1024 → 2048: +6%  (663 → 702) ← 頭打ち傾向
2048 → 4096: -3%  (702 → 700) ← 横ばい
4096 → 8192: -3%  (700 → 679) ←   むしろ悪化

考察

バッチサイズを上げるとPrefillが速くなる理由

  • 大きな行列演算が効率的にGPUで処理できる
  • メモリアクセスのオーバーヘッドが相対的に減少する
  • ただし、一定値(2048程度)を超えるとメモリ管理コストが増大し、逆効果になる

総合TPSが変わらない理由(実験1より)

Prefillは速くなっても、Inference時間が支配的なため変わらないように見えている。

推奨設定

ユースケース 推奨バッチサイズ
メモリ制約あり 512
バランス重視 1024 - 2048
最大Prefill速度 2048
非推奨 4096以上

再現手順

1. 環境の準備

# llama.cppをビルド
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
cmake -B build
cmake --build build --config Release -t llama-server

2. ベンチマークスクリプトの実行

# バッチサイズ512, 1024, 2048をテスト
python3 batch_bench.py

ログは server_debug_{batch_size}.log に出力されます。
batch_bench.pyの詳細は後述

3. 結果の解析

for size in 512 1024 2048; do
    python3 extract_eval_times.py server_debug_${size}.log -o eval_times_${size}.csv
    python3 stats.py eval_times_${size}.csv --detailed
done

結論

  1. 単一リクエストでは総合TPSはほぼ変化しない(約46 TPS)
  2. PrefillフェーズはBatch 128→2048で約2.4倍高速化
  3. 4096以上はむしろ遅くなる(メモリ管理オーバーヘッドのため)
  4. 推奨は1024〜2048
  5. 512でも十分速度は出ていることがわかったので、llama.cppを直接使ってLLMのサーバーを立てなくてもlmstudioで使うことになりそう

appendix

#!/usr/bin/env python3
"""
Batch size benchmark for llama-server.
Tests performance with different batch sizes: 128, 256, 512, 1024, 2048
"""

import subprocess
import time
import json
import os
import signal
import sys
import requests
from datetime import datetime

SERVER_HOST = "127.0.0.1"
SERVER_PORT = 8080
MODEL_PATH = "/Volumes/SSD-PHPU3A/lmstudio/unsloth/MiniMax-M2.1-GGUF/MiniMax-M2.1-UD-Q6_K_XL-00001-of-00004.gguf"

BATCH_SIZES = [128, 256, 512, 1024, 2048, 4096, 8192]
INPUT_TOKENS = 50000
OUTPUT_TOKENS = 1000

def wait_for_server(timeout=120):
    """Wait for server to be ready."""
    start = time.time()
    while time.time() - start < timeout:
        try:
            resp = requests.get(f"http://{SERVER_HOST}:{SERVER_PORT}/health", timeout=5)
            if resp.status_code == 200:
                return True
        except:
            pass
        time.sleep(1)
    return False

def start_server(batch_size, ubatch_size):
    """Start llama-server with specified batch size."""
    cmd = [
        "./build/bin/llama-server",
        "--host", SERVER_HOST,
        "--port", str(SERVER_PORT),
        "-m", MODEL_PATH,
        "--batch-size", str(batch_size),
        "--ubatch-size", str(ubatch_size),
        "-c", "32768",
        "-n", str(OUTPUT_TOKENS),
        "--log-file", f"server_debug_{batch_size}.log",
        "--cont-batching"
    ]
    print(f"\n{'='*60}")
    print(f"Starting server with batch-size={batch_size}, ubatch-size={ubatch_size}")
    print(f"{'='*60}")

    # Kill any existing server on the port
    subprocess.run(["lsof", "-ti", f":{SERVER_PORT}", "|", "xargs", "kill", "-9"],
                   shell=True, capture_output=True)
    time.sleep(2)

    log_file = open(f"server_debug_{batch_size}.log", "w")
    proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT)

    if not wait_for_server():
        print(f"Server failed to start within timeout")
        proc.kill()
        return None

    print(f"Server ready!")
    return proc

def stop_server(proc):
    """Stop the server."""
    if proc:
        proc.terminate()
        try:
            proc.wait(timeout=10)
        except subprocess.TimeoutExpired:
            proc.kill()

def generate_test_prompt(tokens):
    """Generate a test prompt with approximately specified token count."""
    words = ["hello", "world", "test", "data", "sample", "text", "content",
             "generate", "process", "model", "token", "batch", "server"]
    prompt_parts = []
    current_tokens = 0

    while current_tokens < tokens:
        sentence = " ".join(words * 10)
        prompt_parts.append(sentence)
        current_tokens += len(sentence.split()) * 2

    return " ".join(prompt_parts)[:tokens]

def run_single_request():
    """Run a single request and measure timing."""
    prompt = generate_test_prompt(INPUT_TOKENS)

    start_time = time.time()

    response = requests.post(
        f"http://{SERVER_HOST}:{SERVER_PORT}/v1/completions",
        json={
            "model": "model",
            "prompt": prompt,
            "n_predict": OUTPUT_TOKENS,
            "temperature": 0.7,
            "stream": False
        },
        timeout=3600
    )

    elapsed = time.time() - start_time

    if response.status_code == 200:
        data = response.json()
        total_tokens = data.get("usage", {}).get("total_tokens", 0)
        prompt_tokens = data.get("usage", {}).get("prompt_tokens", 0)
        completion_tokens = data.get("usage", {}).get("completion_tokens", 0)

        return {
            "success": True,
            "elapsed": elapsed,
            "total_tokens": total_tokens,
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "tps_total": total_tokens / elapsed if elapsed > 0 else 0
        }
    else:
        return {
            "success": False,
            "error": response.text,
            "elapsed": elapsed
        }

def run_benchmark(batch_size, num_runs=3):
    """Run benchmark for a specific batch size."""
    results = []

    # Use same value for ubatch-size for simplicity
    ubatch_size = batch_size

    server_proc = start_server(batch_size, ubatch_size)
    if not server_proc:
        return None

    try:
        # Warmup run (shorter tokens to save time)
        print(f"\nWarmup run...")
        warmup_input = INPUT_TOKENS // 5
        prompt = generate_test_prompt(warmup_input)

        requests.post(
            f"http://{SERVER_HOST}:{SERVER_PORT}/v1/completions",
            json={
                "model": "model",
                "prompt": prompt,
                "n_predict": OUTPUT_TOKENS,
                "temperature": 0.7,
                "stream": False
            },
            timeout=600
        )
        time.sleep(2)

        # Actual benchmark runs
        print(f"\nRunning {num_runs} benchmark iterations...")
        for i in range(num_runs):
            print(f"  Run {i+1}/{num_runs}...", end=" ", flush=True)
            result = run_single_request()
            if result["success"]:
                results.append(result)
                tps = result.get("tps_total", 0)
                print(f"TPS: {tps:.2f}")
            else:
                print(f"Failed: {result.get('error', 'Unknown error')}")

    finally:
        stop_server(server_proc)

    return results

def parse_results_from_log(batch_size):
    """Parse eval times from server debug log."""
    log_file = f"server_debug_{batch_size}.log"
    if not os.path.exists(log_file):
        return []

    prefill_tps_list = []
    inference_tps_list = []

    with open(log_file, 'r') as f:
        for line in f:
            # Match patterns like: "prompt eval time =   31389.23 ms / 15722 tokens (    2.00 ms per token,   500.87 tokens per second)"
            if "prompt eval time" in line and "tokens per second" in line:
                import re
                match = re.search(r'\(\s*[\d.]+\s*ms per token,\s*([\d.]+)\s*tokens per second\)', line)
                if match:
                    tps = float(match.group(1))
                    prefill_tps_list.append(tps)

            elif "eval time" in line and "tokens per second" in line:
                import re
                match = re.search(r'\(\s*[\d.]+\s*ms per token,\s*([\d.]+)\s*tokens per second\)', line)
                if match:
                    tps = float(match.group(1))
                    inference_tps_list.append(tps)

    return prefill_tps_list, inference_tps_list

def main():
    print("=" * 70)
    print(f"Batch Size Benchmark")
    print(f"Model: {MODEL_PATH}")
    print(f"Input tokens: ~{INPUT_TOKENS}, Output tokens: ~{OUTPUT_TOKENS}")
    print(f"Test sizes: {BATCH_SIZES}")
    print("=" * 70)

    all_results = {}

    for batch_size in BATCH_SIZES:
        print(f"\n\n{'#' * 60}")
        print(f"# Testing Batch Size: {batch_size}")
        print(f"{'#' * 60}")

        results = run_benchmark(batch_size, num_runs=3)

        if results:
            all_results[batch_size] = results

            # Parse log file for TPS data
            prefill_tps_list, inference_tps_list = parse_results_from_log(batch_size)

            avg_elapsed = sum(r["elapsed"] for r in results) / len(results)
            avg_total_tps = sum(r.get("tps_total", 0) for r in results if r.get("tps_total")) / len([r for r in results if r.get("tps_total")])
            avg_prompt_tps = (sum(r["prompt_tokens"] for r in results) / avg_elapsed) if avg_elapsed > 0 else 0
            avg_completion_tps = (sum(r["completion_tokens"] for r in results) / avg_elapsed) if avg_elapsed > 0 else 0

            print(f"\n{'='*60}")
            print(f"Results for batch_size={batch_size}:")
            print(f"{'='*60}")
            print(f"Average elapsed time: {avg_elapsed:.2f}s")
            print(f"Average total TPS: {avg_total_tps:.2f}")
            if prefill_tps_list:
                print(f"Prefill TPS from log (mean): {sum(prefill_tps_list)/len(prefill_tps_list):.2f} (n={len(prefill_tps_list)})")
            if inference_tps_list:
                print(f"Inference TPS from log (mean): {sum(inference_tps_list)/len(inference_tps_list):.2f} (n={len(inference_tps_list)})")

    # Summary table
    print("\n" + "=" * 70)
    print("SUMMARY - Batch Size Comparison")
    print("=" * 70)
    print(f"{'Batch Size':<15} {'Avg TPS':>12} {'Status':>10}")
    print("-" * 40)

    for batch_size in BATCH_SIZES:
        if batch_size in all_results and all_results[batch_size]:
            results = all_results[batch_size]
            avg_tps = sum(r.get("tps_total", 0) for r in results) / len([r for r in results if r.get("tps_total")])
            print(f"{batch_size:<15} {avg_tps:>12.2f} {'OK':>10}")
        else:
            print(f"{batch_size:<15} {'N/A':>12} {'FAILED':>10}")

    # Save detailed results
    with open("benchmark_results.json", "w") as f:
        json.dump(all_results, f, indent=2)

    print("\nResults saved to benchmark_results.json")

if __name__ == "__main__":
    main()
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?