0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【langchain】BM25Retrieverの高速なマージ

Last updated at Posted at 2024-10-01

概要

langchainのBM25Retrieverを高速にマージする方法を検討しました。

背景

BM25アルゴリズムはキーワード検索を実施する代表的なアルゴリズムであり、生成AIと検索機能を組み合わせたRAGにおいても使用されることがあります。

langchainのBM25RetrieverはBM25による検索を生成AIと簡易に組み合わせるのに適したライブラリです。今回は、複数のBM25Retrieverを1つにまとめる方法を検討しました。

参考までに、同じくlangchainのTfidfRetrieverのマージ方法については以下の記事で検討したことがあります。
【langchain】TFIDFRetrieverのマージ

環境

環境

  • M1 Mac, macOS 14.5
  • Python 3.11.2
  • rank-bm25 0.2.2
  • langchain-community 0.3.1

以下の仮想環境上で実行しています。

uv venv
source .venv/bin/activate
uv pip install rank-bm25

簡単な方法

BM25Retireverをマージするいちばん簡単な方法は各retrieverのコーパスをマージして新たなBM25Retrieverを作ることです。

サンプルのコーパスとクエリを用意します。

corpus_a = [
    "The quick brown fox jumps over the lazy dog.",
    "A journey of a thousand miles begins with a single step.",
    "To be or not to be, that is the question."
]

corpus_b = [
    "In the middle of difficulty lies opportunity.",
    "What we think, we become.",
    "The only limit to our realization of tomorrow is our doubts of today."
]

query = "The only limit to our realization of tomorrow is our doubts of today."

各retrieverからdocsを取り出してマージします。

from langchain_community.retrievers import BM25Retriever

retriever_a = BM25Retriever.from_texts(corpus_a)
retriever_b = BM25Retriever.from_texts(corpus_b)

def merge_bm25retriever_simple(retriever_list: list[BM25Retriever]) -> BM25Retriever:
  merged_docs = []
  for retriever in retriever_list:
    merged_docs += retriever.docs
  bm25_params = {}
  bm25_params["k1"] = retriever_list[0].vectorizer.k1
  bm25_params["b"] = retriever_list[0].vectorizer.b
  bm25_params["epsilon"] = retriever_list[0].vectorizer.epsilon
  return BM25Retriever.from_documents(merged_docs, 
                                      k=retriever_list[0].k, 
                                      preprocess_func=retriever_list[0].preprocess_func,
                                      bm25_params = bm25_params)

merged_retriever_simple = merge_bm25retriever_simple([retriever_a, retriever_b])
merged_retriever_simple.invoke(query)
[Document(metadata={}, page_content='The only limit to our realization of tomorrow is our doubts of today.'),
 Document(metadata={}, page_content='To be or not to be, that is the question.'),
 Document(metadata={}, page_content='The quick brown fox jumps over the lazy dog.'),
 Document(metadata={}, page_content='What we think, we become.')]

マージの速度に特に不満を感じなければ、このコードで良いと思います。

高速な方法

コーパスサイズによっては、簡単な方法だと、マージの速度が遅くなるかもしれません。
その場合は、各retrieverで計算済みのIDF辞書をうまく活用することで効率的にマージできるようになります。
BM25Retrieverでは主に以下の2つのプラパティをマージすれば、retriever全体をマージすることができます。

  • vectorizer: BM25のスコアリングに使用されるrank_bm25.BM25Okapiのインスタンス
  • docs: 検索対象の文書リスト

docsは単なるリストなので、プラスで結合すればよいです。
vectorizerのマージ方法については、別記事で検証しました。

参考: 【python】rank_bm25のインスタンスを高速にマージ

上記記事を参考に、vectorizerをマージする処理を実装していきます。

まずマージに必要な情報をプラパティとして取得できるようにするために、rank_bm25.BM25Okapiのメソッドをオーバーライドします。

from rank_bm25 import BM25Okapi
def _initialize(self, corpus):
    nd = {}  # word -> number of documents with word
    num_doc = 0
    for document in corpus:
        self.doc_len.append(len(document))
        num_doc += len(document)

        frequencies = {}
        for word in document:
            if word not in frequencies:
                frequencies[word] = 0
            frequencies[word] += 1
        self.doc_freqs.append(frequencies)

        for word, freq in frequencies.items():
            try:
                nd[word]+=1
            except KeyError:
                nd[word] = 1

        self.corpus_size += 1

    self.avgdl = num_doc / self.corpus_size
    # ndを保存
    self.nd = nd
    return nd
BM25Okapi._initialize = _initialize

次にBM25Okapiのインスタンスを効率的にマージする関数を作成します。
詳細は参考記事に譲りますが、各インスタンスで計算済みの文書頻度辞書を再利用することで、マージが高速化されます。

def merge_bm25(bm25_list: list[BM25Okapi]) -> BM25Okapi:
    if not bm25_list:
        raise ValueError("The bm25_list is empty")

    # Combine all tokenized corpora
    merged_corpus_size = 0
    merged_doc_freqs = []
    merged_doc_len = []
    
    for bm25 in bm25_list:
        merged_corpus_size += bm25.corpus_size
        merged_doc_freqs += bm25.doc_freqs
        merged_doc_len += bm25.doc_len
    
    merged_nd = {}
    for bm25 in bm25_list:
        for word, nd in bm25.nd.items():
            merged_nd[word] = merged_nd.get(word, 0) + nd
            
    merged_bm25 = BM25Okapi(["a"]
                            , tokenizer = bm25_list[0].tokenizer
                            , k1 = bm25_list[0].k1
                            , b = bm25_list[0].b
                            , epsilon = bm25_list[0].epsilon)
    
    merged_bm25.corpus_size = merged_corpus_size
    merged_bm25.doc_freqs = merged_doc_freqs
    merged_bm25.doc_len = merged_doc_len
    merged_bm25.avgdl = sum(merged_doc_len) / merged_corpus_size
    merged_bm25.nd = merged_nd
    
    merged_bm25._calc_idf(merged_nd)
    
    return merged_bm25

最後にBM25Retireverをマージする関数を作ります。この関数は_initializeメソッドがオーバーライドされndがプラパティに保存されるようになったBM25Okapiを内部で使用することを前提としています。

from langchain_community.retrievers import BM25Retriever

def merge_bm25retriever(retriever_list: list[BM25Retriever]) -> BM25Retriever:
    if not retriever_list:
        raise ValueError("The retriever_list is empty")
    
    merged_docs = []
    for retriever in retriever_list:
      merged_docs += retriever.docs
    
    vectorizer_list = [retriever.vectorizer for retriever in retriever_list]
    merged_vectorizer = merge_bm25(vectorizer_list)
    
    merged_retriever = BM25Retriever(
      vectorizer = merged_vectorizer,
      docs = merged_docs,
      k = retriever_list[0].k,
      preprocess_func = retriever_list[0].preprocess_func
    )
    return merged_retriever

動作確認

簡単な実装と比べて検索結果が同じになるかを確認します。
コーパスとクエリは「簡単な実装」で定義したものと同様です。

retriever_a = BM25Retriever.from_texts(corpus_a)
retriever_b = BM25Retriever.from_texts(corpus_b)

merged_retriever = merge_bm25retriever([retriever_a, retriever_b])
merged_retriever.invoke(query)
[Document(metadata={}, page_content='The only limit to our realization of tomorrow is our doubts of today.'),
 Document(metadata={}, page_content='To be or not to be, that is the question.'),
 Document(metadata={}, page_content='The quick brown fox jumps over the lazy dog.'),
 Document(metadata={}, page_content='What we think, we become.')]

簡単な実装の場合と同じ結果になっているので、大丈夫そうです。

速度

速度差を見やすくするため、大量の文書で実験します。50000件のサブコーパスを2つ用意します。

from datasets import load_dataset

# Load the AG News dataset
dataset = load_dataset('ag_news')

# Define sub-corpus size
sub_corpus_size = 50000

# Extract the text data from the dataset
corpus_a = [item['text'] for item in dataset['train'].select(range(sub_corpus_size))]  # First sub_corpus_size items for corpus_a
corpus_b = [item['text'] for item in dataset['train'].select(range(sub_corpus_size, 2 * sub_corpus_size))]  # Next sub_corpus_size items for corpus_b

query = "The only limit to our realization of tomorrow is our doubts of today."

2種類のやり方でのマージ速度を比較します。

import time
import numpy as np

iterations = 10

def measure_execution_time(func, *args, iterations=10, **kwargs):
    times = []
    for _ in range(iterations):
        start_time = time.time()
        func(*args, **kwargs)
        end_time = time.time()
        times.append(end_time - start_time)
    
    average_time = np.mean(times)
    std_dev_time = np.std(times)
    
    return average_time, std_dev_time

retriever_a = BM25Retriever.from_texts(corpus_a)
retriever_b = BM25Retriever.from_texts(corpus_b)
# Measure time for merging BM25
average_time_merge, std_dev_time_merge = measure_execution_time(merge_bm25retriever_simple, [retriever_a, retriever_b])
print(f"merge_bm25retriever_simple function over {iterations} iterations: {average_time_merge} ± {std_dev_time_merge} seconds")

# Measure time for initializing BM25
average_time_init, std_dev_time_init = measure_execution_time(merge_bm25retriever, [retriever_a, retriever_b])
print(f"merge_bm25retriever function over {iterations} iterations: {average_time_init} ± {std_dev_time_init} seconds")
merge_bm25retriever_simple function over 10 iterations: 2.793686556816101 ± 0.1977367136912639 seconds
merge_bm25retriever function over 10 iterations: 0.1386489152908325 ± 0.0025654566699783777 seconds

1/20くらいに時間が短縮されました。

おわりに

BM25Retrieverの高速なマージ方法を検討し、50000件のコーパスで実験したところ、簡単な方法に比べて、1/20ほどに時間短縮できました。

BM25Okapiの関数をオーバーライドするのはできればやりたくないため、別の方法も引き続き検討したいです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?