概要
langchainのBM25Retrieverを高速にマージする方法を検討しました。
背景
BM25アルゴリズムはキーワード検索を実施する代表的なアルゴリズムであり、生成AIと検索機能を組み合わせたRAGにおいても使用されることがあります。
langchainのBM25RetrieverはBM25による検索を生成AIと簡易に組み合わせるのに適したライブラリです。今回は、複数のBM25Retrieverを1つにまとめる方法を検討しました。
参考までに、同じくlangchainのTfidfRetrieverのマージ方法については以下の記事で検討したことがあります。
【langchain】TFIDFRetrieverのマージ
環境
環境
- M1 Mac, macOS 14.5
- Python 3.11.2
- rank-bm25 0.2.2
- langchain-community 0.3.1
以下の仮想環境上で実行しています。
uv venv
source .venv/bin/activate
uv pip install rank-bm25
簡単な方法
BM25Retireverをマージするいちばん簡単な方法は各retrieverのコーパスをマージして新たなBM25Retrieverを作ることです。
サンプルのコーパスとクエリを用意します。
corpus_a = [
"The quick brown fox jumps over the lazy dog.",
"A journey of a thousand miles begins with a single step.",
"To be or not to be, that is the question."
]
corpus_b = [
"In the middle of difficulty lies opportunity.",
"What we think, we become.",
"The only limit to our realization of tomorrow is our doubts of today."
]
query = "The only limit to our realization of tomorrow is our doubts of today."
各retrieverからdocsを取り出してマージします。
from langchain_community.retrievers import BM25Retriever
retriever_a = BM25Retriever.from_texts(corpus_a)
retriever_b = BM25Retriever.from_texts(corpus_b)
def merge_bm25retriever_simple(retriever_list: list[BM25Retriever]) -> BM25Retriever:
merged_docs = []
for retriever in retriever_list:
merged_docs += retriever.docs
bm25_params = {}
bm25_params["k1"] = retriever_list[0].vectorizer.k1
bm25_params["b"] = retriever_list[0].vectorizer.b
bm25_params["epsilon"] = retriever_list[0].vectorizer.epsilon
return BM25Retriever.from_documents(merged_docs,
k=retriever_list[0].k,
preprocess_func=retriever_list[0].preprocess_func,
bm25_params = bm25_params)
merged_retriever_simple = merge_bm25retriever_simple([retriever_a, retriever_b])
merged_retriever_simple.invoke(query)
[Document(metadata={}, page_content='The only limit to our realization of tomorrow is our doubts of today.'),
Document(metadata={}, page_content='To be or not to be, that is the question.'),
Document(metadata={}, page_content='The quick brown fox jumps over the lazy dog.'),
Document(metadata={}, page_content='What we think, we become.')]
マージの速度に特に不満を感じなければ、このコードで良いと思います。
高速な方法
コーパスサイズによっては、簡単な方法だと、マージの速度が遅くなるかもしれません。
その場合は、各retrieverで計算済みのIDF辞書をうまく活用することで効率的にマージできるようになります。
BM25Retrieverでは主に以下の2つのプラパティをマージすれば、retriever全体をマージすることができます。
- vectorizer: BM25のスコアリングに使用されるrank_bm25.BM25Okapiのインスタンス
- docs: 検索対象の文書リスト
docsは単なるリストなので、プラスで結合すればよいです。
vectorizerのマージ方法については、別記事で検証しました。
参考: 【python】rank_bm25のインスタンスを高速にマージ
上記記事を参考に、vectorizerをマージする処理を実装していきます。
まずマージに必要な情報をプラパティとして取得できるようにするために、rank_bm25.BM25Okapiのメソッドをオーバーライドします。
from rank_bm25 import BM25Okapi
def _initialize(self, corpus):
nd = {} # word -> number of documents with word
num_doc = 0
for document in corpus:
self.doc_len.append(len(document))
num_doc += len(document)
frequencies = {}
for word in document:
if word not in frequencies:
frequencies[word] = 0
frequencies[word] += 1
self.doc_freqs.append(frequencies)
for word, freq in frequencies.items():
try:
nd[word]+=1
except KeyError:
nd[word] = 1
self.corpus_size += 1
self.avgdl = num_doc / self.corpus_size
# ndを保存
self.nd = nd
return nd
BM25Okapi._initialize = _initialize
次にBM25Okapiのインスタンスを効率的にマージする関数を作成します。
詳細は参考記事に譲りますが、各インスタンスで計算済みの文書頻度辞書を再利用することで、マージが高速化されます。
def merge_bm25(bm25_list: list[BM25Okapi]) -> BM25Okapi:
if not bm25_list:
raise ValueError("The bm25_list is empty")
# Combine all tokenized corpora
merged_corpus_size = 0
merged_doc_freqs = []
merged_doc_len = []
for bm25 in bm25_list:
merged_corpus_size += bm25.corpus_size
merged_doc_freqs += bm25.doc_freqs
merged_doc_len += bm25.doc_len
merged_nd = {}
for bm25 in bm25_list:
for word, nd in bm25.nd.items():
merged_nd[word] = merged_nd.get(word, 0) + nd
merged_bm25 = BM25Okapi(["a"]
, tokenizer = bm25_list[0].tokenizer
, k1 = bm25_list[0].k1
, b = bm25_list[0].b
, epsilon = bm25_list[0].epsilon)
merged_bm25.corpus_size = merged_corpus_size
merged_bm25.doc_freqs = merged_doc_freqs
merged_bm25.doc_len = merged_doc_len
merged_bm25.avgdl = sum(merged_doc_len) / merged_corpus_size
merged_bm25.nd = merged_nd
merged_bm25._calc_idf(merged_nd)
return merged_bm25
最後にBM25Retireverをマージする関数を作ります。この関数は_initializeメソッドがオーバーライドされndがプラパティに保存されるようになったBM25Okapiを内部で使用することを前提としています。
from langchain_community.retrievers import BM25Retriever
def merge_bm25retriever(retriever_list: list[BM25Retriever]) -> BM25Retriever:
if not retriever_list:
raise ValueError("The retriever_list is empty")
merged_docs = []
for retriever in retriever_list:
merged_docs += retriever.docs
vectorizer_list = [retriever.vectorizer for retriever in retriever_list]
merged_vectorizer = merge_bm25(vectorizer_list)
merged_retriever = BM25Retriever(
vectorizer = merged_vectorizer,
docs = merged_docs,
k = retriever_list[0].k,
preprocess_func = retriever_list[0].preprocess_func
)
return merged_retriever
動作確認
簡単な実装と比べて検索結果が同じになるかを確認します。
コーパスとクエリは「簡単な実装」で定義したものと同様です。
retriever_a = BM25Retriever.from_texts(corpus_a)
retriever_b = BM25Retriever.from_texts(corpus_b)
merged_retriever = merge_bm25retriever([retriever_a, retriever_b])
merged_retriever.invoke(query)
[Document(metadata={}, page_content='The only limit to our realization of tomorrow is our doubts of today.'),
Document(metadata={}, page_content='To be or not to be, that is the question.'),
Document(metadata={}, page_content='The quick brown fox jumps over the lazy dog.'),
Document(metadata={}, page_content='What we think, we become.')]
簡単な実装の場合と同じ結果になっているので、大丈夫そうです。
速度
速度差を見やすくするため、大量の文書で実験します。50000件のサブコーパスを2つ用意します。
from datasets import load_dataset
# Load the AG News dataset
dataset = load_dataset('ag_news')
# Define sub-corpus size
sub_corpus_size = 50000
# Extract the text data from the dataset
corpus_a = [item['text'] for item in dataset['train'].select(range(sub_corpus_size))] # First sub_corpus_size items for corpus_a
corpus_b = [item['text'] for item in dataset['train'].select(range(sub_corpus_size, 2 * sub_corpus_size))] # Next sub_corpus_size items for corpus_b
query = "The only limit to our realization of tomorrow is our doubts of today."
2種類のやり方でのマージ速度を比較します。
import time
import numpy as np
iterations = 10
def measure_execution_time(func, *args, iterations=10, **kwargs):
times = []
for _ in range(iterations):
start_time = time.time()
func(*args, **kwargs)
end_time = time.time()
times.append(end_time - start_time)
average_time = np.mean(times)
std_dev_time = np.std(times)
return average_time, std_dev_time
retriever_a = BM25Retriever.from_texts(corpus_a)
retriever_b = BM25Retriever.from_texts(corpus_b)
# Measure time for merging BM25
average_time_merge, std_dev_time_merge = measure_execution_time(merge_bm25retriever_simple, [retriever_a, retriever_b])
print(f"merge_bm25retriever_simple function over {iterations} iterations: {average_time_merge} ± {std_dev_time_merge} seconds")
# Measure time for initializing BM25
average_time_init, std_dev_time_init = measure_execution_time(merge_bm25retriever, [retriever_a, retriever_b])
print(f"merge_bm25retriever function over {iterations} iterations: {average_time_init} ± {std_dev_time_init} seconds")
merge_bm25retriever_simple function over 10 iterations: 2.793686556816101 ± 0.1977367136912639 seconds
merge_bm25retriever function over 10 iterations: 0.1386489152908325 ± 0.0025654566699783777 seconds
1/20くらいに時間が短縮されました。
おわりに
BM25Retrieverの高速なマージ方法を検討し、50000件のコーパスで実験したところ、簡単な方法に比べて、1/20ほどに時間短縮できました。
BM25Okapiの関数をオーバーライドするのはできればやりたくないため、別の方法も引き続き検討したいです。