LoginSignup
1
2

langchainとDatabricksで(私が)学ぶRAG : Adaptive RAG 前編

Last updated at Posted at 2024-04-15

導入

LangGraphのexampleを眺めていたら、Adaptive RAGという新たなRAGストラテジーのサンプルが上がっていました。

今回はこちらを魔改造してウォークスルーします。

長くなりますので、前編・後編に分割します。今回は前編。

なお、検証はDatabricks on AWSで行いました。
DBRは14.3ML LTS、g4dn.xlargeのクラスタを利用しています。

Adaptive RAGとは?

論文やgithubはこちら。

Mediumで解説されている記事もありました。

Abstractを邦訳すると、

外部知識ベースからのノンパラメトリック知識をLLMに組み込む検索拡張大規模言語モデル(LLM)は、質問応答(QA)などのいくつかのタスクで応答精度を向上させるための有望なアプローチとして浮上しています。
ただし、複雑さの異なるクエリを処理するさまざまなアプローチがありますが、それらは不必要な計算オーバーヘッドを伴う単純なクエリを処理するか、複雑な複数ステップのクエリに適切に対処できません。
ただし、すべてのユーザー要求が単純または複雑なカテゴリのいずれかに分類されるわけではありません。本研究では、クエリの複雑さに基づいて、(検索拡張)LLMに最も適した戦略を、最も単純なものから最も洗練されたものまで動的に選択できる新しい適応型QAフレームワークを提案します。
また、この選択プロセスは、モデルの実際の予測結果とデータセットに固有の帰納バイアスから得られた、自動的に収集されたラベルを使用して受信クエリの複雑さレベルを予測するようにトレーニングされた小さなLMである分類器で運用されます。
このアプローチは、バランスの取れた戦略を提供し、さまざまなクエリの複雑さに対応して、反復型およびシングルステップの検索拡張 LLM と、検索なしの LLM をシームレスに適応させます。
複数のクエリの複雑さをカバーする一連のオープンドメインQAデータセットでモデルを検証し、適応型検索アプローチを含む関連するベースラインと比較して、QAシステムの全体的な効率と精度を向上させることを示しました。

adaptiverag.png

雑な理解で言うと、ユーザクエリの難易度に応じて適切な処理パイプラインを選択・実行することで、パフォーマンス・実行速度、両方を最適化する、というものです。
特に最初のClassifierが重要ぽく、本来は適切に訓練されたLMを利用するのだと思います。

今回は、LangGraphのGithub内にある以下のノートブックに基づいて、共通のLLMを使ってRAGパイプラインを構成・実行してみます。

構成するグラフは以下のようになります。(上記ノートブックより)

image.png

クエリの内容に応じて、①ベクトルストアを使ったRAG(Self-reflective RAG)、②Web Searchを使った回答生成 のパイプラインを使い分けます。
拡張の方向性として、③LLM単体での回答 というパイプラインを加えることができるでしょう。(今回は実施しません)

また、上のノートブックではOllamaGPT4All Embeddingを用いていますが、本記事では、以下のLM Format Enforcer+ExllamaV2やDatabricks Model Servingを使ったEmbeddingを使って実装してみます。

Step1. パッケージインストール・環境変数設定

LangGraphのビジュアライゼーション機能を利用したいため、Graphvizをインストールしておきます。

%sh sudo apt-get install graphviz libgraphviz-dev pkg-config --yes

その上で、各種パッケージをインストール。

%pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu118
%pip install ninja
%pip install -U flash-attn --no-build-isolation

%pip install https://github.com/turboderp/exllamav2/releases/download/v0.0.18/exllamav2-0.0.18+cu118-cp310-cp310-linux_x86_64.whl

%pip install -U langchain langchain-chroma langgraph langchainhub lm-format-enforcer tavily-python
%pip install grandalf pygraphviz

dbutils.library.restartPython()

インターネット検索用にTavilyを利用するため、APIキーを環境変数に設定します。
APIキーは事前にDatabricks Secretsに登録しておいたものを取得します。
なお、APIキーの取得については、以下の記事が詳しいです。

import os

os.environ["TAVILY_API_KEY"] = dbutils.secrets.get("tavily", "api_key")

Step2. RAG処理用のベクトルストア作成

こちらの日本語Wikipediaからデータを取得し、ベクトルストアとして利用することにします。

Wikipediaデータの取得

import requests

def get_wikipedia_page(title: str):
    """
    Retrieve the full text content of a Wikipedia page.

    :param title: str - Title of the Wikipedia page.
    :return: str - Full text content of the page as raw string.
    """
    # Wikipedia API endpoint
    URL = "https://ja.wikipedia.org/w/api.php"

    # Parameters for the API request
    params = {
        "action": "query",
        "format": "json",
        "titles": title,
        "prop": "extracts",
        "explaintext": True,
    }

    # Custom User-Agent header to comply with Wikipedia's best practices
    headers = {"User-Agent": "tutorial/0.0.1"}

    response = requests.get(URL, params=params, headers=headers)
    data = response.json()

    # Extracting page content
    page = next(iter(data["query"]["pages"].values()))
    return page["extract"] if "extract" in page else None

# テキストデータの取得
full_document = get_wikipedia_page("葬送のフリーレン")

埋め込みモデルの準備

Databricks Model Servingに登録済みの埋め込みモデルをLangChain経由で呼び出す準備をします。
登録の仕方などの詳細は以下を参照ください。

別の埋め込みモデルへの置換ももちろん可能です。適切なものを選択してください。

from langchain_community.embeddings import DatabricksEmbeddings

endpoint_name = "bge_m3_endpoint "
embeddings = DatabricksEmbeddings(endpoint=endpoint_name)

テキストデータのチャンク化

Wikipediaから取得したテキストデータに対して、単純な日本語スプリッタを用いてチャンキングします。

from typing import Any
import requests

from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma

class JapaneseCharacterTextSplitter(RecursiveCharacterTextSplitter):
    """句読点も句切り文字に含めるようにするためのシンプルなスプリッタ"""

    def __init__(self, **kwargs: Any):
        separators = ["\n\n", "\n", "", "", " ", "", "==="]
        super().__init__(separators=separators, **kwargs)

# split it into chunks
text_splitter = JapaneseCharacterTextSplitter(chunk_size=400, chunk_overlap=80)
docs = text_splitter.create_documents([full_document])

ベクトルストアの構成

Chromaを使ってチャンクデータと埋め込みモデルからベクトルストアを構築します。

from langchain_chroma import Chroma

# チャンク化済みドキュメントと埋め込みモデルを利用してChromaでベクトルストア作成
vectorstore = Chroma.from_documents(docs, embeddings)

# Retrieverも作っておく
retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})

Retrieverを使って試しに関連文書を取得してみます。

retriever.get_relevant_documents("フリーレンの声優")
出力
[Document(page_content='=== フリーレン一行(主要人物) ===\nフリーレン (Frieren)\n声 - 種﨑敦美'),
 Document(page_content='=== 関連番組 ===\n葬送のフリーレン×ZIP!待望のアニメ化!魅力解剖SP\nテレビアニメ本放送前の2023年9月24日に日本テレビにて放送された『ZIP!』とコラボした特別番組。その後一部地域を除く日本テレビ系列各局でも順次放送された。同番組のコーナー「?よミトく!」特別版として、作品の内容や制作現場の模様が公開された。出演は水卜麻美(日本テレビアナウンサー)。インタビュー出演はかまいたち、若月佑美、種﨑敦美(フリーレン役)、市ノ瀬加那(フェルン役)、小林千晃(シュタルク役)、YOASOBI(OPアーティスト)、milet(EDアーティスト)。\nフラアニ特別編 『葬送のフリーレン』⼤感謝祭 〜⼈の⼼を知る軌跡\n2024年3月29日に日本テレビ系列にて放送された特別番組。公式サイトで募集した「もう⼀度⾒たい名シーン」のアンケートを元に振り返る名場面集。ナレーションはmilet。'),
 Document(page_content='フリーレンに師事している人間の女性魔法使い。9歳→19歳。南側諸国の戦災孤児であり、両親の死に絶望して飛び降り自殺を図ろうとしたところを勇者パーティーの僧侶ハイターに救われ、「一人で生きていける力」を得るために彼から魔法を教わっていた。9歳時にハイターを訪ねてきたフリーレンに弟子入りを志願し、4年間の修業を経て一人前の魔法使いに成長する。ハイターの死後の15歳時にフリーレンの旅に同行する。フリーレンを師として尊敬しつつも、時間感覚が人間とかけ離れている彼女が一か所に長期滞在することに辟易したり、魔法以外の生活水準が低過ぎる彼女を母親のように世話するなど、気苦労が絶えない。フリーレンと同様にあまり感情を出さない一方で結構な毒舌家であり、怒ると雰囲気で周囲を圧倒するので、フリーレンやシュタルクからは恐れられている。甘い食べ物が好きで、フリーレンと一緒に結構な量の菓子を食す場面も多い')]

Step3. LLMの読み込み/各種Chainの作成

各種推論に使うLLMをExllamaV2を使ってロードします。
LangChain + LM Format Enforcerを利用するために、下記記事で作成したカスタムクラスを利用します。

モデルはダウンロード済のEXL2フォーマットで量子化された以下を利用しました。
モデルサイズの割に性能の良いモデルです(多少日本語も使える)。

from exllamav2_json_chat import ChatExllamaV2Model

model_path = "/Volumes/training/llm/model_snapshots/models--LoneStriker--Starling-LM-7B-beta-8.0bpw-h8-exl2/"

base_llm = ChatExllamaV2Model.from_model_dir(
    model_path,
    cache_max_seq_len=4096,
    cache_4bit=True,
    low_mem=True,
    max_new_tokens=128,
    temperature=0.0,
    no_flash_attn=False,
    human_message_template="GPT4 Correct User: {}<|end_of_turn|>GPT4 Correct Assistant: ",
)

以降、ここでロードしたモデルを再利用して、LangChainのチェーンを作成していきます。

Router

まず、クエリからどの処理パイプラインへ分岐するかを判断するチェーンを作成します。
ある種、全体の中で最も大事なチェーンです。

今回はベクトルストアの中身に関連する内容を聞かれた場合はSelf-RAGのパイプラインを、それ以外の場合はWeb検索を利用するパイプラインに分岐するように、結果を返します。
結果はLM Format Enforcerを使ってJSON形式となるようにフォーマットを強制します。

### Router

from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel, conlist
from typing import Literal

class Route(BaseModel):
    datasource: Literal["web_search", "vectorstore"]

llm = ChatExllamaV2Model.from_model(base_llm)
llm.reset_json_schema(Route.schema())

prompt = PromptTemplate(
    template="""You are an expert at routing a user question to a vectorstore or web search. \n
    Use the vectorstore for questions related with 葬送のフリーレン. \n
    You do not need to be stringent with the keywords in the question related to these topics. \n
    Otherwise, use web-search. Give a binary choice 'web_search' or 'vectorstore' based on the question. \n
    Return the a JSON with a single key 'datasource' and no premable or explaination. \n
    Question to route: {question}""",
    input_variables=["question"],
)

question_router = prompt | llm | JsonOutputParser()

このチェーンを実行すると、結果として{'datasource': 'vectorstore'}のようなJSONデータを得られます。プロンプトにあるように、vectorstoreweb_searchのいずれかを得られ、これに基づいて実行するべきパイプラインを判断します。

Retrieval Grader

Self-RAGのパイプライン内で実行するチェーンです。
Retrieverが取得したドキュメントについて、質問と関連したものかどうかを判断します。
こちらも結果がJSONとして取得できるようにしています。

### Retrieval Grader
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from pydantic import BaseModel, conlist
from typing import Literal

class Grade(BaseModel):
    score: Literal["yes", "no"]

# LLM
llm = ChatExllamaV2Model.from_model(base_llm)
llm.reset_json_schema(Grade.schema())

# Prompt
prompt = PromptTemplate(
    template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
    Here is the retrieved document: \n\n {document} \n\n
    Here is the user question: {question} \n
    If the document contains keywords related to the user question, grade it as relevant. \n
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
    Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""",
    input_variables=["question", "document"],
)

retrieval_grader = prompt | llm | JsonOutputParser()

Generator

与えられたコンテキストを使って質問の回答を生成するチェーンです。

### Generate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# LLM
llm = ChatExllamaV2Model.from_model(base_llm)
llm.max_new_tokens = 512

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

Hallucination Grader

RAGによる生成結果が、コンテキストを事実として生成した回答かどうかを判定するチェーンです。
事実と異なる回答を生成した場合、生成をやり直すかどうかを判定する際に利用します。

### Hallucination Grader 

class Grade(BaseModel):
    score: Literal["yes", "no"]

# LLM
llm = ChatExllamaV2Model.from_model(base_llm)
llm.reset_json_schema(Grade.schema())

# Prompt
prompt = PromptTemplate(
    template="""You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n 
    Here are the facts:
    \n ------- \n
    {documents} 
    \n ------- \n
    Here is the answer: {generation}
    Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
    Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
    input_variables=["generation", "documents"],
)

hallucination_grader = prompt | llm | JsonOutputParser()

Answer Grader

生成した回答が質問に対して適切な内容かどうかを判断するチェーンです。
質問に対して不適切な回答を生成した場合、クエリを変換して生成をやり直すかどうかを判定する際に利用します。

### Answer Grader

# LLM
llm = ChatExllamaV2Model.from_model(base_llm)
llm.reset_json_schema(Grade.schema())

# Prompt
prompt = PromptTemplate(
    template="""You are a grader assessing whether an answer is useful to resolve a question. \n 
    Here is the answer:
    \n ------- \n
    {generation} 
    \n ------- \n
    Here is the question: {question}
    Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
    Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
    input_variables=["generation", "question"],
)

answer_grader = prompt | llm | JsonOutputParser()

Question Re-writer

質問をよりよい形へ変換するチェーンです。
Answer Graderで適切な回答生成ができなかったと判定された際に利用します。

### Question Re-writer
from operator import itemgetter

# 改善された質問のみ出力したいため、JSON形式で強制
class NewQuestion(BaseModel):
    improved_question: str

# LLM
llm = ChatExllamaV2Model.from_model(base_llm)
llm.max_new_tokens = 256
llm.reset_json_schema(NewQuestion.schema())

# Prompt 
re_write_prompt = PromptTemplate(
    template="""You a question re-writer that converts an input question to a better version that is optimized \n 
     for vectorstore retrieval. Look at the initial and formulate an improved question, 
     then just reply improved question in Japanese.\n
     Here is the initial question: \n\n {question}""",
    input_variables=["question"],
)

question_rewriter = re_write_prompt | llm | JsonOutputParser() | itemgetter("improved_question")

Step4. Toolの準備

Web検索を行う判定時に利用するツールを準備します。
Step1で設定したように、TavilyをWeb検索ツールとして利用します。

from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)

前編はここまで。

前編まとめ

Adaptive RAGサンプルウォークスルーの前編として、RAG用のベクトルストア準備から、LLM・各種チェーンの構築、Web検索用ツールの準備まで実施しました。
後編ではこれらのチェーンやツールを使って処理にパイプラインをLangGraphを使って構築します。
その上で、いくつか実際の推論を実行する予定です。

しかし、長い。。。


↓ 後編作成しました。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2