概要
複数のMultiVectorRetrieverをまとめて1つにする方法を検討しました。
以下のようにvectorstoreとdocstoreをそれぞれmergeすることで実現できます。
def merge_multi_vector_retriever(retrievers: List[MultiVectorRetriever]
,*
, retriever_class = MultiVectorRetriever
)->MultiVectorRetriever:
merged_vectorstore = copy.deepcopy(retrievers[0].vectorstore)
merged_docstore = copy.deepcopy(retrievers[0].docstore)
for retriever in retrievers[1:]:
merged_vectorstore.merge_from(copy.deepcopy(retriever.vectorstore))
merged_docstore.mset(retriever.docstore.store.items())
merged_retriever = retriever_class(
vectorstore=merged_vectorstore,
docstore=merged_docstore,
)
return merged_retriever
環境
langchainのバージョンは0.0.314です。
背景
langchainを触っていると、複数のretrieverを1つにマージして、検索したくなることがあります。retriever全体に共通するマージの方法は存在しないため、retrieverのtypeごとに方法を考える必要があります。
今回はMultiVectorRetriever(およびそれを継承したParentDocumentRetriever)のマージ方法を検討します。
方法
MultiVectorRetrieverは1つのdocに複数の埋め込みを紐づけるretrieverです。埋め込みを保持するvectorstoreと検索結果の文書集合であるdocstoreをインスタンス変数にもちます。vectorstoreとdocstoreをそれぞれマージすることで、MultiVectorRetrieverのマージも実現できます。
vectorstoreのマージ方法はvectorstoreのtypeにより異なります。FAISSの場合、merge_fromというメソッドによりマージが可能です。
docstoreは通常langchain.storage.InMemoryStoreなどのクラスです。msetというメソッドにより、新たなdocの追加が可能であり、マージできます。
MultiVectorRetrieverのリストを受け取って、vectorstore、docstoreをそれぞれマージし、新たなマージされたMultiVectorRetrieverを返す関数を作ってみます。
ライブラリをimportします。あとで使うライブラリも含まれます。
from typing import List
import copy
from langchain.retrievers import MultiVectorRetriever, ParentDocumentRetriever
from langchain.schema import Document
import uuid
from langchain.vectorstores.faiss import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.storage import InMemoryStore
マージする関数を定義します。引数に与えたretrieverの破壊的変更を防止するため、vectorstoreとdocstoreをdeepcopyしています。
def merge_multi_vector_retriever(retrievers: List[MultiVectorRetriever]
,*
, retriever_class = MultiVectorRetriever
)->MultiVectorRetriever:
merged_vectorstore = copy.deepcopy(retrievers[0].vectorstore)
merged_docstore = copy.deepcopy(retrievers[0].docstore)
for retriever in retrievers[1:]:
merged_vectorstore.merge_from(copy.deepcopy(retriever.vectorstore))
merged_docstore.mset(retriever.docstore.store.items())
merged_retriever = retriever_class(
vectorstore=merged_vectorstore,
docstore=merged_docstore,
)
return merged_retriever
試します。
2つのMultiVectorRetrieverを定義します。それぞれdocを1つだけもつretrieverです。
doc1 = """おはよう
こんにちは
さようなら"""
doc2 = """国語
算数
理科
社会"""
retrievers = []
for doc in [doc1, doc2]:
subdocs = doc.splitlines()
doc_id = str(uuid.uuid4())
embeddings = HuggingFaceEmbeddings(model_name="oshizo/sbert-jsnli-luke-japanese-base-lite")
metadatas = [{"doc_id": doc_id} for _ in subdocs]
vectorstore = FAISS.from_texts(subdocs, embeddings, metadatas)
docstore = InMemoryStore()
docstore.mset([[doc_id, Document(page_content=doc)]])
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=docstore,
)
retrievers.append(retriever)
マージ前の検索結果を確認します。
print(0, retrievers[0].get_relevant_documents("こんにちは"))
print(1, retrievers[1].get_relevant_documents("こんにちは"))
0 [Document(page_content='おはよう\nこんにちは\nさようなら')]
1 [Document(page_content='国語\n算数\n理科\n社会')]
それぞれ検索結果が1つだけ表示されます。
マージして検索結果を確認します。
merged_retriever = merge_multi_vector_retriever(retrievers)
print("こんにちは", merged_retriever.get_relevant_documents("こんにちは"))
print("国語", merged_retriever.get_relevant_documents("国語"))
こんにちは [Document(page_content='おはよう\nこんにちは\nさようなら'), Document(page_content='国語\n算数\n理科\n社会')]
国語 [Document(page_content='国語\n算数\n理科\n社会'), Document(page_content='おはよう\nこんにちは\nさようなら')]
マージしたので2件の検索結果が得られるようになりました。また、クエリに類似したdocが上位の検索結果にきています。
おわりに
MultiVectorRetrieverをマージする方法を検討しました。
docstoreの型はほぼ変更されないと思いますが、vectorstoreに関しては、merge_fromメソッドがないtypeも普通に存在すると思いますので、適宜修正してもらえたらと思います。