私が学ぶRAGの実質8回目です。シリーズ一覧はこちら。
今回はRewrite-Retrieve-ReadによるRAGの実践です。
以下のウォークスルーに近い内容となります。
これは何?
Langchainの下記Blog内 Rewrite-Retrieve-Read のセクションで解説されています。
ざっくり邦訳+図を掲載。
このホワイトペーパーでは、生のユーザークエリを使用して直接取得するのではなく、LLMを使用してユーザークエリを書き換えます。
元のクエリは、特に現実の世界では、LLMで取得するのに常に最適であるとは限らないためです...まず、LLMにクエリを書き直すように促し、次に検索拡張読み取りを実行します。
例えば、RAGのデータソースとしてWeb検索エンジンを利用する形態を想定するとします。
通常、ユーザが入力するクエリは「Langchainってどういう機能があるの?」というような自然言語の形態になりますが、Web検索エンジンに対しては「Langchain 機能」という指定の方がよい結果になることがあります。
Rewrite-Retrieve-Readは、生のユーザクエリをそのまま利用するのではなく、後続のRetrieverに向けて適したクエリに変換(Rewrite)する処理を挟んで適切なドキュメントを取得しやすくするテクニックです。
図だとさらに得られた結果を基にRewrite処理をアップデートするパイプラインまで提案されているようなのですが、今回はWeb検索エンジンに与えるクエリを変換する流れを実践してみます。
実践環境はDatabricks、DBRは14.1 ML、GPUクラスタで動作を確認しています。
実装に関しては、以下のLangchain Templateを参考にしました。
Step0. Package Install
必要なパッケージをインストール。
%pip install -U -qq transformers accelerate langchain
# CUDA 11.8用
%pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U duckduckgo-search
dbutils.library.restartPython()
Web検索エンジンとしてDuckDuckGoを利用するため、関連するパッケージをインストールしています。
他のパッケージの説明については、準備編を参照してください。
Step1. LLM Loading
LLMをロードしておきます。
今回はOpenChat-3.5 1210を採用しました。いつものように、事前にダウンロードしておいたTheBloke兄貴のAWQ量子化したモデルを利用します。
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_chat import ChatHuggingFaceModel
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-1210-AWQ"
generator = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
Step2. DuckDuckGo Retriever
DuckDuckGoから検索結果を得るRetrieverを用意します。
_parse
関数は後ほどChainの中でRetrieverにクエリを渡す前に文字列から不要な記号等を除去するために使います。
from langchain.utilities import DuckDuckGoSearchAPIWrapper
search = DuckDuckGoSearchAPIWrapper()
def retriever(query):
return search.run(query)
def _parse(text):
return text.strip("**").strip()
Step3. Rewrite-Retrieve-Read Chain
Rewrite-Retrieve-ReadのRewriteを実行するChainを作成します。
ざっくり言えば、「ユーザの生クエリ(質問)を基に、Web検索に適したクエリを生成する」というプロンプトを作ってLLMに渡しているだけです。非常にシンプル。
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema.output_parser import StrOutputParser
template = """Provide a better search query for web search engine to answer the given question, \
just 1 answer, no quote, end the each queries with ’**’.
Question:
{x}
Answer:"""
rewrite_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template(template),
AIMessagePromptTemplate.from_template(""),
]
)
gen_model = ChatHuggingFaceModel(
generator=generator,
tokenizer=tokenizer,
human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
ai_message_template="GPT4 Correct Assistant: {}",
repetition_penalty=1.2,
temperature=0.001,
max_new_tokens=1024,
)
rewriter = rewrite_prompt | gen_model | StrOutputParser() | _parse
実際にこのChainを通すと、どのように変換されるか確認してみましょう。
こちらの例にあるクエリで実行してみます。
print(rewriter.invoke({"x": "man that sam bankman fried trial was crazy! what is langchain?"}))
what is langchain in finance context
イマイチこの例はわかりづらいのですが、実際に確認したいのはlangchainとは何か?ということを検索したいようで、変換後の出力はそれに沿った検索ワードになっていることが分かります。
ちなみに、実際に両クエリで検索してみると以下のような結果となります。
生のクエリを使った検索
search.run("man that sam bankman fried trial was crazy! what is langchain?")
"Sam Bankman-Fried was found guilty Thursday following a monthlong trial that was almost as wild as the rapid rise and fall of cryptocurrency exchange FTX itself. He was found guilty of... Nov. 2, 2023. Sam Bankman-Fried, the tousle-haired mogul who founded the FTX cryptocurrency exchange, was convicted on Thursday of seven charges of fraud and conspiracy after a monthlong trial ... Nov. 3, 2023 Sam Bankman-Fried, crypto's Icarus, was convicted of seven counts of fraud and conspiracy on Thursday night after a trial that generated 10 million pages of documents and only... A look into Sam Bankman-Fried's trial so far. FTX founder Sam Bankman-Fried leaves the courthouse following his arraignment in New York City on Dec. 22, 2022. Bankman-Fried's trial just concluded ... The verdict caps a yearlong saga that took the 31-year-old Bankman-Fried from a billionaire living in a luxury apartment in the Bahamas to a defendant in one of the biggest white-collar crime..."
変換後のクエリを使った検索
search.run("what is langchain in finance context")
"LangChain is a Python framework designed to streamline AI application development, focusing on real-time data processing and integration with Large Language Models (LLMs). It offers features for data communication, generation of vector embeddings, and simplifies the interaction with LLMs, making it efficient for AI developers. It is a fundamental part of our lives and is used in many different fields, from engineering to finance. 🧬 Sequential Chains. ... maintaining context, or adding custom logic, LangChain provides the tools to elevate our interactions with language models. As the landscape of language processing continues to evolve, LangChain stands as a ... LangChain is a framework for developing applications powered by language models. The two major advantages of LangChain are: Easily connect a language model to other sources of data Allows a language model to interact with its environment LangChain provides multiple tools to work with LLM's. The one's used in this blog are: 1. LangChain is an open-source framework that's specially designed for developing applications powered by language models. [2] LangChain is an intuitive open-source framework created to simplify the development of applications using large language models (LLMs), such as OpenAI or Hugging Face. This allows you to build dynamic, data-responsive applications that harness the most recent breakthroughs in natural language processing."
というわけで、全然違う結果が得られます。
後者がきちんとLangchainに関する結果を得られていますね。
Step4. Chain
では、実際にRewriteのChainを組み込んだRAG用のChainを作成します。
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
)
template = """Answer the users question based only on the following context:
<context>
{context}
</context>
Question: {question}
"""
response_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template(template),
AIMessagePromptTemplate.from_template(""),
]
)
chat_model = ChatHuggingFaceModel(
generator=generator,
tokenizer=tokenizer,
human_message_template="GPT4 Correct User: {}<|end_of_turn|>",
ai_message_template="GPT4 Correct Assistant: {}",
repetition_penalty=1.2,
temperature=0.1,
max_new_tokens=1024,
)
chain = (
{
"context": {"x": RunnablePassthrough()} | rewriter | retriever,
"question": RunnablePassthrough(),
}
| response_prompt
| chat_model
| StrOutputParser()
)
context
に単純にRetrieverを使うのではなく、RewriterのChainを挟んでRetrieverを使うようになっています。
では、実際に実行してみましょう。
例その1: 英語クエリ
まずは、Langchain上の例を実行。
distracted_query = "man that sam bankman fried trial was crazy! what is langchain?"
for s in chain.stream(distracted_query):
print(s, end="", flush=True)
LangChain is an open-source Python framework designed to streamline AI application development, focusing on real-time data processing and integration with Large Language Models (LLMs). It offers features for data communication, generation of vector embeddings, and simplifies the interaction with LLMs, making it efficient for AI developers. It is used in various fields, from engineering to finance.
What is langchain?に答えられる内容を得られました。
例その2: 日本語クエリ
別のパターンも試してみます。
日本語で聞いてみましょう。
distracted_query = "昨日Databricksを初めて触ったんだけど、めっちゃよかった。Databricksの機能をもっと知りたい"
for s in chain.stream(distracted_query):
print(s, end="", flush=True)
Databricksは、クラウド上の統合分析プラットフォームであり、データ統合とデータ分析、AI活用を行うことができます。Databricksは多様な分野・業種で導入され、店舗需要予測やゲノム解析、品質管理などさまざまな場面で活用されています。
Databricksの特徴は、データの取り込みから分析までのプロセスを手軽に一元化できることです。Databricksは、BIから生成AIまでのソリューションを使用して、データセットの処理、保存、共有、分析、モデル化、収益化のために、データソースを1つのプラットフォームに接続するのに役立ちます。
Databricksワークスペースには、ほとんどのデータタスク用の統一されたインターフェイスとツールが用意されています。これには、データ処理ワークフローのスケジューリングと管理、ダッシュボードとビジュアライゼーションの生成、セキュリティー、ガバナンス、高可用性、およびディザスタリカバリーの管理、データの検出、アノテーション、探索、機械学習(ML)のモデリング、追跡、モデルサービング、生成AIソリューション、オープンソースとのマネージド統合が含まれます。
Azure Databricksは、データのソースを1つのプラットフォームに接続し、BIから生成AIまでのソリューションを使ってデータセットを処理、格納、共有、分析、モデル化、収益化するためのツールが用意されています。Azure Databricksワークスペースには、ほとんどのデータタスクに対応できる統合インターフェイスとツールが用意されています。
Databricksは、Apache Sparkなどのオープンソース開発者たちが2013年に創業したスタートアップであり、データレイクとデータウェアハウスを重ねた「レイクハウスカンパニー」を標榜している特徴があります。Databricksを導入する企業数は1万社を超え、従業員数はグローバルで5000名以上、日本リージョンでは現在約1500名と、企業規模・サービスともに急成長中のデータ/AIのリードカンパニーだ。
前半は質問ではありませんが、きちんと検索・回答してくれますね。
例その3: 脈絡のない日本語クエリ
では、もう少し脈絡のない内容を。
distracted_query = "昨日のお酒が残っていて、若干気持ち悪い。仕事したくないけどしないといけない。Pysparkのsum関数とは何?"
for s in chain.stream(distracted_query):
print(s, end="", flush=True)
かなり意味の分からない質問ですが、結果はどうでしょうか。
この質問は関連性のない情報を含んでいますが、Pysparkのsum関数について説明します。Pysparkのsum関数は、SparkのデータフレームやRDDの列データに対して、各要素の値を合計する関数です。例えば、以下のようなデータフレームがあるとして、sum関数を使って合計を計算できます。
データフレーム:
```
+----+-----+
| A | B |
+----+-----+
| 1 | 2 |
| 3 | 4 |
+----+-----+
```
Pysparkのsum関数を使って、列AとBの合計を計算することができます。
```python
from pyspark.sql import SparkSession
# スパークのセッションを作成
spark = SparkSession.builder.appName("example").getOrCreate()
# データフレームを作成
df = spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
# 列AとBの合計を計算
sum_A = df.agg({"A": "sum"}).first()[0]
sum_B = df.agg({"B": "sum"}).first()[0]
print("列Aの合計:", sum_A)
print("列Bの合計:", sum_B)
```
この例では、列AとBの合計を計算して、それぞれの合計を表示しています。
正しく動いている・・・ように見えますね。
ただ、実はこれはLLM自体の知識で答えています。
検証のために、Rewriterにこのクエリをそのまま渡した場合の結果を見てみましょう。
ret = rewriter.invoke({"x": "昨日のお酒が残っていて、若干気持ち悪い。仕事したくないけどしないといけない。Pysparkのsum関数とは何?"})
print(ret)
昨日のお酒の後の気持ち悪さと仕事の不欲は、PySparkのsum関数については関連性がないので、別の検索クエリをお使いください。
Question: PySpark sum function explanation
検索用クエリを正しく得られていませんね。
念のため、上の文字列をそのままDuckDuckGoで検索してみます。
q = """昨日のお酒の後の気持ち悪さと仕事の不欲は、PySparkのsum関数については関連性がないので、別の検索クエリをお使いください。
Question: PySpark sum function explanation"""
search.run(q)
'今回は、飲み会の当日と飲み会中、飲み過ぎてしまった翌日に分けて、二日酔いを早く治す方法を解説するとともに、持っておくと安心な市販薬を紹介します。 さらに、そもそも二日酔いにならないための予防法についても説明していきます。 ※この情報は2023年6月時点で更新しています。 医師・薬剤師が 選んだ市販薬を紹介 経験① 現場でよく聞かれる質問 現場で聞かれる薬の効果や副作用、飲み合わせの注意点等をご説明します。 経験② 現場で教える医薬品 悩みに合った薬や普段から自分が案内する薬など、現場で案内するものを教えます。 視点① ユーザー目線で解説 実際にどう使うのかや、ユーザー目線で 必要な情報をお伝えします。 医師・薬剤師 が悩みにお答えします! 当コラムの掲載記事に関するご注意点 1. 速攻で二日酔いの頭痛や吐き気を治す方法. 二日酔いを改善させるためには、失った水分やビタミン、ミネラルの補給、アルコールの解毒に必要な糖分の摂取などが必要 です。. 二日酔いを素早く解消したい方は、まず以下の方法を試してください。. 水や ... 冬になるとアルコールの摂取量が増えるのはなぜ?. 残念ながら、アルコールは肝臓に有害なだけではない。. 「アルコールは胃の粘膜を刺激して ... しかし、お酒の飲みすぎたときに、身体の中でどのようなことが起こって体調不良が起こるのか、はっきりとしたメカニズムはわかっていません。 これまで行われてきた研究によると、以下のような現象が絡み合った結果、二日酔いの症状があらわれるので ... 二日酔いの吐き気や頭痛、辛いですよね…。前日飲み過ぎたことを後悔しても仕方がないので、少しでも早く治るよう対策しましょう!今回は二日酔いに効く食べ物・飲み物についてご紹介します。 また、辛い二日酔いを繰り返さないために予防のポイントに...'
SUM関数の説明は得られませんでした。
とはいえ、Question: PySpark sum function explanation
というワードをRewrite後に引き出せているので、うまくパース処理や、そもそも性能のよいLLMを使えば適切に変換できるんじゃないかと思われます。
まとめ
Rewrite-Retrieve-ReadによるRAGを実践してみました。
ユースケースにもよりますが、Web検索を使ったRAGを構築する場合、このテクニックは入れておいた方がいいと思います。
もちろんWeb検索以外でも、言い換えや多言語変換が必要な際などにも使えるかと思いました。