LoginSignup
4
3

langchainとDatabricksで(私が)学ぶRAG : LangGraphとローカルLLMによるAgentを使ったRAG

Last updated at Posted at 2024-03-06

最近、複雑な実装内容が増えてきたので、もうちょっと手軽な記事を増やしたい。

導入

LangGraphのExampleの中に、エージェントを活用したRAGの実装例があることに気づきました。

こちらを土台に、ローカルLLMを使う形に魔改造修正・実装してみます。

実装および検証はDatabricks on AWSを利用しました。
DBRは14.3ML、クラスタタイプはg5.xlarge(GPUクラスタ)です。

これは何?

Retrieval Agentによる関連文書取得の判定、および文書の品質チェックと再処理を組み込んだRAGのパイプラインになります。
Retrieval Agentはクエリ内容に基づいて文書検索の必要有無を判定し、必要に応じて文書検索を行うエージェントです。これによって、より柔軟に文書検索を制御することができます。

処理の流れ

以下のような流れとなります。

image.png

各ノードと条件分岐エッジの処理は以下のようになります。

種類 名称 処理概要
ノード Agent Retrieverをツールとして利用するかどうかを判断するノード処理。今回はRetrieverを使う・使わないの判定に加えて、複数のRetrieverから利用するRetrieverを選定する処理も担います。
条件分岐エッジ Should Retrieve Retrieverの実行をするかどうかの判定結果から、遷移するノードを決定します。Retrieverを使わない判定の場合、そのまま処理を終了します。
ノード Tool 選択したRetrieverを実行し、関連文書を取得します。
条件分岐エッジ Check Relevance Retrieverから取得した文書が、ユーザからのクエリと実際に関連するかどうかを判定し、その結果に応じて遷移先のノードを決定します。関連する場合はGenerateノードに遷移し、関連しない場合はRewriteノードに遷移してクエリ変換を行った上で再度Agentからやり直す流れとなります。
ノード Generate Retrieverから取得した文書を用いて、最終的な回答を生成します。
ノード Rewrite より関連しやすい文書取得になるようにクエリを書き換えます。

それでは、実装&実際の動作を見ていきましょう。

Step1. パッケージのインストール

必要なパッケージをインストール。
今回もSGLangをローカルLLLMの推論パッケージとして利用します。
このあたりの解説はこちらと同じです。

# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl

%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl

# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.3.2/vllm-0.3.2+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install /Volumes/training/llm/tmp/vllm-0.3.2+cu118-cp310-cp310-manylinux1_x86_64.whl

# sglang==0.1.12の場合、outlines>0.030だとエラーが出るためoutlinesのバージョンを固定
%pip install "outlines>=0.0.27,<=0.0.30"
%pip install "sglang[srt]>=0.1.12"
# %pip install "triton>=2.2.0"
%pip install /Volumes/training/llm/tmp/triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
%pip install -U langchain langchain-openai langgraph langchainhub sqlalchemy chromadb

dbutils.library.restartPython()
import torch
torch.multiprocessing.set_start_method('spawn', force=True)

Step2. モデルのロード

Embedding Model

埋め込み用のモデルをLangChainのHuggingFaceBgeEmbeddingsクラスを使って読み込んでおきます。
このモデルは、ベクトルストアの作成や検索において利用します。

埋め込み用のモデルは、事前にダウンロードしておいたBGE-M3を利用しました。

import torch
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = "/Volumes/training/llm/model_snapshots/models--BAAI--bge-m3"
model_kwargs = {"device": device}
encode_kwargs = {"normalize_embeddings": True}  # Cosine Similarity

embedding = HuggingFaceBgeEmbeddings(
    model_name=model_path,
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs,
)

LLM

次にSGLangを使って、LLMをロードします。
今回はAWQフォーマットで量子化されたQwen1.5 14Bを利用しました。

モデルは事前にダウンロードしたものを利用しています。

from sglang import function, system, user, assistant, gen, set_default_backend, Runtime

from sglang.lang.chat_template import (
    get_chat_template,
    register_chat_template,
    ChatTemplate,
)

# チャット用のプロンプトテンプレート
register_chat_template(
    ChatTemplate(
        name="qwen",
        default_system_prompt=None,
        role_prefix_and_suffix={
            "system": ("<|im_start|>system\n", "<|im_end|>\n"),
            "user": ("<|im_start|>user\n", "<|im_end|>\n"),
            "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
        },
    )
)

model_path = (
    "/Volumes/training/llm/model_snapshots/models--Qwen--Qwen1.5-14B-Chat-AWQ"
)

runtime = Runtime(model_path, mem_fraction_static=0.6)
runtime.endpoint.chat_template = get_chat_template("qwen")

set_default_backend(runtime)

Step3. Retrieverの準備

検索文書として、Wikipediaから文書を取得し、ベクトルストアへ格納することにします。
Wikipediaのタイトルを指定してベクトルストアに保管後、LangChainのRetrieverを取得する関数を準備。

from typing import Any
import requests

from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma

class JapaneseCharacterTextSplitter(RecursiveCharacterTextSplitter):
    """句読点も句切り文字に含めるようにするためのシンプルなスプリッタ"""

    def __init__(self, **kwargs: Any):
        separators = ["\n\n", "\n", "", "", " ", ""]
        super().__init__(separators=separators, **kwargs)

def get_wikipedia_page(title: str):
    """
    日本語Wikipediaから情報を取得
    """

    URL = "https://ja.wikipedia.org/w/api.php"

    params = {
        "action": "query",
        "format": "json",
        "titles": title,
        "prop": "extracts",
        "explaintext": True,
    }

    # Wikipediaのベストプラクティスに則って、カスタムユーザエージェントをヘッダにに設定
    headers = {"User-Agent": "Langchain+Databricks+RAG/0.0.1"}

    response = requests.get(URL, params=params, headers=headers)
    data = response.json()

    # コンテンツを取得
    page = next(iter(data["query"]["pages"].values()))
    return page["extract"] if "extract" in page else None

def create_retriever_from_wikipedia(title:str, collection_name:str):
    """ Wikipediaの情報をもとにVectorstore/Retrieverを作成 """

    docs = [get_wikipedia_page(title)]
    docs_list = [Document(page_content=d) for d in docs]

    text_splitter = JapaneseCharacterTextSplitter(chunk_size=512, chunk_overlap=40)
    doc_splits = text_splitter.split_documents(docs_list)

    # VectorDBであるChromaに保管
    vectorstore = Chroma.from_documents(
        documents=doc_splits,
        collection_name=collection_name,
        embedding=embedding,
    )

    # Retriever取得。検索件数は1件固定
    return vectorstore.as_retriever(search_kwargs={"k": 1})    

前述のcreate_retriever_from_wikipedia関数を使って2種類の異なるWikipediaページを使ったベクトルストアをそれぞれ作成します。
その上で、LangChainのcreate_retriever_tool関数を使ってLangChainのツール化を行います。

今回はこの2種のツール(2種のRetriever)をクエリによってエージェントに使い分けさせます。

from langchain.tools.retriever import create_retriever_tool
from langgraph.prebuilt import ToolExecutor

tool1 = create_retriever_tool(
    create_retriever_from_wikipedia("葬送のフリーレン", "rag-chroma1"),
    "retriever1",
    "Search and return information about 葬送のフリーレン.",
)

tool2 = create_retriever_tool(
    create_retriever_from_wikipedia("Apache_Spark", "rag-chroma2"),
    "retriever2",
    "Search and return information about Apache Spark.",
)


tools = [tool1, tool2]

tool_executor = ToolExecutor(tools)

Step4. SGLangによる推論処理の定義

SGLangで推論処理させるための単純な関数を定義。
シンプルにinputパラメータの内容をユーザ指示として回答生成する推論処理です。
ただし、regexパラメータでJSONスキーマを指定することでJSONフォーマットの回答を生成するようにしています。

@function
def _llm(s, input, ai_message_prefix="", regex=None):
    s += user(input)
    s += assistant(
        ai_message_prefix
        + gen(
            "answer",
            max_tokens=2048,
            temperature=0,
            regex=regex,
        )
    )

Step5. Function Calling模倣処理用のプロンプト定義

オリジナルのサンプルノートブックでは、OpenAIのfunction callingを使ってエージェント処理を実装していました。
ローカルLLMで同様の処理を実現するために、専用のプロンプトテンプレートを作るための処理を定義します。
処理としてはツールの名前とその説明をプロンプト内に埋め込み、その情報を使ってクエリに対応する適切なツールを選択、JSONで結果を返すようにしています。

def create_function_calling_prompt(tools):
    """ Tool選択のためのプロンプトテンプレートを生成 """

    tool_names = ", ".join([t.name for t in tools])
    tools_desc = "\n".join([t.name + ": " + t.description for t in tools])

    prompt = PromptTemplate(
        template=("TOOLS\n------\n"
                  "Assistant can ask the user to use tools to look up information that may be helpful in answering the users original question." 
                  "The tools the human can use are:\n\n"
                  "{tools_desc}\n\n"
                  "RESPONSE FORMAT INSTRUCTIONS\n----------------------------\n\n"
                  "When responding to me, please output a response in one of two formats:\n\n"
                  "**Option 1:**\n"
                  "Use this if you want the human to use a tool.\n"
                  "Markdown code snippet formatted in the following schema:\n\n"
                  "```json\n"
                  '{{\n'
                  '    "action": string, \\\\ The action to take. Must be one of tools: {tool_names}\n'
                  '    "action_input": string \\\\ The input to the action\n'
                  '}}\n'
                  '```\n\n'
                  "**Option #2:**\n"
                  "Use this if you can respond directly to the human after tool execution. "
                  "Markdown code snippet formatted in the following schema:\n\n"
                  "```json\n"
                  '{{\n'
                  '    "action": "Final Answer",\n'
                  '    "action_input": string \\\\ You should put what you want to return to use here\n'
                  '}}\n'
                  '```\n\n'
                  "USER\'S INPUT\n--------------------\n"
                  "Here is the user\'s input (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):\n\n"
                  "{input}"),
        input_variables=["input"],
        partial_variables={"tool_names":tool_names, "tools_desc":tools_desc},
    )

    return prompt

Step6. Graphの状態クラス定義

ここからグラフに関する処理に入ります。

まずはグラフの状態クラスを定義します。
状態としてメッセージの履歴を保持するクラスとなります。

import operator
from typing import Annotated, Sequence, TypedDict

from langchain_core.messages import BaseMessage


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

Step7. ノードとエッジ処理の定義

グラフのノードと条件分岐のエッジ処理を定義します。

まず、必要なモジュールをインポート。

import json
import operator
from typing import Annotated, Sequence, TypedDict
from enum import Enum

from langchain_core.prompts import PromptTemplate
from langchain_core.messages import (
    BaseMessage,
    FunctionMessage,
    AIMessage,
    HumanMessage,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableLambda

from langgraph.prebuilt import ToolInvocation

from sglang.srt.constrained import build_regex_from_object

Nodes

ここからノードの処理定義です。

Agentノード

まずAgentノードに対応する処理を定義。
Step5で作成したFunction Callingを模倣するプロンプトを利用して、2種のツール(Retriever)から適正なもの、もしくは使用しないという判断を選択させます。

def agent(state):
    print("---CALL AGENT---")
    messages = state["messages"]

    def predict(input):
        action_names = "|".join([t.name for t in tools]) + "|Final Answer"
        tool_regex = (
            r"""\{\n"""
            + rf"""    "action": "({action_names})",\n"""
            + r"""    "action_input": ".*"\n"""
            + r"""\}"""
        )

        state = _llm(input.text, regex=tool_regex)
        return json.loads(state["answer"])

    prompt = create_function_calling_prompt(tools)
    chain = prompt | RunnableLambda(predict)

    response = chain.invoke({"input": messages})

    response_message = AIMessage(
        content=response["action_input"],
    )
    if response["action"] != "Final Answer":
        # function callingを模した結果を出力
        response_message.additional_kwargs = {
            "function_call": {
                "name": response["action"],
                "arguments": json.dumps({"query": response["action_input"]}),
            }
        }

    # We return a list, because this will get added to the existing list
    return {"messages": [response_message]}

Toolノード

選択したツールを実行するノードです。
実行結果(今回は関連文書の検索結果)をFunctionMessageに詰めてグラフの状態に追加しています。

def retrieve(state):
    print("---EXECUTE RETRIEVAL---")
    messages = state["messages"]
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    action = ToolInvocation(
        tool=last_message.additional_kwargs["function_call"]["name"],
        tool_input=json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        ),
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)

    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

Rewriteノード

検索文書がクエリに対して関連しないものだった場合、クエリをより適したものに言い替える処理です。
言い替えたクエリを使って、再びAgentノードから実行し直すようにグラフを構築します。

def rewrite(state):
    print("---TRANSFORM QUERY---")
    messages = state["messages"]
    question = messages[0].content

    def predict(input):
        state = _llm(input.text, ai_message_prefix="Formulate an improved question:", regex=None)
        return state["answer"]
    
    prompt = PromptTemplate(
        template=(
            "Look at the input and try to reason about the underlying semantic intent / meaning. \n\n"
            "Here is the initial question:"
            "\n ------- \n"
            "{question}"
        ),
        input_variables=["question"],
    )
    chain = prompt | RunnableLambda(predict)
    response = AIMessage(content=chain.invoke({"question":question}))

    return {"messages": [response]}

Generateノード

クエリおよび検索した関連文書を使って最終的な回答を出力するノードです。

def generate(state):
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content

    def predict(input):
        state = _llm(input.text, regex=None)
        return state["answer"]

    prompt = PromptTemplate(
        template=(
            "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. "
            "If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n"
            "You must reply in Japanese.\n"
            "Question: {question} \n"
            "Context: {context} "
        ),
        input_variables=["question", "context"],
    )
    chain = prompt | RunnableLambda(predict)

    # Run
    response = AIMessage(content=chain.invoke({"context": docs, "question": question}))
    return {"messages": [response]}

Edges

ここから条件分岐のためのエッジ処理定義です。

Should Retrieve条件分岐エッジ

ツール利用の判断に基づいて、処理を継続するのか、それともそのまま終了するかの条件分岐処理を定義します。
検索する必要性がない、となった時点でこのグラフは終了となります。

def should_retrieve(state):
    print("---DECIDE TO RETRIEVE---")
    messages = state["messages"]
    last_message = messages[-1]

    # If there is no function call, then we finish
    if "function_call" not in last_message.additional_kwargs:
        print("---DECISION: DO NOT RETRIEVE / DONE---")
        return "end"
    # Otherwise there is a function call, so we continue
    else:
        print("---DECISION: RETRIEVE---")
        return "continue"

Check Relevance条件分岐エッジ

検索して取得した関連文書が、ユーザクエリと関連があるかどうかを判定します。
関連している場合は最終回答の生成を、そうでない場合はクエリを書き替えて検索判断からやり直す流れになります。
関連判断もLLMを用いて実施しており、結果をJSONフォーマットになるように強制しています。

def grade_documents(state):
    print("---CHECK RELEVANCE---")

    # Data model
    class yesno(str, Enum):
        yes = "yes"
        no = "no"

    class grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: yesno = Field(description="Relevance score 'yes' or 'no'")

    def judge_grade(input):
        state = _llm(input.text, regex=build_regex_from_object(grade))
        return grade.parse_raw(state["answer"])

    # Prompt
    prompt = PromptTemplate(
        template=(
            "You are a grader assessing relevance of a retrieved document to a user question. \n"
            "Here is the retrieved document: \n\n {context} \n\n"
            "Here is the user question: {question} \n"
            "If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n"
            "Give a binary score 'yes' or 'no' score to indicate whether the retrieved document is relevant to the question."
        ),
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | RunnableLambda(judge_grade)

    messages = state["messages"]
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content
    score = chain.invoke({"question": question, "context": docs})

    grade = score.binary_score

    if grade == "yes":
        print("---DECISION: DOCS RELEVANT---")
        return "yes"

    else:
        print("---DECISION: DOCS NOT RELEVANT---")
        return "no"

Step8. グラフの構築

ノードとエッジの処理定義が終わりましたので、それらを用いてLangGraphのグラフを構成します。
上段で示した処理の流れを構成しています。

from langgraph.graph import END, StateGraph

# Define a new graph
workflow = StateGraph(AgentState)

# Define the nodes we will cycle between
workflow.add_node("agent", agent)
workflow.add_node("retrieve", retrieve)
workflow.add_node("rewrite", rewrite)
workflow.add_node("generate", generate)

# Call agent node to decide to retrieve or not
workflow.set_entry_point("agent")

# Decide whether to retrieve
workflow.add_conditional_edges(
    "agent",
    # Assess agent decision
    should_retrieve,
    {
        # Call tool node
        "continue": "retrieve",
        "end": END,
    },
)

# Edges taken after the `action` node is called.
workflow.add_conditional_edges(
    "retrieve",
    # Assess agent decision
    grade_documents,
    {
        "yes": "generate",
        "no": "rewrite",  
    },
)

# Simple edges
workflow.add_edge("rewrite", "agent")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

長かった。。。ここまででRAG用のグラフが構築できました。

Step9. 推論

では、構築したグラフを使って推論してみましょう。

まず、葬送のフリーレンwikipediaに回答が含まれる質問を投げてみます。

最終結果だけでなく、LangGraphのstream機能を使って各ノードごとの状態も出力してみます。

from pprint import pprint

inputs = {
    "messages": [
        HumanMessage(
            content="フリーレンの弟子は誰?"
            # content="adsfabba"            
        )
    ]
}

for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Output from node '{key}':")
        pprint("---")
        pprint(value, indent=2, width=80, depth=None)
    print("\n---\n")

print("Result:")
print(output["__end__"]["messages"][-1].content)
出力
---CALL AGENT---
"Output from node 'agent':"
'---'
{ 'messages': [ AIMessage(content='フリーレンの弟子', additional_kwargs={'function_call': {'name': 'retriever1', 'arguments': '{"query": "\\u30d5\\u30ea\\u30fc\\u30ec\\u30f3\\u306e\\u5f1f\\u5b50"}'}})]}

---

---DECIDE TO RETRIEVE---
---DECISION: RETRIEVE---
---EXECUTE RETRIEVAL---
"Output from node 'retrieve':"
'---'
{ 'messages': [ FunctionMessage(content='さらに20年後、フリーレンはもうひとりの仲間であったハイターを訪ねる。ヒンメルと同じく老い先短い身であったハイターは、魔導書の解読と戦災孤児の少女フェルンを弟子にすることを依頼。その4年後に魔導書の解読を終えたフリーレンと、一人前の魔法使いに成長したフェルンは、ハイターの最期を看取ったあとに諸国をめぐる旅に出る。\nその後フリーレンたちは、最後に訪ねた仲間であるアイゼンの助力を受けて、フリーレンの師匠にして伝説の大魔法使いフランメの手記を入手。その手記には、かつての魔王城があった大陸北端の地エンデにあるという、死者の魂と対話できる場所・オレオールの存在が記されていた。オレオールで亡きヒンメルと再会するという新たな目的ができたフリーレンは、アイゼンの弟子である少年戦士シュタルクや、行方不明の親友との再会を望む僧侶ザインという新たな仲間たちを加えて北方を目指す。', name='retriever1')]}

---

---CHECK RELEVANCE---
---DECISION: DOCS RELEVANT---
---GENERATE---
"Output from node 'generate':"
'---'
{ 'messages': [ AIMessage(content='フリーレンの弟子は、戦災孤児の少女フェルンです。フェルンは4年後に一人前の魔法使いになり、ハイターの最期を看取った後、弗里倫と旅に出ました。\n')]}

---

"Output from node '__end__':"
'---'
{ 'messages': [ HumanMessage(content='フリーレンの弟子は誰?'),
                AIMessage(content='フリーレンの弟子', additional_kwargs={'function_call': {'name': 'retriever1', 'arguments': '{"query": "\\u30d5\\u30ea\\u30fc\\u30ec\\u30f3\\u306e\\u5f1f\\u5b50"}'}}),
                FunctionMessage(content='さらに20年後、フリーレンはもうひとりの仲間であったハイターを訪ねる。ヒンメルと同じく老い先短い身であったハイターは、魔導書の解読と戦災孤児の少女フェルンを弟子にすることを依頼。その4年後に魔導書の解読を終えたフリーレンと、一人前の魔法使いに成長したフェルンは、ハイターの最期を看取ったあとに諸国をめぐる旅に出る。\nその後フリーレンたちは、最後に訪ねた仲間であるアイゼンの助力を受けて、フリーレンの師匠にして伝説の大魔法使いフランメの手記を入手。その手記には、かつての魔王城があった大陸北端の地エンデにあるという、死者の魂と対話できる場所・オレオールの存在が記されていた。オレオールで亡きヒンメルと再会するという新たな目的ができたフリーレンは、アイゼンの弟子である少年戦士シュタルクや、行方不明の親友との再会を望む僧侶ザインという新たな仲間たちを加えて北方を目指す。', name='retriever1'),
                AIMessage(content='フリーレンの弟子は、戦災孤児の少女フェルンです。フェルンは4年後に一人前の魔法使いになり、ハイターの最期を看取った後、弗里倫と旅に出ました。\n')]}

---

Result:
フリーレンの弟子は、戦災孤児の少女フェルンです。フェルンは4年後に一人前の魔法使いになり、ハイターの最期を看取った後、弗里倫と旅に出ました。

グラフの状態から、Agentノードで適切なRetriever(retriever1)が選ばれ、文書検索・回答生成を行っていることがわかります。


状態の表示出力を最終状態と結果だけにして、別のバリエーションも試してみます。

今度はRetriever2に回答が含まれる、Apache Sparkに関する質問です。

inputs = {
    "messages": [
        HumanMessage(
            content="SparkSQLとはどのような特徴をもった機能?"
        )
    ]
}

result = app.invoke(inputs)
print(result["messages"][-1].content)
出力
---CALL AGENT---
---DECIDE TO RETRIEVE---
---DECISION: RETRIEVE---
---EXECUTE RETRIEVAL---
---CHECK RELEVANCE---
---DECISION: DOCS RELEVANT---
---GENERATE---
-----
{'messages': [HumanMessage(content='SparkSQLとはどのような特徴をもった機能?'),
              AIMessage(content='SparkSQLの特徴', additional_kwargs={'function_call': {'name': 'retriever2', 'arguments': '{"query": "SparkSQL\\u306e\\u7279\\u5fb4"}'}}),
              FunctionMessage(content='=== Spark SQL ===\nSpark Coreより上位のコンポーネントで、構造化データや半構造化データをサポートするDataFramesというデータ抽象化を導入した。Scala、Java、PythonのDataFramesを操作するためのドメイン固有言語(DSL)を提供しており、キャラクタユーザインタフェースとOpen Database Connectivity/JDBCサーバとのSQL言語サポートも実装している。DataFramesには、Spark 2.0のようにRDDによって提供されるコンパイル時型チェック機能はないが、強く型付けされたデータセットはSpark SQLでも完全にサポートされている。', name='retriever2'),
              AIMessage(content='SparkSQLは、DataFramesというデータ抽象化を提供し、Scala、Java、Pythonのドメイン固有言語(DSL)を介して操作可能で、またキャラクタユーザインタフェースとSQL言語サポートも実装しています。RDDのコンパイル時型チェック機能はspark 2.0以前にないが、型付けされたデータセットは完全にサポートされています。\n')]}
-----
SparkSQLは、DataFramesというデータ抽象化を提供し、Scala、Java、Pythonのドメイン固有言語(DSL)を介して操作可能で、またキャラクタユーザインタフェースとSQL言語サポートも実装しています。RDDのコンパイル時型チェック機能はspark 2.0以前にないが、型付けされたデータセットは完全にサポートされています。

最終の状態から、retriever2という適切なツールが選ばれていることがわかります。
想定した挙動になっていますね。


では、意味不明なクエリを渡すとどうなるでしょうか。

inputs = {
    "messages": [
        HumanMessage(
            content="asefff?"
        )
    ]
}

result = app.invoke(inputs)
pprint(result)
print("-----")
print(result["messages"][-1].content)
出力
---CALL AGENT---
---DECIDE TO RETRIEVE---
---DECISION: DO NOT RETRIEVE / DONE---
{'messages': [HumanMessage(content='asefff?'), AIMessage(content="It seems like your query is incomplete or unclear. Could you please provide more context or specify what information you are looking for regarding 'asefff' or '葬送のフリーレン' or 'Apache Spark'?")]}
-----
It seems like your query is incomplete or unclear. Could you please provide more context or specify what information you are looking for regarding 'asefff' or '葬送のフリーレン' or 'Apache Spark'?

Retrieverの選択を行わないという判断となり、グラフの処理が終了しました。
※ 最終出力するメッセージを工夫すると、よりよい処理になりそうですね。


今回準備した文書に答えがないことを聞いてみます。

inputs = {
    "messages": [
        HumanMessage(
            content="Github Copilotとは?"
        )
    ]
}

result = app.invoke(inputs)
print("-----")
pprint(result)
print("-----")
print(result["messages"][-1].content)
出力
---CALL AGENT---
---DECIDE TO RETRIEVE---
---DECISION: RETRIEVE---
---EXECUTE RETRIEVAL---
---CHECK RELEVANCE---
---DECISION: DOCS NOT RELEVANT---
---GENERATE---
---TRANSFORM QUERY---
---CALL AGENT---
---DECIDE TO RETRIEVE---
---DECISION: DO NOT RETRIEVE / DONE---
-----
{'messages': [HumanMessage(content='Github Copilotとは?'),
              AIMessage(content='Github Copilot', additional_kwargs={'function_call': {'name': 'retriever2', 'arguments': '{"query": "Github Copilot"}'}}),
              FunctionMessage(content='RDDの可用性は、ループ内で複数回データセットを参照する反復法アルゴリズム、および対話型/探索型データ分析、データ反復のデータベースクエリの両方の実装を容易にする。このようなアプリケーションのレイテンシ(Apache Hadoopスタックでは一般的であったMapReduce実装と比較して)は、桁違いに低下する可能性がある。反復アルゴリズムのクラスの中には、 機械学習のための訓練アルゴリズムがあり、Apache Sparkを開発の初期の刺激となった。', name='retriever2'),
              AIMessage(content=' \n\nWhat is Github Copilot and how does it work?\n'),
              AIMessage(content='Github Copilotは、Apache Sparkやそのようなデータ分析のためのツールをサポートするツールであり、特にデータ反復や対話型データ解析のための効率的な実装を提供します。機械学習のアルゴリズムも含むこのアシスタントは、開発者にコード生成や自動補完の助けを提供することを目的としています。\n'),
              AIMessage(content='Github Copilotは、GitHubが開発したコード自動補完ツールです。このツールは、プログラミング中にコードを自動的に提案し、入力を補完することで、開発者のタスクを効率化します。特に、Apache Sparkやデータ分析のフレームワークで、データ反復や対話型データ解析のためのコードの生成や最適なパターンの提案が可能です。機械学習アルゴリズムの一部もサポートし、コードの自動完成や自動完結を支援します。')]}
-----
Github Copilotは、GitHubが開発したコード自動補完ツールです。このツールは、プログラミング中にコードを自動的に提案し、入力を補完することで、開発者のタスクを効率化します。特に、Apache Sparkやデータ分析のフレームワークで、データ反復や対話型データ解析のためのコードの生成や最適なパターンの提案が可能です。機械学習アルゴリズムの一部もサポートし、コードの自動完成や自動完結を支援します。

わかりづらいですが、Retriever(retriever2)を使う判断を最初に行ったのですが、取得文書とクエリの関連性が低いということでRewriteノードを実行→Agentノードを再実行という流れになっています。
2周目では、結局Retrieverを使わないという判断となり、LLMが単体で考え生成した回答を得ています。


同様に文書に答えがない別の質問を聞いてみます。

inputs = {
    "messages": [
        HumanMessage(
            content="PowerBIとは?"
        )
    ]
}

result = app.invoke(inputs)
print("-----")
pprint(result)
print("-----")
print(result["messages"][-1].content)
出力
---CALL AGENT---
---DECIDE TO RETRIEVE---
---DECISION: DO NOT RETRIEVE / DONE---
-----
{'messages': [HumanMessage(content='PowerBIとは?'),
              AIMessage(content='PowerBIは、Microsoftによって提供されるビジネスインテリジェンスとデータ分析のプラットフォームです。ユーザーは、データを視覚化し、分析し、レポートを作成し、ビジネスの洞察を得るためにデータを可視化することができます。PowerBIは、オンラインのサービス(Power BI Desktop、Power BI Service)とオフラインのツール(Power BI Desktop)の両方を提供しており、データの整合、分析、および共有が可能です。')]}
-----
PowerBIは、Microsoftによって提供されるビジネスインテリジェンスとデータ分析のプラットフォームです。ユーザーは、データを視覚化し、分析し、レポートを作成し、ビジネスの洞察を得るためにデータを可視化することができます。PowerBIは、オンラインのサービス(Power BI Desktop、Power BI Service)とオフラインのツール(Power BI Desktop)の両方を提供しており、データの整合、分析、および共有が可能です。

今度は最初からツールの使用をしない判断となり、LLM単体で回答されました。
Qwen1.5はPowerBIをよく知っている模様。

まとめ

エージェントを使ったRAG処理をLangGraph+ローカルLLMで実装してみました。

かなり複雑ですが、応用すればRetrieverの使い分けをエージェントで判断させるなど、より柔軟かつ高品質なRAGのフローを構築することができそうです。
CRAGSelf-RAGとも組み合わせると、かなり高品質なパイプラインを構築できると思います。
RewriteのところはStep Back プロンプトを使うのもよさそう。

しかし、エージェント周りをローカルLLMでやるのはそれなりにしんどいですね。
SGLangのようなJSON形式を強制できるような仕組を備えていないと難しいものがあります。
一方で、function callingに近いものがローカルLLMでも出来るようになってきているということでもあります。

この領域の発展を期待しています。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3