1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RRFのk=60は最適とは限らない

1
Posted at

TL;DR

  • RRF原著の k=60 は「パイロットで固定した値」で、near-optimal だが choice was not critical: 本質的ではない という位置づけ
  • kが小さいほど「どれか1つでランキング上位」 の文書のスコアが高くなり、kが大きいほど「複数ランキングでそこそこ上位」 の文書のスコアが高くなる

はじめに

こんにちは、株式会社IGSAでAIエンジニアをしている加藤です。
IGSAのプロジェクトにおいて、RAGを用いたアプリケーション実装を行うことがあります。
RAG(Retrieval Augmented Generation)の検索部分は、ハイブリッド検索がよく用いられます。例えばベクトル検索 + キーワードベースの検索(BM25)等のハイブリッドがあります。これらのランキングを統合する手法としてはRRF(Reciprocal Rank Fusion) がよく使われます。これは順位のみを使う手法のためスコアのスケール合わせや学習が不要であるという利点があります。

このRRFの式には定数k=60が定番として用いられますが、慣例的に使われることが多く根拠が曖昧なまま固定されがちです。この記事は、次の点を明確にします

  • なぜk=60が定番なのか
  • kを変更することによるランキングの変化
  • kの策定方法

RRF原著論文の確認

RRF(Reciprocal Rank Fusion)の原著は、より一般的なrank fusion(複数の順位リストの統合) の文脈で提案されています。論文の目的としては学習ベースのlearning-to-rankが注目される一方で、学習データ無しで強い融合手法が欲しいという動機があるようです。

原著が扱っている入力と目的

入力 :複数の「ランキング(順位リスト)」
ここでのランキングは、複数のIRシステム/同一システムの異なる設定/参加者が提出した run/learning-to-rank の出力など、とにかく順位さえあれば良いという扱いです。

目的 :教師なしで統合し、既存の融合手法より良いランキングを得ること
比較対象としてCondorcet FuseやCombMNZなど、従来の融合が挙げられています。

評価 :主にMAP(Mean Average Precision)
加えてP@k / R-precision / NDCG でも同様の傾向を述べています。

RRFパラメータ(k)の効果

RRFは次のスコアで順位を統合します:

\mathrm{RRFscore}(d)=\sum_{r \in R} \frac{1}{k + r(d)}

ここでkは順位差の効き方 を調整する定数で、原著はk=60を「パイロットで固定し、その後の検証で変えなかった」と明示しています。 
加えて、パイロット結果としてk=60がnear-optimalだがchoice was not critical(選択は本質的ではない)とも述べています。 
また原著は、kの役割として「外れ値的なシステム(outlier systems)が高順位を付けたときの影響を緩和する」と説明しています。 

直感的なkの影響

  • kが小さい:上位順位の寄与が相対的に大きくなる
    → どれか1つのランキングで極端に上位に来た文書が押し上げられやすい
  • kが大きい:上位順位の差がなだらかになり、寄与が均される
    → 複数ランキングでそこそこ上位に出る文書が有利になりやすい

具体例

ランキングが2つ(Vector / BM25)で、重みが同じとします。ある文書の順位が次の2パターンだとします。

  • パターンA: Vector=1位、BM25=100位
  • パターンB: Vector=10位、BM25=10位

RRFスコアは以下のようになります
k=10
• A: 1/(10+1) + 1/(10+100) ≈ 0.0909 + 0.0091 = 0.1000
• B: 1/(10+10) + 1/(10+10) = 0.05 + 0.05 = 0.1000
→ 尖ったランキングによりスコアが同格

k=60
• A: 1/(60+1) + 1/(60+100) ≈ 0.0164 + 0.00625 = 0.0227
• B: 1/(60+10) + 1/(60+10) ≈ 0.0143 + 0.0143 = 0.0286
→ 複数の順位で合意されている場合のスコアが高い

つまり、kを触るときのイメージは 尖ったランキングを拾う vs 複数の合意を強める のトレードオフといえます。

定数kを変化させた場合の変化の可視化

ここからは実際にkを変えると何が起きるかを、可視化とメトリクスで確認します。

使用データセット(SciFact / MTEB)

今回の検証では、Hugging Faceのmteb/scifact を使いました。SciFact科学的主張に対して、根拠となる論文アブストラクトを検索して支持/反証の根拠を集める科学的主張検証(scientific claim verification) のデータセットです。コーパスは論文アブストラクト、クエリは主張文で、qrels(関連判定)により「この主張に関連する根拠アブストラクト」が与えられています。

本記事においては検索タスク(クエリ→関連文書のランキング) として利用し、ベクトル検索+BM25の2ランキングをRRFで融合して評価しました。

各クエリにおけるRRFスコアの可視化

今回は ランカーが2つ(ベクトル検索とBM25) の場合を想定し、横軸をベクトル検索順位、縦軸をBM25順位にして、RRFスコアの大きさを色の濃さで表します。
(実装コードは記事末尾に記載しています)

rrf_rank_scatter_query_34.png

プロットは以下のように作っています。

  • RRFスコアを色の濃淡で表示
  • 正解文書かつ、RRF後の順位がTop10以内 → 青丸
  • 正解文書だが、RRF後の順位がTop10圏外 → 赤丸
  • 正解文書でないが、RRF後の順位がTop10以内 → 灰色の丸

この可視化により、次の点が読み取れます。

  • kが小さい:等高線(濃淡の境界)が軸方向に伸びやすく、どちらか片方の順位が極端に良い点が持ち上がりやすい
    尖ったランキングを重視しやすい
  • kが大きい:濃淡がよりなだらかで、片方だけ良い点が特段有利になりにくく、両方とも良い領域が強くなる
    ランキングの合意を重視しやすい
  • 今回は、k=10ではTop10内だが k=60,200ではTop10外になるケースが確認できます。プロット位置から分かる通り、この文書は BM25順位は高くないがベクトル順位が高いため、kが小さいと拾えて、kが大きいと落ちる、という挙動になっています。

rrf_rank_scatter_query_0.png
こちらは別クエリのプロットです。正解文書は両ランキングの順位が高いためRRFのパラメータは問題になっていませんが、top10に入っている文書の分布がk=10とそれ以外で変わっていることがわかります。

定数kの選定

次に、テストセット上でkを変化させた場合の検索性能指標をプロットします。

rrf_constant_metrics.png

各指標の意味は以下です。(全て高いほうが良い指標)

  • NDCG@10:Top-10の関連度の並びの良さ(上位ほど重みが高い)を図る指標
  • Capped Recall@10:取得できた正解数 / min(正解総数, 10)
    • 複数文書が正解となるためこのようにRecallを計算しました
  • Hit@10:Top-10に正解が1件でも入れば1、入らなければ0
  • MRR:最初に出てくる正解順位の逆数の平均

この図から、このテストケースでは kが0〜20程度の範囲で性能が良く、kを上げると低下する傾向が見られます。
つまりk=60は常に最適ではないことが確認できます

一般的にはこのようなプロセスでkを変化させた場合の目的指標を確認し、性能が良いkを選定することが良いと言えると思います。しかしこの選び方はテストセットの分布に最適化されるため、適切なデータ作成が特に重要になります。最良点より安定している点を選ぶなどの運用が良いこともあると思います。

考察

今回の結果で kが小さいほど性能が高いということは、少なくとも今回の設定ではどちらか一方のランキングで強く上位に来る文書を拾う方が有利だったという解釈になります。

経験則的にも、ベクトル検索は意味的近さ、BM25は表層一致を拾うため、どちらかが当たっていれば正解というクエリが一定数あるのは自然です。そういう場合、kを小さくしてどちらかが高い場合を拾いやすくすることは合理的となり得ます。

注意点としては、今回はSciFact(科学ドメイン)で英語向けの設定であるためkは低い方がいいという一般化はできない点です。最終的には、品質が高い評価セットを作りパラメータを決めるというプロセスが一番重要です。

まとめ

本記事では、RRFの定数kについて

  • 原著におけるk=60の位置づけ
  • kが小さい/大きいときの挙動(尖ったランキング vs 合意)
  • 可視化とメトリクス元にした選定方法

を整理しました。
本記事の内容は定数kを変えるだけなので実験コストが低く、キーワード的/ベクトル的に近いはずなのに、なぜか取れない文書があるといった症状のときに、調整して良いかと思います。

また、最近はAgentic Search的な構成で検索の役割が分散するケースも増えています。その場合、RRFの定数kなど単発のパラメータの重要度は相対的に下がることもあります。
ただし「評価セット作成 → パラメータ変更 → 検索指標測定(→ 可視化)」という流れ自体は、検索品質を人間が検証するうえで定式化しやすく、再利用性が高いパイプラインだと考えられます。最終的には高品質な評価セットを持つことが一番重要だと考えられます。

実装コード

コードは各種コーディングツールを一部用いて実装しています。
出力するクエリや実験設定は6) 実験パラメータ で変更可能です。

1) 依存関係インストール

RUN_PIP_INSTALL = True  # Colabで初回実行する場合は True に変更

if RUN_PIP_INSTALL:
    import subprocess
    import sys

    packages = [
        "numpy",
        "pandas",
        "matplotlib",
        "japanize-matplotlib",
        "datasets>=4.5",
        "rank-bm25",
        "sentence-transformers",
        "tqdm",
    ]
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", *packages])
    print("依存関係をインストールしました")
else:
    print("依存関係インストールをスキップしました")

2) インポートと実行環境設定

import importlib.util
import json
import os
import re
import subprocess
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Sequence, Tuple

# キャッシュ設定
PROJECT_ROOT = Path.cwd()
CACHE_ROOT = PROJECT_ROOT / ".cache"
CACHE_ROOT.mkdir(parents=True, exist_ok=True)
(CACHE_ROOT / "matplotlib").mkdir(parents=True, exist_ok=True)
(CACHE_ROOT / "huggingface" / "transformers").mkdir(parents=True, exist_ok=True)
(CACHE_ROOT / "xdg").mkdir(parents=True, exist_ok=True)

os.environ.setdefault("MPLCONFIGDIR", str(CACHE_ROOT / "matplotlib"))
os.environ.setdefault("HF_HOME", str(CACHE_ROOT / "huggingface"))
os.environ.setdefault("TRANSFORMERS_CACHE", str(CACHE_ROOT / "huggingface" / "transformers"))
os.environ.setdefault("XDG_CACHE_HOME", str(CACHE_ROOT / "xdg"))

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from tqdm import tqdm


def configure_japanese_font() -> None:
    # japanize-matplotlib を優先利用して日本語フォントを設定する
    try:
        import japanize_matplotlib  # noqa: F401
        print("japanize-matplotlib: already available")
    except ModuleNotFoundError:
        is_colab = importlib.util.find_spec("google.colab") is not None
        if is_colab:
            subprocess.check_call([
                os.sys.executable,
                "-m",
                "pip",
                "install",
                "-q",
                "japanize-matplotlib",
            ])
            import japanize_matplotlib  # noqa: F401
            print("japanize-matplotlib: installed")
        else:
            print("Warning: japanize-matplotlib が未導入です。pip install japanize-matplotlib を実行してください")

    plt.rcParams["axes.unicode_minus"] = False


configure_japanese_font()
TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_]{2,}")

print(f"PROJECT_ROOT: {PROJECT_ROOT}")

3) データ読み込みロジック(SciFact)

@dataclass
class RetrievalDataset:
    dataset_name: str
    corpus_ids: np.ndarray
    corpus_texts: List[str]
    query_ids: np.ndarray
    query_texts: List[str]
    relevance_matrix: np.ndarray
    query_groups: List[str]
    metadata: Dict[str, str]


def clean_text(text: str) -> str:
    return " ".join(text.split())


def tokenize(text: str) -> List[str]:
    return TOKEN_PATTERN.findall(text.lower())


def parse_float_list(raw_values: str) -> List[float]:
    values = [float(v.strip()) for v in raw_values.split(",") if v.strip()]
    if not values:
        raise ValueError("At least one float value is required.")
    return values


def load_scifact_dataset(
    query_size: int,
    corpus_size: int,
    min_doc_chars: int,
    seed: int,
    cache_dir: Path,
    split: str = "test",
) -> RetrievalDataset:
    rng = np.random.default_rng(seed)
    hf_cache_dir = cache_dir / "hf_datasets"
    hf_cache_dir.mkdir(parents=True, exist_ok=True)

    corpus_ds = load_dataset(
        "mteb/scifact",
        "corpus",
        split="corpus",
        cache_dir=str(hf_cache_dir),
    )
    query_ds = load_dataset(
        "mteb/scifact",
        "queries",
        split="queries",
        cache_dir=str(hf_cache_dir),
    )
    qrels_ds = load_dataset(
        "mteb/scifact",
        "default",
        split=split,
        cache_dir=str(hf_cache_dir),
    )

    corpus_text_by_id: Dict[str, str] = {}
    for row in corpus_ds:
        corpus_id = str(row["_id"])
        title = clean_text(str(row.get("title", "")))
        text = clean_text(str(row.get("text", "")))
        merged = clean_text(f"{title} {text}".strip())
        if len(merged) >= min_doc_chars:
            corpus_text_by_id[corpus_id] = merged

    query_text_by_id: Dict[str, str] = {}
    for row in query_ds:
        query_id = str(row["_id"])
        text = clean_text(str(row.get("text", "")))
        if text:
            query_text_by_id[query_id] = text

    raw_qrels: Dict[str, set] = defaultdict(set)
    for row in qrels_ds:
        if float(row["score"]) <= 0:
            continue
        query_id = str(row["query-id"])
        corpus_id = str(row["corpus-id"])
        raw_qrels[query_id].add(corpus_id)

    valid_qrels: Dict[str, List[str]] = {}
    for query_id, rel_ids in raw_qrels.items():
        if query_id not in query_text_by_id:
            continue
        filtered_rel_ids = [corpus_id for corpus_id in rel_ids if corpus_id in corpus_text_by_id]
        if filtered_rel_ids:
            valid_qrels[query_id] = filtered_rel_ids

    candidate_query_ids = sorted(valid_qrels.keys())
    if len(candidate_query_ids) == 0:
        raise RuntimeError("No valid queries with positive qrels were found in mteb/scifact.")

    selected_query_size = min(query_size, len(candidate_query_ids))
    selected_query_ids = rng.choice(candidate_query_ids, size=selected_query_size, replace=False).tolist()

    relevant_corpus_ids = {
        corpus_id
        for query_id in selected_query_ids
        for corpus_id in valid_qrels[query_id]
    }
    all_corpus_ids = list(corpus_text_by_id.keys())

    if corpus_size > 0:
        if corpus_size < len(relevant_corpus_ids):
            raise ValueError(
                "corpus_size is smaller than the number of required relevant docs. "
                f"corpus_size={corpus_size}, required>={len(relevant_corpus_ids)}"
            )
        negative_pool = [corpus_id for corpus_id in all_corpus_ids if corpus_id not in relevant_corpus_ids]
        additional_needed = min(corpus_size - len(relevant_corpus_ids), len(negative_pool))
        sampled_negatives = (
            rng.choice(negative_pool, size=additional_needed, replace=False).tolist()
            if additional_needed > 0
            else []
        )
        selected_corpus_ids = sorted(list(relevant_corpus_ids) + sampled_negatives)
    else:
        selected_corpus_ids = sorted(all_corpus_ids)

    corpus_id_to_idx = {corpus_id: idx for idx, corpus_id in enumerate(selected_corpus_ids)}
    relevance_matrix = np.zeros((len(selected_query_ids), len(selected_corpus_ids)), dtype=bool)

    for query_idx, query_id in enumerate(selected_query_ids):
        for corpus_id in valid_qrels[query_id]:
            if corpus_id in corpus_id_to_idx:
                relevance_matrix[query_idx, corpus_id_to_idx[corpus_id]] = True

    return RetrievalDataset(
        dataset_name="mteb_scifact",
        corpus_ids=np.array(selected_corpus_ids, dtype=object),
        corpus_texts=[corpus_text_by_id[corpus_id] for corpus_id in selected_corpus_ids],
        query_ids=np.array(selected_query_ids, dtype=object),
        query_texts=[query_text_by_id[query_id] for query_id in selected_query_ids],
        relevance_matrix=relevance_matrix,
        query_groups=["scifact" for _ in selected_query_ids],
        metadata={
            "source": "mteb/scifact",
            "split": split,
            "note": "relevance=qrels (BEIR style, mostly 1 positive, max 5)",
        },
    )


4) ランキング計算ロジック

def scores_to_rank_matrix(scores: np.ndarray) -> np.ndarray:
    num_queries, num_docs = scores.shape
    rank_matrix = np.empty((num_queries, num_docs), dtype=np.int32)
    for query_idx in range(num_queries):
        order = np.argsort(-scores[query_idx], kind="mergesort")
        rank_matrix[query_idx, order] = np.arange(1, num_docs + 1, dtype=np.int32)
    return rank_matrix


def compute_vector_rank_matrix(
    model_name: str,
    corpus_texts: Sequence[str],
    query_texts: Sequence[str],
    batch_size: int,
    cache_folder: Path,
) -> np.ndarray:
    model = SentenceTransformer(model_name, cache_folder=str(cache_folder))

    corpus_embeddings = model.encode(
        list(corpus_texts),
        batch_size=batch_size,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )
    query_embeddings = model.encode(
        list(query_texts),
        batch_size=batch_size,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )

    scores = np.matmul(query_embeddings, corpus_embeddings.T)
    return scores_to_rank_matrix(scores)


def compute_bm25_rank_matrix(
    corpus_texts: Sequence[str],
    query_texts: Sequence[str],
) -> np.ndarray:
    tokenized_corpus = [tokenize(text) for text in corpus_texts]
    bm25 = BM25Okapi(tokenized_corpus)

    scores = np.zeros((len(query_texts), len(corpus_texts)), dtype=np.float32)
    for query_idx, query in enumerate(tqdm(query_texts, desc="BM25 scoring")):
        tokenized_query = tokenize(query)
        scores[query_idx] = bm25.get_scores(tokenized_query)

    return scores_to_rank_matrix(scores)


def compute_rrf_rank_for_query(
    vector_rank: np.ndarray,
    bm25_rank: np.ndarray,
    rrf_constant: float,
    vector_weight: float,
    bm25_weight: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    rrf_scores = (
        vector_weight / (rrf_constant + vector_rank.astype(np.float32))
        + bm25_weight / (rrf_constant + bm25_rank.astype(np.float32))
    )
    order = np.argsort(-rrf_scores, kind="mergesort")
    final_rank = np.empty_like(order)
    final_rank[order] = np.arange(1, len(order) + 1, dtype=np.int32)
    return order, final_rank, rrf_scores

5) 指標計算と可視化ロジック

def compute_ndcg_at_k(order: np.ndarray, relevance_mask: np.ndarray, k: int) -> float:
    rel_k = relevance_mask[order[:k]].astype(np.float32)
    if rel_k.sum() == 0:
        return 0.0
    discounts = 1.0 / np.log2(np.arange(2, len(rel_k) + 2, dtype=np.float32))
    dcg = float(np.sum(rel_k * discounts))

    ideal_rel = np.sort(relevance_mask.astype(np.float32))[::-1][:k]
    idcg = float(np.sum(ideal_rel * discounts[: len(ideal_rel)]))
    if idcg <= 0:
        return 0.0
    return dcg / idcg


def compute_recall_at_k(order: np.ndarray, relevance_mask: np.ndarray, k: int) -> float:
    total_relevant = int(np.sum(relevance_mask))
    if total_relevant == 0:
        return 0.0
    hit = int(np.sum(relevance_mask[order[:k]]))
    return float(hit / total_relevant)


def compute_capped_recall_at_k(order: np.ndarray, relevance_mask: np.ndarray, k: int) -> float:
    total_relevant = int(np.sum(relevance_mask))
    if total_relevant == 0:
        return 0.0
    hit = int(np.sum(relevance_mask[order[:k]]))
    denom = min(total_relevant, k)
    return float(hit / denom)


def compute_hit_at_k(order: np.ndarray, relevance_mask: np.ndarray, k: int) -> float:
    return float(np.any(relevance_mask[order[:k]]))


def compute_mrr(order: np.ndarray, relevance_mask: np.ndarray) -> float:
    hit_positions = np.flatnonzero(relevance_mask[order])
    if len(hit_positions) == 0:
        return 0.0
    return 1.0 / float(hit_positions[0] + 1)


def sweep_rrf_constants(
    vector_rank_matrix: np.ndarray,
    bm25_rank_matrix: np.ndarray,
    relevance_matrix: np.ndarray,
    rrf_constants: Sequence[float],
    vector_weight: float,
    bm25_weight: float,
    top_k: int,
) -> pd.DataFrame:
    num_queries = vector_rank_matrix.shape[0]
    metrics_rows: List[Dict[str, float]] = []

    for rrf_constant in rrf_constants:
        ndcg_scores = []
        recall_scores = []
        capped_recall_scores = []
        hit_scores = []
        mrr_scores = []

        for query_idx in range(num_queries):
            order, _, _ = compute_rrf_rank_for_query(
                vector_rank_matrix[query_idx],
                bm25_rank_matrix[query_idx],
                rrf_constant,
                vector_weight,
                bm25_weight,
            )
            relevance_mask = relevance_matrix[query_idx]

            ndcg_scores.append(compute_ndcg_at_k(order, relevance_mask, top_k))
            recall_scores.append(compute_recall_at_k(order, relevance_mask, top_k))
            capped_recall_scores.append(compute_capped_recall_at_k(order, relevance_mask, top_k))
            hit_scores.append(compute_hit_at_k(order, relevance_mask, top_k))
            mrr_scores.append(compute_mrr(order, relevance_mask))

        metrics_rows.append(
            {
                "c": float(rrf_constant),
                f"ndcg@{top_k}": float(np.mean(ndcg_scores)),
                f"recall@{top_k}": float(np.mean(recall_scores)),
                f"capped_recall@{top_k}": float(np.mean(capped_recall_scores)),
                f"hit@{top_k}": float(np.mean(hit_scores)),
                "mrr": float(np.mean(mrr_scores)),
            }
        )

    return pd.DataFrame(metrics_rows)


def calculate_rrf_score(
    vector_rank: int,
    bm25_rank: int,
    rrf_constant: float,
    vector_weight: float,
    bm25_weight: float,
) -> float:
    return (
        vector_weight / (rrf_constant + float(vector_rank))
        + bm25_weight / (rrf_constant + float(bm25_rank))
    )


def _create_rrf_heatmap_data(
    axis_limit: int,
    rrf_constant: float,
    vector_weight: float,
    bm25_weight: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    x_grid = np.arange(1, axis_limit + 1)
    y_grid = np.arange(1, axis_limit + 1)
    x_mesh, y_mesh = np.meshgrid(x_grid, y_grid)

    rrf_scores = np.zeros_like(x_mesh, dtype=np.float32)
    for row_idx in range(rrf_scores.shape[0]):
        for col_idx in range(rrf_scores.shape[1]):
            rrf_scores[row_idx, col_idx] = calculate_rrf_score(
                int(x_mesh[row_idx, col_idx]),
                int(y_mesh[row_idx, col_idx]),
                rrf_constant,
                vector_weight,
                bm25_weight,
            )

    return x_mesh, y_mesh, rrf_scores


def _plot_rrf_heatmap(ax: plt.Axes, x_mesh: np.ndarray, y_mesh: np.ndarray, rrf_scores: np.ndarray):
    return ax.pcolormesh(
        x_mesh,
        y_mesh,
        rrf_scores,
        cmap="Greens",
        alpha=0.7,
        shading="auto",
        zorder=1,
    )


def _setup_plot_axes(ax: plt.Axes, axis_limit: int):
    ax.set_xlim(1, axis_limit)
    ax.set_ylim(1, axis_limit)
    ax.set_xlabel("Vector rank")
    ax.set_ylabel("BM25 rank")
    ax.grid(True, alpha=0.25)


def select_expected_docs(
    relevance_mask: np.ndarray,
    vector_rank: np.ndarray,
    bm25_rank: np.ndarray,
    expected_docs: int,
    baseline_rrf_constant: float,
    vector_weight: float,
    bm25_weight: float,
) -> np.ndarray:
    relevant_ids = np.flatnonzero(relevance_mask)
    if len(relevant_ids) == 0:
        return np.array([], dtype=np.int32)

    _, baseline_final_rank, _ = compute_rrf_rank_for_query(
        vector_rank=vector_rank,
        bm25_rank=bm25_rank,
        rrf_constant=baseline_rrf_constant,
        vector_weight=vector_weight,
        bm25_weight=bm25_weight,
    )
    sort_idx = np.argsort(baseline_final_rank[relevant_ids], kind="mergesort")
    sorted_relevant_ids = relevant_ids[sort_idx]
    return sorted_relevant_ids[:expected_docs].astype(np.int32)


def plot_rrf_metrics(metrics_df: pd.DataFrame, top_k: int, output_path: Path):
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    k_values = metrics_df["c"].to_numpy()

    ax.plot(k_values, metrics_df[f"ndcg@{top_k}"], marker="o", label=f"NDCG@{top_k}")
    ax.plot(k_values, metrics_df[f"capped_recall@{top_k}"], marker="o", label=f"Capped Recall@{top_k}")
    ax.plot(k_values, metrics_df[f"hit@{top_k}"], marker="o", label=f"Hit@{top_k}")
    ax.plot(k_values, metrics_df["mrr"], marker="o", label="MRR")

    ax.set_title("RRF定数kによる検索指標の変化")
    ax.set_xlabel("RRF定数 k")
    ax.set_ylabel("指標値")
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=8)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_rank_visualization_for_query(
    query_index: int,
    query_external_id: str,
    query_text: str,
    vector_rank: np.ndarray,
    bm25_rank: np.ndarray,
    relevance_mask: np.ndarray,
    expected_doc_ids: np.ndarray,
    corpus_ids: np.ndarray,
    rrf_constants: Sequence[float],
    vector_weight: float,
    bm25_weight: float,
    top_k: int,
    output_path: Path,
) -> List[Dict[str, float]]:
    ncols = len(rrf_constants)
    fig, axes = plt.subplots(1, ncols, figsize=(6 * ncols, 6))
    if ncols == 1:
        axes = [axes]

    expected_set = set(int(doc_id) for doc_id in expected_doc_ids.tolist())
    query_summary: List[Dict[str, float]] = []

    for ax, k_value in zip(axes, rrf_constants):
        order, final_rank, _ = compute_rrf_rank_for_query(
            vector_rank,
            bm25_rank,
            k_value,
            vector_weight,
            bm25_weight,
        )

        expected_in_topk = [doc_id for doc_id in expected_doc_ids if final_rank[doc_id] <= top_k]
        expected_not_in_topk = [doc_id for doc_id in expected_doc_ids if final_rank[doc_id] > top_k]
        non_expected_in_topk = [doc_id for doc_id in order[:top_k] if int(doc_id) not in expected_set]

        plot_ids = expected_in_topk + expected_not_in_topk + non_expected_in_topk
        if len(plot_ids) == 0:
            plot_ids = order[:top_k].tolist()

        max_rank = max(
            int(np.max(vector_rank[plot_ids])),
            int(np.max(bm25_rank[plot_ids])),
        )
        axis_limit = max(int(max_rank * 1.2), top_k + 5)

        x_mesh, y_mesh, rrf_scores = _create_rrf_heatmap_data(
            axis_limit=axis_limit,
            rrf_constant=k_value,
            vector_weight=vector_weight,
            bm25_weight=bm25_weight,
        )
        mesh = _plot_rrf_heatmap(ax, x_mesh, y_mesh, rrf_scores)
        _setup_plot_axes(ax, axis_limit)
        fig.colorbar(mesh, ax=ax, fraction=0.046, pad=0.04, label="RRFスコア")

        if len(non_expected_in_topk) > 0:
            ax.scatter(
                vector_rank[non_expected_in_topk],
                bm25_rank[non_expected_in_topk],
                c="gray",
                alpha=0.5,
                s=60,
                edgecolors="black",
                linewidths=0.8,
                label=f"RRF Top-{top_k}(期待外)",
                zorder=3,
            )

        if len(expected_not_in_topk) > 0:
            ax.scatter(
                vector_rank[expected_not_in_topk],
                bm25_rank[expected_not_in_topk],
                c="red",
                alpha=0.8,
                s=100,
                edgecolors="black",
                linewidths=1.0,
                label=f"期待文書(>{top_k}位)",
                zorder=4,
            )

        if len(expected_in_topk) > 0:
            ax.scatter(
                vector_rank[expected_in_topk],
                bm25_rank[expected_in_topk],
                c="blue",
                alpha=0.85,
                s=100,
                edgecolors="black",
                linewidths=1.0,
                label=f"期待文書(<= {top_k}位)",
                zorder=5,
            )

        sorted_expected = sorted(expected_doc_ids.tolist(), key=lambda doc_id: int(final_rank[doc_id]))
        for doc_idx in sorted_expected[:10]:
            corpus_external_id = str(corpus_ids[doc_idx])
            ax.annotate(
                f"順位{int(final_rank[doc_idx])} / ID:{corpus_external_id[-5:]}",
                (int(vector_rank[doc_idx]), int(bm25_rank[doc_idx])),
                textcoords="offset points",
                xytext=(7, 7),
                fontsize=8,
                bbox={
                    "boxstyle": "round,pad=0.3",
                    "facecolor": "white",
                    "alpha": 0.75,
                    "edgecolor": "black",
                },
                zorder=6,
            )

        expected_hit_ratio = float(len(expected_in_topk) / max(len(expected_doc_ids), 1))
        hit_positions = np.flatnonzero(relevance_mask[order])
        first_rel_rank = int(hit_positions[0] + 1) if len(hit_positions) > 0 else int(len(order) + 1)
        query_summary.append(
            {
                "c": float(k_value),
                "expected_hit_ratio": expected_hit_ratio,
                "first_relevant_rank": float(first_rel_rank),
            }
        )

        ax.set_title(f"k={k_value:g}")
        ax.legend(loc="upper right", fontsize=8)

    query_title = f"query_index={query_index}, query_id={query_external_id}"
    if query_text:
        query_title += f" | {query_text[:80]}{'...' if len(query_text) > 80 else ''}"
    fig.suptitle(query_title, fontsize=11, y=0.98)
    plt.tight_layout(rect=(0, 0, 1, 0.94))
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return query_summary

6) 実験パラメータ

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DATASET_SPLIT = "test"

CORPUS_SIZE = 3000
QUERY_SIZE = 120
MIN_DOC_CHARS = 80
BATCH_SIZE = 64
SEED = 42
TOP_K = 10
EXPECTED_DOCS = 1

RRF_CONSTANTS = parse_float_list("1,5,10,20,40,60,80,120,200")
SCATTER_CONSTANTS = parse_float_list("10,60,200")
VECTOR_WEIGHT = 0.5
BM25_WEIGHT = 0.5

FOCUS_QUERY_INDICES = [0, 34, 41, 86]  # 指定クエリ

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = PROJECT_ROOT / "outputs" / f"colab_demo_standalone_{timestamp}"
run_dir.mkdir(parents=True, exist_ok=True)

cache_dir = CACHE_ROOT
(cache_dir / "sentence_transformers").mkdir(parents=True, exist_ok=True)

print(f"run_dir: {run_dir}")

7) 実行: データ読み込み -> ランク計算 -> 可視化

# 7) 実行: データ読み込み -> ランク計算 -> 可視化
print("Loading dataset...")
dataset = load_scifact_dataset(
    query_size=QUERY_SIZE,
    corpus_size=CORPUS_SIZE,
    min_doc_chars=MIN_DOC_CHARS,
    seed=SEED,
    cache_dir=cache_dir,
    split=DATASET_SPLIT,
)

avg_rel_docs = float(np.mean(dataset.relevance_matrix.sum(axis=1)))
print(
    f"Dataset ready: name={dataset.dataset_name}, corpus={len(dataset.corpus_texts)}, "
    f"queries={len(dataset.query_texts)}, avg_relevant_per_query={avg_rel_docs:.2f}"
)

for qidx in FOCUS_QUERY_INDICES:
    if qidx < 0 or qidx >= len(dataset.query_texts):
        raise ValueError(f"FOCUS_QUERY_INDICES contains out-of-range value: {qidx}")

print("Computing vector ranks...")
vector_rank_matrix = compute_vector_rank_matrix(
    model_name=MODEL_NAME,
    corpus_texts=dataset.corpus_texts,
    query_texts=dataset.query_texts,
    batch_size=BATCH_SIZE,
    cache_folder=cache_dir / "sentence_transformers",
)

print("Computing BM25 ranks...")
bm25_rank_matrix = compute_bm25_rank_matrix(
    corpus_texts=dataset.corpus_texts,
    query_texts=dataset.query_texts,
)

print("Sweeping RRF k values...")
metrics_df = sweep_rrf_constants(
    vector_rank_matrix=vector_rank_matrix,
    bm25_rank_matrix=bm25_rank_matrix,
    relevance_matrix=dataset.relevance_matrix,
    rrf_constants=RRF_CONSTANTS,
    vector_weight=VECTOR_WEIGHT,
    bm25_weight=BM25_WEIGHT,
    top_k=TOP_K,
)

metrics_csv_path = run_dir / "rrf_constant_metrics.csv"
metrics_plot_path = run_dir / "rrf_constant_metrics.png"
metrics_df.to_csv(metrics_csv_path, index=False)
plot_rrf_metrics(metrics_df=metrics_df, top_k=TOP_K, output_path=metrics_plot_path)

print(f"Saved: {metrics_csv_path}")
print(f"Saved: {metrics_plot_path}")

baseline_rrf_constant = float(SCATTER_CONSTANTS[len(SCATTER_CONSTANTS) // 2])
scatter_paths: List[Path] = []
focus_analyses: Dict[str, Dict[str, object]] = {}

for query_idx in FOCUS_QUERY_INDICES:
    query_external_id = str(dataset.query_ids[query_idx])
    query_text = dataset.query_texts[query_idx]

    expected_doc_ids = select_expected_docs(
        relevance_mask=dataset.relevance_matrix[query_idx],
        vector_rank=vector_rank_matrix[query_idx],
        bm25_rank=bm25_rank_matrix[query_idx],
        expected_docs=EXPECTED_DOCS,
        baseline_rrf_constant=baseline_rrf_constant,
        vector_weight=VECTOR_WEIGHT,
        bm25_weight=BM25_WEIGHT,
    )
    if len(expected_doc_ids) == 0:
        raise RuntimeError(f"No expected docs found for query index={query_idx}")

    scatter_plot_path = run_dir / f"rrf_rank_scatter_query_{query_idx}.png"
    query_summary = plot_rank_visualization_for_query(
        query_index=query_idx,
        query_external_id=query_external_id,
        query_text=query_text,
        vector_rank=vector_rank_matrix[query_idx],
        bm25_rank=bm25_rank_matrix[query_idx],
        relevance_mask=dataset.relevance_matrix[query_idx],
        expected_doc_ids=expected_doc_ids,
        corpus_ids=dataset.corpus_ids,
        rrf_constants=SCATTER_CONSTANTS,
        vector_weight=VECTOR_WEIGHT,
        bm25_weight=BM25_WEIGHT,
        top_k=TOP_K,
        output_path=scatter_plot_path,
    )
    scatter_paths.append(scatter_plot_path)
    print(f"Saved: {scatter_plot_path}")

    analysis = {
        "focus_query_index": query_idx,
        "focus_query_id": query_external_id,
        "focus_query_group": dataset.query_groups[query_idx],
        "focus_query_text": query_text,
        "expected_doc_indices": expected_doc_ids.tolist(),
        "expected_doc_ids": [str(dataset.corpus_ids[int(doc_idx)]) for doc_idx in expected_doc_ids.tolist()],
        "scatter_constants": SCATTER_CONSTANTS,
        "query_summary": query_summary,
    }
    analysis_path = run_dir / f"focus_query_analysis_query_{query_idx}.json"
    analysis_path.write_text(json.dumps(analysis, ensure_ascii=False, indent=2), encoding="utf-8")
    focus_analyses[str(query_idx)] = analysis

run_summary = {
    "dataset": {
        "name": dataset.dataset_name,
        "source": dataset.metadata.get("source", ""),
        "split": dataset.metadata.get("split", ""),
        "note": dataset.metadata.get("note", ""),
        "corpus_size": len(dataset.corpus_texts),
        "query_size": len(dataset.query_texts),
        "avg_relevant_per_query": avg_rel_docs,
    },
    "parameters": {
        "model_name": MODEL_NAME,
        "top_k": TOP_K,
        "expected_docs": EXPECTED_DOCS,
        "rrf_constants": RRF_CONSTANTS,
        "scatter_constants": SCATTER_CONSTANTS,
        "vector_weight": VECTOR_WEIGHT,
        "bm25_weight": BM25_WEIGHT,
        "focus_query_indices": FOCUS_QUERY_INDICES,
    },
    "outputs": {
        "metrics_csv": str(metrics_csv_path),
        "metrics_plot": str(metrics_plot_path),
        "scatter_plots": [str(p) for p in scatter_paths],
    },
}

summary_path = run_dir / "notebook_run_summary.json"
summary_path.write_text(json.dumps(run_summary, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"Saved: {summary_path}")

required_paths = [metrics_csv_path, metrics_plot_path, *scatter_paths, summary_path]
missing = [str(p) for p in required_paths if not p.exists()]
if missing:
    raise RuntimeError(f"Expected files are missing: {missing}")

print("All outputs generated.")
print(metrics_df)

8) 出力デモコード

print("=== Output files ===")
for p in required_paths:
    print(p)

try:
    from IPython.display import Image, display

    display(Image(filename=str(metrics_plot_path), width=1100))
    for p in scatter_paths:
        display(Image(filename=str(p), width=1100))
except Exception as e:
    print("画像表示をスキップしました:", e)

IGSAについて

IGSAは、社会を温かく柔らかく持続的に支えるAIシステムにより、持続可能な幸せを目指す、東京大学松尾・岩澤研究室発のAIカンパニーです。
脳の健康管理アプリ「はなしてね」や、中古品の画像解析SaaS「スグトリ」などのAIプロダクト提供に加え、潜在的な課題に対し柔軟な開発支援を行うパートナー事業を展開。センシングAI技術を活用した状態の定量化と分析により、人の意思決定をサポートしています。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?