1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【langchain】BM25Retrieverの高速化(rank_bm25使用) v0.0.2

Last updated at Posted at 2024-10-07

概要

リポジトリ

評価用のコードおよび手元のlaptopで検証した結果を以下で公開しています。
2024/10/29追記:リポジトリをアップデートしたため、最新のリポジトリと本記事の内容に齟齬があります。記事執筆当時のリポジトリのリンクは以下です。
https://github.com/jiroshimaya/fast-langchain-sparse-retriever/tree/v0.0.2

環境

  • M1 Mac, macOS 14.5
  • Python 3.11.2

以下の仮想環境上で実行しています。

uv venv
source .venv/bin/activate
uv pip install -r requirements.txt
requirements.txt
langchain
langchain-community==0.3.1
scipy
rank-bm25
scikit-learn
datasets

概要

背景

  • BM25Retrieverが使用するrank_bm25では、クエリごとに辞書操作でBM25スコアを計算しており、これが低速の原因となっている可能性がある。
  • 例えば筆者の環境では10万件のコーパスの検索に1秒かかってしまう。これはscikit-learnベースのTFIDFVectorizerより50倍遅い。
  • rank_bm25の代わりにscikit-learnベースのBM25Vectorizerを用いることで高速化できるが、rank_bm25と異なってしまい、後方互換性が保たれないことが課題だった。

解決策と実装

TFIDFRetrieverが高速なのは、コーパスの重みマトリックスを事前計算し、検索時に行列演算で類似度を計算しているため。
BM25Retrieverでも同様のアプローチを実装した。具体的には以下の変更を行った:

  1. rank_bm25.BM25Okapiを拡張し、文を重みベクトルに変換するメソッドを追加
  2. 疎行列の効率的な保存と計算のため、scipyライブラリ(pipでインストール可能)を追加で使用
  3. BM25Retriever.from_textsで、コーパスの各文書のBM25重みベクトルを作成し、プロパティとして保存
  4. BM25Retriever._get_relevant_documentsで、クエリの単語頻度ベクトルとコーパス文書の重みベクトルのドット積としてBM25スコアを計算

評価

  • 100,000文書のコーパスで評価したところ、初期化は2-3倍遅くなったが、検索は約50倍高速化した。
    • 初期化はユーザーが使い始める前に完了できるため、検索速度の方がユーザー体験に影響を与える可能性が高い。したがって、多くのユースケースでこの改善は有用と考えられる。
  • 検索プロセスを変更したため、BM25スコアは変更前と完全には一致しなくなったが、その差は非常に小さかった。
    • 具体的には、上位100件のBM25スコアの平均値が30程度であるのに対し、従来の方法と改善後の方法でのBM25スコアの絶対差は10のマイナス16乗のオーダーであった。
  • 検索結果にも若干の影響があったが、実用上は無視できるレベルと思われる。
    • 具体的には、100個のクエリに対してBM25Retriever.invokeの上位100件の検索結果を従来の方法と改善後の方法で比較したところ、94個のクエリで上位100件の検索結果が完全に一致した。残りの6個のクエリのうち、1個は6位の結果に差異があり、他の5個はすべて50位以降の結果に差異があった。したがって、上位の検索結果を重視するRAGなどのアプリケーションでは、この差異の影響は最小限であると考えられる。

実装

langchainのbm25.pyをベースにいくつ可能機能を追加します。

rank-bm25のBM25Okapiのサブクラスを作り、以下の機能を新たに実装します。

  • コンストラクタをオーバーライドし、単語ID辞書を作成。重み行列の行番号と単語の対応をとるため。
  • transformメソッドの追加。疎行列を前提とすることで計算を効率化する。
  • count_transformメソッドの追加。BM25スコアはクエリの単語頻度ベクトルと、コーパスの重み行列の内積によって求められるため、行番号と単語が対応した単語頻度ベクトルの算出用メソッドが必要。

なお、langchainの既存の書かれ方に合わせて、ライブラリを関数の中でimport(いわゆるlazy import)するようにしています。このため、関数の中で動的にクラスを作成するというややこしい実装になっています。lazy importを諦めれば関数の外で素直にクラス定義できるのですが、どっちがいいのかはよくわかりません。

def create_bm25_vectorizer(corpus, **bm25_params):
    try:
        from rank_bm25 import BM25Okapi
    except ImportError:
        raise ImportError(
            "Could not import rank_bm25, please install with `pip install "
            "rank_bm25`."
        )
    
    class BM25Vectorizer(BM25Okapi):
        def __init__(self, corpus, **bm25_params):            
            super().__init__(corpus, **bm25_params)
            self.vocabulary = list(self.idf.keys())
            self.word_to_id = {word: i for i, word in enumerate(self.vocabulary)}
        
        def transform(self, queries: list[list[str]]) -> scipy.sparse.csr_matrix:
            try:
                from scipy.sparse import csr_matrix
            except ImportError:
                raise ImportError(
                    "Could not import scipy, please install with `pip install "
                    "scipy`."
                )
            
            rows = []
            cols = []
            data = []
            
            for i, query in enumerate(queries):
                query_len = len(query)
                
                for word in set(query):
                    if word in self.word_to_id:
                        word_id = self.word_to_id[word]
                        tf = query.count(word)
                        idf = self.idf.get(word, 0)
                        
                        # BM25 scoring formula
                        numerator = idf * tf * (self.k1 + 1)
                        denominator = tf + self.k1 * (1 - self.b + self.b * query_len / self.avgdl)
                        
                        score = numerator / denominator
                        
                        rows.append(i)
                        cols.append(word_id)
                        data.append(score)
            
            return csr_matrix((data, (rows, cols)), shape=(len(queries), len(self.vocabulary)))
        
        def count_transform(self, queries: list[list[str]]) -> scipy.sparse.csr_matrix:
            try:
                from scipy.sparse import csr_matrix
            except ImportError:
                raise ImportError(
                    "Could not import scipy, please install with `pip install "
                    "scipy`."
                )
            
            rows = []
            cols = []
            data = []
            
            for i, query in enumerate(queries):
                for word in query:
                    if word in self.word_to_id:
                        word_id = self.word_to_id[word]
                        rows.append(i)
                        cols.append(word_id)
                        data.append(1)  # Count is always 1 for each occurrence
            
            return csr_matrix((data, (rows, cols)), shape=(len(queries), len(self.vocabulary)))
    
    return BM25Vectorizer(corpus, **bm25_params)

改良したBM25Retrieverクラスを定義します。

  • from_texts
    • create_bm25_vectorizerを呼び出す
    • bm25_arrayを作成
  • _get_relevant_documents
    • クエリを単語頻度ベクトルに変換
    • bm25_arrayと単語頻度ベクトルの内積を計算
class BM25Retriever(BaseRetriever):
    """`BM25` retriever without Elasticsearch."""

    vectorizer: Any = None
    """ BM25 vectorizer."""
    docs: List[Document] = Field(repr=False)
    """ List of documents."""
    bm25_array: Any = None
    """BM25 array."""
    k: int = 4
    """ Number of documents to return."""
    preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
    """ Preprocessing function to use on the text before BM25 vectorization."""

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )

    @classmethod
    def from_texts(
        cls,
        texts: Iterable[str],
        metadatas: Optional[Iterable[dict]] = None,
        bm25_params: Optional[Dict[str, Any]] = None,
        preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
        **kwargs: Any,
    ) -> BM25Retriever:
        """
        Create a BM25Retriever from a list of texts.
        Args:
            texts: A list of texts to vectorize.
            metadatas: A list of metadata dicts to associate with each text.
            bm25_params: Parameters to pass to the BM25 vectorizer.
            preprocess_func: A function to preprocess each text before vectorization.
            **kwargs: Any other arguments to pass to the retriever.

        Returns:
            A BM25Retriever instance.
        """
        texts_processed = [preprocess_func(t) for t in texts]
        bm25_params = bm25_params or {}
        vectorizer = create_bm25_vectorizer(texts_processed, **bm25_params)
        bm25_array = vectorizer.transform(texts_processed)
        metadatas = metadatas or ({} for _ in texts)
        docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
        return cls(
            vectorizer=vectorizer, docs=docs, bm25_array=bm25_array, preprocess_func=preprocess_func, **kwargs
        )

    @classmethod
    def from_documents(
        cls,
        documents: Iterable[Document],
        *,
        bm25_params: Optional[Dict[str, Any]] = None,
        preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
        **kwargs: Any,
    ) -> BM25Retriever:
        """
        Create a BM25Retriever from a list of Documents.
        Args:
            documents: A list of Documents to vectorize.
            bm25_params: Parameters to pass to the BM25 vectorizer.
            preprocess_func: A function to preprocess each text before vectorization.
            **kwargs: Any other arguments to pass to the retriever.

        Returns:
            A BM25Retriever instance.
        """
        texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
        return cls.from_texts(
            texts=texts,
            bm25_params=bm25_params,
            metadatas=metadatas,
            preprocess_func=preprocess_func,
            **kwargs,
        )

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        processed_query = self.preprocess_func(query)
        query_vec = self.vectorizer.count_transform([processed_query])
        results = query_vec.dot(self.bm25_array.T).toarray()[0]
        return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
        return return_docs

評価と考察

リポジトリ以外に記載する内容はないので省略します。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?