本年もよろしくお願いします。そういえばHyDEやっていませんでした。
導入
私が学ぶRAGの実質9回目です。シリーズ一覧はこちら。
今回はHyDEによるRAGの改善です。
これは何?
日本語だと以下の記事が詳しいのかなと思います。
大本の論文はこちら。
ざっくり概要を記載すると、HyDE(Hypothetical Document Embeddings)は、検索を強化するための手法です。
一般的にはRAGにおいては、検索クエリをEmbeddingを使ってベクトル変換しますが、HyDEではクエリに基づいて仮想的な応答を作成し、それをベクトル変換した結果を用いて検索をかけます。
今回は以下のLangchain Templatesのコードを基に実践してみます。
DatabricksのDBRは14.1 ML、GPUクラスタで動作を確認しています。
なお、以下Step3まではシリーズ内の以前の記事と同じ内容となります。
Step0. パッケージインストール
使うモジュールをインストールします。
%pip install -U -qq transformers accelerate langchain faiss-cpu
# CUDA 11.8用
%pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp310-cp310-linux_x86_64.whl
%pip install "databricks-feature-engineering"
dbutils.library.restartPython()
AutoAWQはCUDA 11.8版をWheelを指定してインストールしています。
(DatabricksのDBR ML 14台はCUDAのバージョンが11.8のため)
Step1. Document Loading
準備編で保管した特徴量を取得します。
読み込んだデータは、docs
という名前のビューから参照できるようにします。
from databricks.feature_engineering import FeatureEngineeringClient
fe = FeatureEngineeringClient()
feature_name = "training.llm.sample_doc_features"
df = fe.read_table(name=feature_name)
df.createOrReplaceTempView("docs")
Step2. Splitting
ごくごく単純なText Splitterを使って長文データをチャンキングします。
ベーシックなRAGのときと同じです。
from typing import Any
import pandas as pd
from pyspark.sql.functions import pandas_udf
from langchain.text_splitter import RecursiveCharacterTextSplitter
class JapaneseCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""句読点も句切り文字に含めるようにするためのスプリッタ"""
def __init__(self, **kwargs: Any):
separators = ["\n\n", "\n", "。", "、", " ", ""]
super().__init__(separators=separators, **kwargs)
@pandas_udf("array<string>")
def split_text(texts: pd.Series) -> pd.Series:
# 適当なサイズとオーバーラップでチャンク分割する
text_splitter = JapaneseCharacterTextSplitter(chunk_size=200, chunk_overlap=40)
return texts.map(lambda x: text_splitter.split_text(x))
# チャンキング
df = spark.table("docs")
df = df.withColumn("chunk", split_text("page_content"))
# Pandas Dataframeに変換し、チャンクのリストデータを取得
pdf = df.select("chunk").toPandas()
texts = list(pdf["chunk"][0])
Step3. Storage
チャンクデータに埋め込み(Embedding)を行い、ベクトルストアへデータを保管します。
Embeddingは事前にダウンロードしておいた以下のモデルを利用します。
まずはEmbeddingモデルのロード。
import torch
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
device = "cuda" if torch.cuda.is_available() else "cpu"
embedding_path = "/Volumes/training/llm/model_snapshots/models--intfloat--multilingual-e5-large"
embedding = HuggingFaceEmbeddings(
model_name=embedding_path,
model_kwargs={"device": device},
)
FAISSでベクトルストアを作成します。
vectorstore = FAISS.from_texts(texts, embedding)
Step4. LLM Preparation
LLMをロードします。
以下のモデルを事前にダウンロードして利用しました。
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_chat import ChatHuggingFaceModel
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-1210-AWQ"
generator = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
Step5. HyDE Chain
今回のポイントです。
HyDEによる応答生成のためのChainChainを作成します。
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
hyde_template = """Please write a passage to answer the question
Question: {question}
Passage:"""
hyde_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template(hyde_template),
AIMessagePromptTemplate.from_template(""),
]
)
hyde_chat_model = ChatHuggingFaceModel(
generator=generator,
tokenizer=tokenizer,
human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
ai_message_template="GPT4 Correct Assistant: {}",
repetition_penalty=1.2,
temperature=0.1,
max_new_tokens=400,
)
hyde_chain = hyde_prompt | hyde_chat_model | StrOutputParser()
コード中にあるように、プロンプトはシンプルにPlease write a passage to answer the question
という指示にしています。
ChainはRetrieverは含まれておらず、LLMが内包する知識のみで応答を作成させるようになっています。
どのような応答が作成されるか、動作確認してみましょう。
hyde_chain.invoke({"question": "この契約における契約保証金はどのような扱いなのか?"})
この契約における契約保証金は、契約の対象となる仕事が完了した後に返金されるものであり、契約の成立に伴う押さえ金として機能します。契約保証金は、契約の成立時に支払われ、契約の期間中には特定の条件に従って管理されます。契約が成立し、仕事が完了した後、契約保証金は契約の対象となる仕事が正常に行われたことを確認した上で、返金されることができます。ただし、契約の条件によっては、契約保証金が返金されない状況があることもあります。このように、契約保証金は契約の成立や仕事の完了に関連するリスクを軽減するために使用されるものであり、契約の成功や失敗に関する重要な要素となります。
契約内容がどうなっているか、というより「契約保証金」そのものがどのようなものかを説明する応答が得られました。
※ OpenChatが単体でこれだけ破綻のない日本語内容を返すことに驚き。
HyDEを使ったRAGでは、この契約における契約保証金はどのような扱いなのか?
というクエリではなく、この生成された応答をRetrieverに与えて関連するドキュメントを取得する流れになります。
では、このHyDE用Chainを使ったRAGを実行してみましょう。
Step6. RAG Chain
HyDEを組み込んだChainは以下のようになります。
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
retriever = vectorstore.as_retriever()
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
response_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template(template),
AIMessagePromptTemplate.from_template(""),
]
)
chat_model = ChatHuggingFaceModel(
generator=generator,
tokenizer=tokenizer,
human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
ai_message_template="GPT4 Correct Assistant: {}",
repetition_penalty=1.2,
temperature=0.1,
max_new_tokens=1024,
)
# RAG chain
chain = (
{
# Generate a hypothetical document and then pass it to the retriever
"context": hyde_chain | retriever,
"question": lambda x: x["question"],
}
| response_prompt
| chat_model
| StrOutputParser()
)
Chain自体はベーシックなRAGのChainです。ただし、Retrieverへの入力がhyde_chain
を通した出力結果となっている点が異なります。
Step7. Run
準備が整いましたので、ストリーミング出力で実行してみます。
ちなみに、今回の質問の答え(契約書内の該当条項抜粋)は以下のようになります。
甲は、本契約に係る乙が納付すべき契約保証金の納付を全額免除する。
for s in chain.stream({"question": "この契約における契約保証金はどのような扱いなのか?"}):
print(s, end="", flush=True)
この契約における契約保証金は、甲は本契約に係る乙が納付すべき契約保証金の納付を全額免除する。
回答としては正しい結果を得られました。
contextとして挿入されるhyde_chain | retriever
で得られた結果部分を確認してみましょう。
dummy_chain = hyde_chain | retriever
dummy_chain.invoke({"question": "この契約における契約保証金はどのような扱いなのか?"})
[Document(page_content='(契約保証金) \n第3条 甲は、本契約に係る乙が納付すべき契約保証金の納付を全額免除する。 \n \n (知的財産 権の帰属及び 使用) \n第4条 本契約の締結時に乙が既に所有又は管理していた 知的財産権(以下「 乙知的財産\n権」という。)を 乙が納入物に使用した場合には、甲は、当該乙知的財産権を、仕様書\n記載の「目的」のため、仕様書の「納入物」の項 に記載した利用方法に従い、本契約終'),
Document(page_content='があったとき。 \n(4)前 3号に定めるもののほか、乙が本契約の規定に違反したとき。 \n2 甲は、前項の規定により 本契約を解除した場合において、委託金の全部又は一部を 乙\nに支払っているときは、その全部又は一部を期限を定めて返還させることができる。 \n \n(延滞金) \n第20条 乙は、第 18条第 1項の規定によ り甲に確定額を超える 額を返納告知のあった'),
Document(page_content='る法律(平成 12年法律第100号)第6条第1項の規定に基づき定められた環境物品\n等の調達の推進に関する基本方針( 令和5年2月24日変更閣議決定)によ る紙類の印\n刷用紙及び役務の印刷の基準を満たすこととし、様式第1により作成した印刷物基準実\n績報告書を納入物とともに甲に提出しなければならない。 \n \n(契約保証金)'),
Document(page_content='甲に返還しなければならない。 \n2 乙が第16条第2項の規定 により概算払を受領している場合であっ て、当該概算払の\n合計額が確定額に満たない ときには、第16条第1項を準用する。 \n \n(契約の解除等) \n第19条 甲は、乙が次の各号のいずれかに該当するときは、催告を要さず 直ちに本契約\nの全部又は一部 を解除することができる。この場合、甲は乙に対して委託金その他これ')]
契約保証金の条項を含むチャンクがトップに得られています。
Step8. もし、HyDEを使わなかったら?
HyDEを使うことでRAGの性能はあがったのでしょうか?
使わない場合のChainを作って比較してみます。
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
retriever = vectorstore.as_retriever()
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
response_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template(template),
AIMessagePromptTemplate.from_template(""),
]
)
chat_model = ChatHuggingFaceModel(
generator=generator,
tokenizer=tokenizer,
human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
ai_message_template="GPT4 Correct Assistant: {}",
repetition_penalty=1.2,
temperature=0.1,
max_new_tokens=1024,
)
# RAG chain
chain2 = (
{
# 単純にretrieverからコンテキストを取得するように変更
"context": retriever,
"question": RunnablePassthrough(),
}
| response_prompt
| chat_model
| StrOutputParser()
)
for s in chain2.stream("この契約における契約保証金はどのような扱いなのか?"):
print(s, end="", flush=True)
この契約における契約保証金は、全額免除される。
だいぶシンプルですが、正しく回答されているように見えます。
ちなみに、Retrieverから得られた結果は以下のようになります。
dummy_chain2 = retriever
dummy_chain2.invoke("この契約における契約保証金はどのような扱いなのか?")
[Document(page_content='(契約保証金) \n第3条 甲は、本契約に係る乙が納付すべき契約保証金の納付を全額免除する。 \n \n (知的財産 権の帰属及び 使用) \n第4条 本契約の締結時に乙が既に所有又は管理していた 知的財産権(以下「 乙知的財産\n権」という。)を 乙が納入物に使用した場合には、甲は、当該乙知的財産権を、仕様書\n記載の「目的」のため、仕様書の「納入物」の項 に記載した利用方法に従い、本契約終'),
Document(page_content='に譲渡し、又は承継させてはならない。ただし、信用保証協会、資産の流動化に関する\n法律(平成10年法律第105号)第2条第3項に規定する特定目的会社 又は中小企業\n信用保険法施行令(昭和25年政令第350号)第1条の 3に規定する金融機関 に対し\nて債権を譲渡する場合にあっては、この限りでない。 \n2 乙が本契約により行うこととされた 全ての給付を完了 する前に、乙が 前項ただし書に'),
Document(page_content='る法律(平成 12年法律第100号)第6条第1項の規定に基づき定められた環境物品\n等の調達の推進に関する基本方針( 令和5年2月24日変更閣議決定)によ る紙類の印\n刷用紙及び役務の印刷の基準を満たすこととし、様式第1により作成した印刷物基準実\n績報告書を納入物とともに甲に提出しなければならない。 \n \n(契約保証金)'),
Document(page_content='の納付の日までの日 数に応じ、年3パーセントの割合により計算 した利息を付すことが\nできる。 \n \n(乙による 公表の禁止) \n第25条 乙は、甲の許可を得ないで委託業務の内容を公表してはならない。 \n \n(情報セキュリティの確保) \n第26条 乙は、契約締結後速やかに、情報セキュリティを確保するための体制 並びに本')]
こちらでも、最上位に契約保証金条項を含むチャンクが得られていました。
というわけで、この質問ではHyDEを使用有無で精度は変わらなかった、ということになります。
ただ、他のドキュメントでは異なるものを取得してきたりしているので、ケースによってはHyDEの方がよい精度を得られる可能性はあります。
まとめ
HyDEを使ったRAGを実践しました。
HyDEは仮の応答を一度作るため、通常のRAGに比べて最終的な回答生成までに時間がかかります。(=レスポンス性能の悪化)
上で紹介したリンク先でも否定的に記載されていましたが、私も精度向上のメリットとレスポンス性能悪化のデメリットを比較するに、実用範囲は限定的かなという印象を持っています。
ただ、最近は7B(以下の)サイズのLLMでも十分な性能を持つモデルが増え、かつ推論速度も非常に高速になってきています。
そのため、HyDEだけではありませんが、LLMを使ってクエリ変換するRAG処理においては、メリット・デメリットのバランスがメリット側にどんどん傾く可能性もあり、実用範囲が拡大していくかもしれません。
うまくケースバイケースで使いこなせていきたいなと思います。