8
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

langchainとDatabricksで(私が)学ぶRAG : MultiQueryRetrieverを使ったRAG

Posted at

導入

私が学ぶRAGの実質4回目です。シリーズ一覧はこちら
今回はMultQueryRetrieverを使ったRAGです。

langchainの公式Docは以下です。

これは何?

上の公式Docより、ざっくり和訳。

距離ベースのベクトルデータベース検索は、クエリを高次元空間に埋め込み(表現し)、「距離」に基づいて類似の埋め込み文書を見つけます。しかし、クエリの表現が微妙に変わったり、埋め込みがデータのセマンティクスをうまく捉えていなかったりすると、検索結果が異なることがあります。このような問題に手動で対処するために、プロンプトエンジニアリングやチューニングが行われることもありますが、面倒な作業となります。

MultiQueryRetrieverは、 LLMを用いて、与えられたユーザ入力クエリに対して異なる視点から複数のクエリを生成することで、 プロンプトチューニングのプロセスを自動化します。各クエリに対して、関連するドキュメントのセットを取得し、すべてのクエリにまたがるユニークな結合を取ることで、より大きな関連する可能性のあるドキュメントのセットを取得します。同じ質問に対する複数の視点を生成することで、MultiQueryRetrieverは距離ベースの検索の制限のいくつかを克服し、より豊かな結果セットを得ることができるかもしれません。

というわけで、問い合わせからそのまま関連文書を取得するのではなく、問い合わせを使って似た表現の問い合わせを複数生成し、それぞれで文書を取得、得られた結果全てをマージして利用する、ということを行います。
これによって、表現の揺れによる検索結果の違いを抑制し、安定的な結果を得ることができるようになります。

というわけで、やってみましょう。

DatabricksのDBRは14.1 ML、GPUクラスタで動作を確認しています。

Step0. モジュールインストール

今後、使うモジュールをインストールします。

%pip install -U -qq transformers accelerate langchain faiss-cpu
%pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp310-cp310-linux_x86_64.whl
%pip install "databricks-feature-engineering" 

dbutils.library.restartPython()

AutoAWQはCUDA 11.8版をWheelを指定してインストールしています。
(DatabricksのDBR ML 14台はCUDAのバージョンが11.8のため)

Step1. Document Loading

準備編で保管した特徴量を取得します。
読み込んだデータは、docsという名前のビューから参照できるようにします。

from databricks.feature_engineering import FeatureEngineeringClient

fe = FeatureEngineeringClient()

feature_name = "training.llm.sample_doc_features"
df = fe.read_table(name=feature_name)

df.createOrReplaceTempView("docs")

Step2. Splitting

ごくごく単純なText Splitterを使って長文データをチャンキングします。
ベーシックなRAGのときと同じです。

from typing import Any
import pandas as pd
from pyspark.sql.functions import pandas_udf
from langchain.text_splitter import RecursiveCharacterTextSplitter


class JapaneseCharacterTextSplitter(RecursiveCharacterTextSplitter):
    """句読点も句切り文字に含めるようにするためのスプリッタ"""

    def __init__(self, **kwargs: Any):
        separators = ["\n\n", "\n", "", "", " ", ""]
        super().__init__(separators=separators, **kwargs)


@pandas_udf("array<string>")
def split_text(texts: pd.Series) -> pd.Series:

    # 適当なサイズとオーバーラップでチャンク分割する
    text_splitter = JapaneseCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
    return texts.map(lambda x: text_splitter.split_text(x))


# チャンキング
df = spark.table("docs")
df = df.withColumn("chunk", split_text("page_content"))

# Pandas Dataframeに変換し、チャンクのリストデータを取得
pdf = df.select("chunk").toPandas()
texts = list(pdf["chunk"][0])

print(len(texts))
print(texts)

Step3. Storage

チャンクデータに埋め込み(Embedding)を行い、ベクトルストアへデータを保管します。

まずはEmbedding用のモデルをロード。
以下のモデルをダウンロードしたものを利用します。

import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_path = "/Volumes/training/llm/model_snapshots/models--intfloat--multilingual-e5-large"

embedding = HuggingFaceEmbeddings(
    model_name=embedding_path,
    model_kwargs={"device": device},
)

FAISSでベクトルストアを作成します。

vectorstore = FAISS.from_texts(texts, embedding)

Step4. LLM Preparation for Retriever

MultiQueryRetrieverで複数クエリを生成するためのLLMをロードします。
今回も、前回同様、以下のモデルを利用しました。

事前にダウンロードしてあるものを利用します。

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_chat import ChatHuggingFaceModel

model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat_3.5-AWQ"

generator = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
retriever_llm = ChatHuggingFaceModel(
    generator=generator,
    tokenizer=tokenizer,
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>GPT4 Correct Assistant: ",
    repetition_penalty=1.2,
    temperature=0.1,
    max_new_tokens=1024,
)

Step5A. MultiQueryRetriever preparation

今回のポイントです。

MultiQueryRetrieverを作成します。

from langchain.retrievers.multi_query import MultiQueryRetriever

retriever_from_llm = MultiQueryRetriever.from_llm(
    retriever=vectorstore.as_retriever(), llm=retriever_llm
)

どのようなクエリを生成したかを確認するために、loggerを設定します。

# Set logging for the queries
import logging

logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

それでは動作確認してみましょう。

# Test

question = "この契約において知的財産権はどのような扱いなのか?"
unique_docs = retriever_from_llm.get_relevant_documents(query=question)
len(unique_docs)
出力
INFO:langchain.retrievers.multi_query:Generated queries: ['1. この契約での知的財産権の管理方法はどのようなものですか?', '2. この契約において知的財産権の利用権はどのように決定されますか?', '3. この契約による知的財産権の保護方法はどのようなものですか?']
5

オリジナルの問い合わせに対して、若干異なる表現のクエリが3件生成され、それを使って5件の類似文書を取得できました。

Step5B. MultiQueryRetriever preparation with Custom Prompt

複数のクエリを生成する際に、オリジナルのプロンプトを指定して作成することもできます。

まずは、独自プロンプトを使うLLMChainを作成します。
ここは公式Doc通りに、類似クエリを5件生成する内容にしました。

from typing import List
from langchain.chains import LLMChain
from pydantic import BaseModel, Field
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser


# Output parser will split the LLM result into a list of queries
class LineList(BaseModel):
    # "lines" is the key (attribute name) of the parsed output
    lines: List[str] = Field(description="Lines of text")


class LineListOutputParser(PydanticOutputParser):
    def __init__(self) -> None:
        super().__init__(pydantic_object=LineList)

    def parse(self, text: str) -> LineList:
        lines = text.strip().split("\n")
        return LineList(lines=lines)


output_parser = LineListOutputParser()

QUERY_PROMPT = PromptTemplate(
    input_variables=["question"],
    template="""You are an AI language model assistant. Your task is to generate five 
    different versions of the given user question to retrieve relevant documents from a vector 
    database. By generating multiple perspectives on the user question, your goal is to help
    the user overcome some of the limitations of the distance-based similarity search. 
    Provide these alternative questions separated by newlines.
    Original question: {question}""",
)

# Chain
llm_chain = LLMChain(llm=retriever_llm, prompt=QUERY_PROMPT, output_parser=output_parser)

作成したLLMChainを使ってMultiQueryRetrieverを作成します。

from langchain.retrievers.multi_query import MultiQueryRetriever

retriever_from_llm2 = MultiQueryRetriever(
    retriever=vectorstore.as_retriever(), llm_chain=llm_chain, parser_key="lines"
)

# Test
question = "この契約において知的財産権はどのような扱いなのか?"
unique_docs = retriever_from_llm2.get_relevant_documents(query=question)
len(unique_docs)
出力
INFO:langchain.retrievers.multi_query:Generated queries: ['1. この契約での知的財産権の管理方法はどのようなものですか?', '2. この契約において知的財産権についての規定はどのようなものがありますか?', '3. この契約によって知的財産権に関する制度はどのようなものですか?', '4. この契約に基づく知的財産権の利用方法はどのようなものですか?', '5. この契約による知的財産権の保護はどのような手順がありますか?']
8

指定したプロンプト通り、5件のクエリが生成されました。

では、このRetrieverを使って、実際にChainを組んで実行してみましょう。

Step6. Chain creation

まずは構成要素を準備していきます。

Prompt Template

簡単なチャットテンプレートを準備。

from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

template = """次のcontextの内容のみを使い、なるべく平易な文章を使って日本語で質問に回答してください。
{context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_messages(
    [
        HumanMessagePromptTemplate.from_template(template),
        AIMessagePromptTemplate.from_template(""),
    ]
)

LLM

LLM(ChatModel)はRetrieverで利用しているものと同じにしました

chat_model = ChatHuggingFaceModel(
    generator=generator,
    tokenizer=tokenizer,
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
    ai_message_template="GPT4 Correct Assistant: {}",
    repetition_penalty=1.2,
    temperature=0.1,
    max_new_tokens=1024,
)

Chain

これまで作成した構成要素を組み合わせてChainを作成します。

from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

chain = (
   {"context": retriever_from_llm2, "question": RunnablePassthrough()} 
   | prompt 
   | chat_model 
   | StrOutputParser()
)

Step7. Run

準備が整いましたので、ストリーミング出力で実行してみます。

for s in chain.stream("この契約において知的財産権はどのような扱いなのか?"):
    print(s, end="", flush=True)
出力
INFO:langchain.retrievers.multi_query:Generated queries: ['1. この契約での知的財産権の管理方法はどのようなものですか?', '2. この契約において知的財産権についての規定はどのようなものがありますか?', '3. この契約によって知的財産権に関する制度はどのようなものですか?', '4. この契約に基づく知的財産権の利用方法はどのようなものですか?', '5. この契約による知的財産権の保護はどのような手順がありますか?']
この契約では、新規知的財産権は、納入物の引渡しと同時に甲から乙に譲渡され、甲単独に帰属する。また、乙は、本契約終了後であっても、知的財産権の取扱いに関する本契約の約定を自ら遵守し、再委託先に遵守させることを約束する。

クエリを5個生成し、そこから得られた類似文書を使って回答を生成できました。

まとめ

MultiQueryRetrieverを使ったRAGを実践してみました。
個人的に、生成される類似クエリで大きな意味の揺らぎが起きないように気を付ける必要があるかなと思いました。素人なので、正直よくわかっていないところがあります。。。

なお、langchain templatesでは以下などでサンプルを確認することができます。

次回はRAG-Fusionを実践する予定です。

8
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?