LoginSignup
8
4

【langchain】ベクトルストアの保存と読込方法を整理する(save_local vs pickle)

Last updated at Posted at 2023-10-03

概要

langchainで、ベクトルストアを保存するとき、save_localを使う方がいいのか、pickleでまとめて保存する方がいいのかを考えてみました。
結論としては、公式が提供しているsave_local、load_localを素直に使うのが良さそうです。

langchainのバージョンは0.0.304です。

save_localとload_local

公式にサポートされている王道な方法です。

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

embedding = HuggingFaceEmbeddings(model_name = "oshizo/sbert-jsnli-luke-japanese-base-lite")
vectorstore = FAISS.from_texts(
    ["こんにちは","こんばんは","さようなら"]
    , embedding
)
# save
vectorstore.save_local("./vectorstore")

save_localを実行すると、第一引数に指定した名前のディレクトリが作成され、その中に、index.faissとindex.pklが保存されます。ソースコードを読むと、index.faissは埋め込みベクトルの情報、index.pklはdocstoreと埋め込みベクトルとdocstoreのidの対応を管理する情報(index_to_docstore_id)が保存されていることがわかります。

読込時は、load_localを使います。embeddingに使うクラスを初期化して、save_localしたディレクトリ名とともに与える必要があります。

from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

embedding = HuggingFaceEmbeddings(model_name = "oshizo/sbert-jsnli-luke-japanese-base-lite")

# load
vectorstore = FAISS.load_local("./vectorstore", embedding)

ソースコードを読むとload_localは以下のように、index.faiss、docstore、index_to_docstore_id、埋め込み作成関数(embedding.embed_query)で初期化したベクトルストアクラスを返す関数であることがわかります。

return cls(
            embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
        )

長所

公式で提供されているAPIのため安心感があります。特別な理由がなければこれを採用するのが良いと思います。

短所

読込時にembeddingのクラスを与える必要があるため、どのクラスのembeddingを作っていたか覚えておく必要があります。

まとめてpickle

save_localを使わずにまとめてpickleする方法もあります。

# 保存時
import pickle
with open("vectorstore.pkl", "rb") as f:
  pickle.dump(vectorstore, f)
# 読み込み時
import pickle
with open("vectorstore.pkl", "wb") as f:
  vectorstore = pickle.load(f)

長所

読込時にembeddingもまとめて読み込むことができ、個別の情報管理が不要です。

短所

保存時と読込時のlangchainのバージョンが違うと動かなくなりがちです。

試しに0.0.277で保存して、0.0.304で読込して、使ってみます。埋め込みはOpenAIEmbeddingsでやってみます。

pip install langchain==0.0.277
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
import pickle

embedding = OpenAIEmbeddings()
vectorstore = FAISS.from_texts(
    ["こんにちは","こんばんは","さようなら"]
    , embedding
)
# save
with open("vectorstore.pkl", "wb") as f:
  pickle.dump(vectorstore, f)
pip install langchain==0.0.304
import pickle
with open("vectorstore.pkl", "rb") as f:
  vectorstore = pickle.load(f)

vectorstore.similarity_search("こんにちは")
...
AttributeError: 'OpenAIEmbeddings' object has no attribute 'skip_empty'

OpenAIEmbeddingsにskip_emptyという属性がないというエラーが出ました。
おそらく0.0.304の間にOpenAIEmbeddingsにskip_emptyという属性が追加されているのですが、0.0.277の時点ではそれがなかったため、0.0.277でpickleしたものを0.0.304でloadするとエラーが起こったと考えられます。
load_localを使った場合は、embeddingをload時に初期化して与えるため、このエラーは起こりません。もちろんload_localも完璧ではありませんが、pickleを使うこの方法の方が、情報をまるごと保存するため、langchainのバージョンアップの影響は受けやすいです。

もうひとつ、割と致命的な問題があります。それはAPI_KEYが容易に読める形式で一緒に保存されてしまうことです。
読み込んだvectorstoreのうち、embedding_functionをprintすると以下のような感じでOpenAIEmbeddingの属性情報が一覧でき、下記では省略していますが、この中にopenai_api_keyも直書きされています。

print(vectorstore.embedding_function)
<bound method OpenAIEmbeddings.embed_query of OpenAIEmbeddings(client=<class 'openai.api_resources.embedding.Embedding'>, model='text-embedding-ada-002', ...

したがって、少なくともOpenAIEmbeddingsを使っている、かつ、他人と共有する可能性が少しでもあるときは、pickleではなくsave_localを使った方が良いと思います。

おまけ:save/load_localをラップ

save_localとload_localの短所、つまり、embeddingsの種類を記憶しておかなければならないという問題に対処するため、embeddingsの情報も合わせて保存/読込できる、save_localとload_localのラッパー関数を作ってみます。

ライブラリをimportして定数を定義します。

from langchain.vectorstores.faiss import FAISS
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
import langchain.embeddings
from pydantic import BaseModel, field_validator
import json
from pathlib import Path
from typing import List, Dict, Any, Tuple

class Constants:
    DEFAULT_INDEX_NAME: str = "index" # VectorStore.save_localに与える引数の初期値
    DEFAULT_CONFIG_NAME: str = "embedding_config" # VectorStore.save_localに与える引数の初期値
    # Embeddingsの保存・読込に必要なkeyのリスト
    EMBEDDING_CONSTRUCT_KEYS: Dict[str, List[str]] = {
        "OpenAIEmbeddings": ["model"]
        , "HuggingFaceEmbeddings": ["model_name"]
        , "DEFAULT": ["model_name"] 
    }

embeddingsの設定を保存する関数を作ります。

class EmbeddingConfig(BaseModel):
    embedding_name: str
    options: Dict[str, Any] # Embeddingsのコンストラクタに与える引数の辞書

    @field_validator('embedding_name')
    def validate_embedding_type(cls, value):
        if not value.endswith("Embeddings") or not hasattr(langchain.embeddings, value):
            raise ValueError("Embedding name is invalid")
        return value
    

def save_embedding_config(embeddings: Embeddings, output_path: str):
    embedding_name = embeddings.__class__.__name__
    options = {}

    for key in Constants.EMBEDDING_CONSTRUCT_KEYS.get(embedding_name, Constants.EMBEDDING_CONSTRUCT_KEYS["DEFAULT"]):
        if getattr(embeddings, key):
            options[key] = getattr(embeddings, key)
    embedding_config = EmbeddingConfig(embedding_name=embedding_name, options=options)
    with open(output_path, "w") as f:
        json.dump(embedding_config.model_dump(), f)

def load_embedding_config(input_path: str) -> Embeddings:
    with open(input_path) as f:
        loaded_data = json.load(f)
    embedding_config = EmbeddingConfig(**loaded_data)
    embedding = getattr(langchain.embeddings, embedding_config.embedding_name)(**embedding_config.options)
    return embedding

保存するoptions(初期化時の引数)をConstants.EMBEDDING_CONSTRUCT_KEYSで制限することで、API_KEYなど公開したくない情報が保存されないようにしています。
保存・読込時にはembeddings.__class__.__name__やgetattrを使って、クラスから直接情報を取得することで、embeddingsのtypeによらずある程度汎用的に使えるように工夫しています。

最後にVectorStoreをsave、loadする関数を定義します。

def save_vectorstore(vectorstore: VectorStore, embeddings: Embeddings, dirpath: str
                     , *
                     , index_name = Constants.DEFAULT_INDEX_NAME
                     , config_name = Constants.DEFAULT_CONFIG_NAME) -> None:
    if not hasattr(vectorstore, "save_local"):
        raise ValueError("vectorstore must have save_local method")
    vectorstore.save_local(Path(dirpath), index_name)
    save_embedding_config(embeddings, Path(dirpath).joinpath(f"{Path(config_name).stem}.json"))

def load_vectorstore(dirpath: str, *
               , index_name = Constants.DEFAULT_INDEX_NAME
               , config_name=Constants.DEFAULT_CONFIG_NAME
               , vectorstore_class = FAISS
               , return_only_vectorstore = False
               )->Tuple[VectorStore, Embeddings]:
    
    embeddings = load_embedding_config(Path(dirpath).joinpath(f"{Path(config_name).stem}.json"))
    if not hasattr(vectorstore_class, "load_local"):
        raise ValueError("vectorstore_class must have load_local method")
    vectorstore = vectorstore_class.load_local(dirpath, embeddings, index_name)
    if return_only_vectorstore:
        return vectorstore
    else:
        return vectorstore, embeddings

vectorstoreは基本的にはFAISSクラスを想定しています。少なくともsave_local、load_localメソッドを持つことを前提としています。このあたりは完全な汎用的にはできなかったので、vectorstoreのクラスはある程度固定する運用を前提として、必要に応じて書き換える必要があります。

vectorstore自体にはEmbeddingsの情報は保存されていないため、save_vectorstoreにはembeddingsを一緒に与える必要があります。VectorStore.embeddingsやVectorStore.embedding_funcitonという属性はあるのですが、いずれも元のEmbeddingsクラスを推定するには不十分な情報でした。

load_vectorstoreの戻り値は、VectorStoreのみであるほうが自然に思えるのですが、上述のsave_vectorstoreでEmbeddingsを与える必要がある関係上、Embeddingsも一緒にloadできるほうが汎用性が高いと考えて、embeddingsも返すようにしています。一応、不要な場合には、return_only_vectorstore = Trueを指定することで、vectorstoreだけを返すようにしてあります。

長所

save時にEmbeddingsの情報も合わせて辞書で保存することで、load時にembeddingsを指定する必要がなくなっています。またpickleと違い、api_keyの保存は回避できています。要は、embeddingsが特定できれば良いだけなので、pydantic.BaseModelまで使うのは大袈裟かもしれず、この実装はあくまで一例です。

短所

save時にEmbeddingsも合わせて与える必要があるのはpickleと比べて面倒です。また関数が全体的に複雑です。loadでembeddingsを返す実装も直感的ではないです。

理想はsave時にvectorstoreのみを引数として与えて保存できることなのですが、api_keyが保存されることを回避しつつそれを実現することは、なかなか難しかったです。

補足: EmbeddingConfingの改良ポイント

現状は、pydantic.BaseModelを継承したEmbeddingConfigというクラスでembeddingの設定値を管理しており、embedding_nameがlangchain.embeddingsに存在しない名前の場合は、エラーとなるようにfield_validatorで設定しています。使うEmbeddingsのtypeが固定されているならば、より限定的に書いたほうが、管理しやすいと思います。汎用性とのトレードオフです。例えば以下のような感じです。

# OpenAIEmbeddings以外のembedding_nameに対し、エラーとしたい場合
class EmbeddingConfig(BaseModel):
    embedding_name: str
    options: Dict[str, Any] # Embeddingsのコンストラクタに与える引数の辞書

    @field_validator('embedding_name')
    def validate_embedding_type(cls, value):
        if value not in ["OpenAIEmbeddings"]:
            raise ValueError("Embedding name is invalid")
        return value
8
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
4