1
1

ベクトル検索とWHERE句とのハイブリッド検索

Last updated at Posted at 2024-04-11

はじめに

目的

LangChainアプリケーションでベクトル検索とWHERE句/条件検索とのハイブリッド検索を行う

手段

  • 汎用的なデータベースであるApache Cassandraに対するベクトル検索機能拡張が行われた最新バージョンを利用する(OSSでのGAはまだだが、先行して利用可能なCassandra-as-a-serviceであるDataStax Astra DBを利用する)
  • クエリの実装にLangChainのBaseRetriever拡張クラスを利用する
  • 検索条件を引き渡すために、LangChainの(エージェントではなく、より原始的な)チェインによる構成を採用する

実装の要点

データベースの準備

テーブル定義

session.execute("""
CREATE TABLE IF NOT EXISTS bookstore.book_hybrid (
id text,
title text,
author text,
publisher text,
price int,
year int,
description text,
sem_vec vector<float, 768>,
PRIMARY KEY((author, publisher), title, id)
);
        """
)

インデックス定義

session.execute("CREATE CUSTOM INDEX IF NOT EXISTS idx_price_book_hybrid ON bookstore.book_hybrid(price) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'")
session.execute("CREATE CUSTOM INDEX IF NOT EXISTS idx_year_book_hybrid ON bookstore.book_hybrid(year) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'")
session.execute("CREATE CUSTOM INDEX IF NOT EXISTS idx_sem_vec_book_hybrid ON bookstore.book_hybrid(sem_vec) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'")

テーブルへの直接的な操作

データ挿入

query = "INSERT INTO bookstore.book_hybrid(id,title,author,publisher,price,year,description,sem_vec) VALUES (?,?,?,?,?,?,?,?)"
prepared = session.prepare(query)
params_list = []
for index, row in df_vec_added.iterrows():
  params = [str(row['ID']),row['title'],row['author'],row['publisher'],row['price'],row['year'],row['description'],row['sem_vec']]
  params_list.append(params)
from cassandra.concurrent import execute_concurrent_with_args
execute_concurrent_with_args(session, prepared, params_list)
from cassandra.query import SimpleStatement
query = SimpleStatement("""
SELECT
similarity_cosine(sem_vec, %s) as similarity,
title, author, publisher, price, description
FROM bookstore.book_hybrid
WHERE price > %s
ORDER BY sem_vec ANN OF %s LIMIT 5;
"""
)
price = 5000

results = session.execute(query, (emb_query, price, emb_query))
for row in results:
  print(row)

LangChainフレームワーク利用

カスタムリトリーバー定義


from cassandra.cluster import Session
from cassandra.query import PreparedStatement
from langchain_core.embeddings import Embeddings

class HybridBookRetriever(BaseRetriever):
    session: Session
    embeddings: Embeddings
    search_statement: PreparedStatement = None

    class Config:
        arbitrary_types_allowed = True

    def get_relevant_documents(self, query, **kwargs):
        docs = []
        embeddingvector = self.embeddings.embed_query(query)
        if self.search_statement is None:
            self.search_statement = self.session.prepare("""
                SELECT
                    id,
                    similarity_cosine(sem_vec, ?) as similarity,
                    title,
                    author,
                    publisher,
                    price,
                    description
                FROM bookstore.book_hybrid
                WHERE price < ?
                ORDER BY sem_vec ANN OF ?
                LIMIT ?
                """)
        query = self.search_statement
        max_price = 0
        if 'max_price' in kwargs:
          max_price = kwargs['max_price']
        results = self.session.execute(query, 
                            [embeddingvector, max_price, embeddingvector, 5])
        top_products = results._current_rows
        for r in top_products:
            if r.similarity > 0.8:
                docs.append(Document(
                    id=r.id,
                    page_content=r.title,
                    metadata={"id": r.id,
                            "title": r.title,
                            "author": r.author,
                            "publisher": r.publisher,
                            "description": r.description,
                            "price": r.price
                            }
                ))

        return docs
     

検索条件はkwargsで受け渡している。

        if 'max_price' in kwargs:
          max_price = kwargs['max_price']

チェインによる実装

インプット定義

from langchain.schema.runnable import RunnableMap

inputs = RunnableMap({
            'context': lambda x: hybrid_retriever.get_relevant_documents(x['question'], **{'max_price': x['max_price']}),
            'question': lambda x: x['question']
        })

モデル定義

def load_model():
    print("load_model")
    # Get the OpenAI Chat Model
    return ChatOpenAI(
        temperature=0.3,
        model='gpt-3.5-turbo',
        streaming=True,
        verbose=True
    )

プロンプト定義


from langchain.prompts import ChatPromptTemplate
def load_prompt():
    print("load_prompt")
    template = """You're a helpful AI assistant tasked to answer the user's questions.
You're friendly and you answer extensively with multiple sentences. You prefer to use bulletpoints to summarize.
If you don't know the answer, just say 'I do not know the answer'.

Use the following context to answer the question:
{context}

Question:
{question}

Answer in the user's language:"""
    return ChatPromptTemplate.from_messages([("system", template)])

チェインの構築

model = load_model()
prompt = load_prompt()
chain = inputs | prompt | model

実行

question = "データサイエンスについての本を教えて"
max_price = 5000

response = chain.invoke({'question': question, 'max_price': max_price})
print(f"Response: {response}")

結果

以下、本記事中で読みやすいフォーマットでOpenAIが出力した結果を示す。


データサイエンスに関する本をいくつか紹介しますね:

  • 『データサイエンスとビッグデータ』(著者:石田美香、出版社:技術書典、価格:5300円):データサイエンスとビッグデータの基本原則から実践までの手法を学ぶことができます。
  • 『データサイエンスと医療』(著者:大谷花子、出版社:データ出版社、価格:5400円):医療データの解析とデータサイエンスの応用に焦点を当てた入門書です。

これらの本はデータサイエンスに興味がある方にとって役立つ情報が豊富に含まれています。興味があればぜひチェックしてみてください!


参考

本文中、ハイブリッド検索とは関係のない(いわば定型的な)部分のコードは省略している。以下で全体を確認することが可能。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1