概要
- langchainのBM25Retrieverを高速化した(100Kのコーパス使用時で約50倍)
- 過去にBM25スコアの計算に使うライブラリをrank_bm25からscikit-learnベースのBM25Vectorizerに変更することで高速化できたが、検索結果が異なってしまう課題が見られたため、rank_bm25を使用し、APIや検索結果を維持したままでの高速化した。
リポジトリ
評価用のコードおよび手元のlaptopで検証した結果を以下で公開しています。
2024/10/29追記:リポジトリをアップデートしたため、最新のリポジトリと本記事の内容に齟齬があります。記事執筆当時のリポジトリのリンクは以下です。
https://github.com/jiroshimaya/fast-langchain-sparse-retriever/tree/v0.0.2
環境
- M1 Mac, macOS 14.5
- Python 3.11.2
以下の仮想環境上で実行しています。
uv venv
source .venv/bin/activate
uv pip install -r requirements.txt
langchain
langchain-community==0.3.1
scipy
rank-bm25
scikit-learn
datasets
概要
背景
- BM25Retrieverが使用するrank_bm25では、クエリごとに辞書操作でBM25スコアを計算しており、これが低速の原因となっている可能性がある。
- 例えば筆者の環境では10万件のコーパスの検索に1秒かかってしまう。これはscikit-learnベースのTFIDFVectorizerより50倍遅い。
- rank_bm25の代わりにscikit-learnベースのBM25Vectorizerを用いることで高速化できるが、rank_bm25と異なってしまい、後方互換性が保たれないことが課題だった。
解決策と実装
TFIDFRetrieverが高速なのは、コーパスの重みマトリックスを事前計算し、検索時に行列演算で類似度を計算しているため。
BM25Retrieverでも同様のアプローチを実装した。具体的には以下の変更を行った:
- rank_bm25.BM25Okapiを拡張し、文を重みベクトルに変換するメソッドを追加
- 疎行列の効率的な保存と計算のため、scipyライブラリ(pipでインストール可能)を追加で使用
- BM25Retriever.from_textsで、コーパスの各文書のBM25重みベクトルを作成し、プロパティとして保存
- BM25Retriever._get_relevant_documentsで、クエリの単語頻度ベクトルとコーパス文書の重みベクトルのドット積としてBM25スコアを計算
評価
- 100,000文書のコーパスで評価したところ、初期化は2-3倍遅くなったが、検索は約50倍高速化した。
- 初期化はユーザーが使い始める前に完了できるため、検索速度の方がユーザー体験に影響を与える可能性が高い。したがって、多くのユースケースでこの改善は有用と考えられる。
- 検索プロセスを変更したため、BM25スコアは変更前と完全には一致しなくなったが、その差は非常に小さかった。
- 具体的には、上位100件のBM25スコアの平均値が30程度であるのに対し、従来の方法と改善後の方法でのBM25スコアの絶対差は10のマイナス16乗のオーダーであった。
- 検索結果にも若干の影響があったが、実用上は無視できるレベルと思われる。
- 具体的には、100個のクエリに対してBM25Retriever.invokeの上位100件の検索結果を従来の方法と改善後の方法で比較したところ、94個のクエリで上位100件の検索結果が完全に一致した。残りの6個のクエリのうち、1個は6位の結果に差異があり、他の5個はすべて50位以降の結果に差異があった。したがって、上位の検索結果を重視するRAGなどのアプリケーションでは、この差異の影響は最小限であると考えられる。
実装
langchainのbm25.pyをベースにいくつ可能機能を追加します。
rank-bm25のBM25Okapiのサブクラスを作り、以下の機能を新たに実装します。
- コンストラクタをオーバーライドし、単語ID辞書を作成。重み行列の行番号と単語の対応をとるため。
- transformメソッドの追加。疎行列を前提とすることで計算を効率化する。
- count_transformメソッドの追加。BM25スコアはクエリの単語頻度ベクトルと、コーパスの重み行列の内積によって求められるため、行番号と単語が対応した単語頻度ベクトルの算出用メソッドが必要。
なお、langchainの既存の書かれ方に合わせて、ライブラリを関数の中でimport(いわゆるlazy import)するようにしています。このため、関数の中で動的にクラスを作成するというややこしい実装になっています。lazy importを諦めれば関数の外で素直にクラス定義できるのですが、どっちがいいのかはよくわかりません。
def create_bm25_vectorizer(corpus, **bm25_params):
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"Could not import rank_bm25, please install with `pip install "
"rank_bm25`."
)
class BM25Vectorizer(BM25Okapi):
def __init__(self, corpus, **bm25_params):
super().__init__(corpus, **bm25_params)
self.vocabulary = list(self.idf.keys())
self.word_to_id = {word: i for i, word in enumerate(self.vocabulary)}
def transform(self, queries: list[list[str]]) -> scipy.sparse.csr_matrix:
try:
from scipy.sparse import csr_matrix
except ImportError:
raise ImportError(
"Could not import scipy, please install with `pip install "
"scipy`."
)
rows = []
cols = []
data = []
for i, query in enumerate(queries):
query_len = len(query)
for word in set(query):
if word in self.word_to_id:
word_id = self.word_to_id[word]
tf = query.count(word)
idf = self.idf.get(word, 0)
# BM25 scoring formula
numerator = idf * tf * (self.k1 + 1)
denominator = tf + self.k1 * (1 - self.b + self.b * query_len / self.avgdl)
score = numerator / denominator
rows.append(i)
cols.append(word_id)
data.append(score)
return csr_matrix((data, (rows, cols)), shape=(len(queries), len(self.vocabulary)))
def count_transform(self, queries: list[list[str]]) -> scipy.sparse.csr_matrix:
try:
from scipy.sparse import csr_matrix
except ImportError:
raise ImportError(
"Could not import scipy, please install with `pip install "
"scipy`."
)
rows = []
cols = []
data = []
for i, query in enumerate(queries):
for word in query:
if word in self.word_to_id:
word_id = self.word_to_id[word]
rows.append(i)
cols.append(word_id)
data.append(1) # Count is always 1 for each occurrence
return csr_matrix((data, (rows, cols)), shape=(len(queries), len(self.vocabulary)))
return BM25Vectorizer(corpus, **bm25_params)
改良したBM25Retrieverクラスを定義します。
- from_texts
- create_bm25_vectorizerを呼び出す
- bm25_arrayを作成
- _get_relevant_documents
- クエリを単語頻度ベクトルに変換
- bm25_arrayと単語頻度ベクトルの内積を計算
class BM25Retriever(BaseRetriever):
"""`BM25` retriever without Elasticsearch."""
vectorizer: Any = None
""" BM25 vectorizer."""
docs: List[Document] = Field(repr=False)
""" List of documents."""
bm25_array: Any = None
"""BM25 array."""
k: int = 4
""" Number of documents to return."""
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
""" Preprocessing function to use on the text before BM25 vectorization."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25Retriever:
"""
Create a BM25Retriever from a list of texts.
Args:
texts: A list of texts to vectorize.
metadatas: A list of metadata dicts to associate with each text.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts_processed = [preprocess_func(t) for t in texts]
bm25_params = bm25_params or {}
vectorizer = create_bm25_vectorizer(texts_processed, **bm25_params)
bm25_array = vectorizer.transform(texts_processed)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(
vectorizer=vectorizer, docs=docs, bm25_array=bm25_array, preprocess_func=preprocess_func, **kwargs
)
@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any,
) -> BM25Retriever:
"""
Create a BM25Retriever from a list of Documents.
Args:
documents: A list of Documents to vectorize.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts,
bm25_params=bm25_params,
metadatas=metadatas,
preprocess_func=preprocess_func,
**kwargs,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
processed_query = self.preprocess_func(query)
query_vec = self.vectorizer.count_transform([processed_query])
results = query_vec.dot(self.bm25_array.T).toarray()[0]
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
評価と考察
リポジトリ以外に記載する内容はないので省略します。