0
0
生成AIに関する記事を書こう!
Qiita Engineer Festa20242024年7月17日まで開催中!

DatabricksとLangGraphで学ぶAgenticアプローチ: 会話履歴の管理

Posted at

導入

LangGraphのHow-to Guideウォークスルーの5回目です。

今回は、こちらの内容である「会話履歴を管理する方法」をウォークスルーしてみます。

検証はDatabricks on AWS、DBRは15.3MLを使っています。

会話履歴の管理

前回、エージェントに状態の永続性を追加する方法を試してみました。永続性の具体的なユースケースとしては会話履歴を管理することだと思います。

以下、公式ドキュメントの序文を邦訳。

会話履歴を管理する方法

永続化の最も一般的な使用例の 1 つは、会話履歴を追跡するために使用することです。
これは素晴らしいことであり、会話を続けるのが簡単になります。
ただし、会話が長くなるにつれて、この会話履歴が蓄積され、コンテキスト ウィンドウをますます占有する可能性があります。
これは、LLMへのより高価で長い呼び出しにつながり、エラーが発生する可能性があるため、多くの場合、望ましくない場合があります。
このノートブックでは、これに対処する方法に関するいくつかの戦略について説明します。

というわけで、より実践的な会話履歴を管理する方法です。公式ドキュメントのコードをウォークスルーしてみましょう。

Step1. パッケージインストール

LangGraphやLangChainなど、必要なパッケージをインストール。

%pip install -U langgraph==0.1.4 langchain==0.2.6 langchain-community==0.2.6 mlflow-skinny[databricks]==2.14.1 pydantic==2.7.4
dbutils.library.restartPython()

Step2. グラフの構築

まずはサンプルに沿って永続化するためのグラフを作成します。
シンプルなReActスタイルのエージェントを構築します。

モデルは以前の記事で作成したDatabricks Model Servingのエンドポイントを流用します。

from typing import Literal

from langchain_community.chat_models import ChatDatabricks
from langchain_core.tools import tool

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import MessagesState, StateGraph, START
from langgraph.prebuilt import ToolNode

import mlflow

memory = SqliteSaver.from_conn_string(":memory:")


@tool
@mlflow.trace(span_type="tool")
def search(query: str):
    """Web検索を実行します。"""
    # これは実際の実装のためのプレースホルダーです
    # ただし、LLMにはこのことを知らせないでください 😊
    return ["サンフランシスコは晴れですが、あなたがGeminiなら気をつけてください 😈."]


tools = [search]
tool_node = ToolNode(tools)
endpoint_name = "mistral-7b-instruct-v03-endpoint "
model = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)
# bound_model = model.bind_tools(tools) # ツールバインドは実行できない

@mlflow.trace(span_type="edge")
def should_continue(state: MessagesState) -> Literal["action", "__end__"]:
    """次に実行するノードを返します。"""
    last_message = state["messages"][-1]
    # 関数呼び出しがない場合、終了します
    if not last_message.tool_calls:
        return "__end__"
    # それ以外の場合は続行します
    return "action"


# モデルを呼び出す関数を定義します
@mlflow.trace(span_type="node")
def call_model(state: MessagesState):
    response = model.invoke(state["messages"])
    # 既存のリストに追加されるため、リストを返します
    return {"messages": response}


# 新しいグラフを定義します
workflow = StateGraph(MessagesState)

# サイクルする2つのノードを定義します
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)

# エントリーポイントを `agent` に設定します
# これは、このノードが最初に呼び出されることを意味します
workflow.add_edge(START, "agent")

# 条件付きエッジを追加します
workflow.add_conditional_edges(
    # まず、開始ノードを定義します。`agent` を使用します。
    # これは、`agent` ノードが呼び出された後に取られるエッジを意味します。
    "agent",
    # 次に、次に呼び出されるノードを決定する関数を渡します。
    should_continue,
)

# `tools` から `agent` への通常のエッジを追加します。
# これは、`tools` が呼び出された後に `agent` ノードが次に呼び出されることを意味します。
workflow.add_edge("action", "agent")

# 最後に、これをコンパイルします!
# これをLangChain Runnableにコンパイルします。
# つまり、他のランナブルと同様に使用できます
app = workflow.compile(checkpointer=memory)

構築したエージェント処理を実行してみます。

from langchain_core.messages import HumanMessage

# thread_id:2として記憶させる
config = {"configurable": {"thread_id": "2"}}
input_message = HumanMessage(content="hi! I'm bob")
with mlflow.start_span("graph", span_type="AGENT") as span:
    for event in app.stream(
        {"messages": [input_message]}, config, stream_mode="values"
    ):
        event["messages"][-1].pretty_print()


input_message = HumanMessage(content="what's my name?")
with mlflow.start_span("graph", span_type="AGENT") as span:
    for event in app.stream(
        {"messages": [input_message]}, config, stream_mode="values"
    ):
        event["messages"][-1].pretty_print()
出力
================================ Human Message =================================

hi! I'm bob
================================== Ai Message ==================================

 Hello Bob! How can I help you today? Is there something specific you'd like to talk about or ask me questions on? I'm here to help with a wide range of topics, from answering questions to providing information and even engaging in friendly conversation. What's on your mind?
================================ Human Message =================================

what's my name?
================================== Ai Message ==================================

 Your name is Bob, as you mentioned earlier. How can I assist you today, Bob? Is there something specific you'd like to know or discuss? I'm here to help with a wide range of topics, from answering questions to providing information and even engaging in friendly conversation. What's on your mind?

MLflow Tracingの出力は以下のようになります。

image.png

Step3. メッセージのフィルタリング

続けての処理について、公式ドキュメントの内容は以下のように記載されています。

会話履歴が大量に失われるのを防ぐ最も簡単な方法は、LLMに渡される前にメッセージのリストをフィルタリングすることです。
これには、メッセージをフィルター処理する関数の定義と、それをグラフに追加するという 2 つの部分が含まれます。
以下の例では、非常に単純なfilter_messages関数を定義して使用しています。

というわけで、前に定義したグラフに対して、メッセージをフィルタリングする機能をつけて再定義します。

from typing import Literal

from langchain_community.chat_models import ChatDatabricks
from langchain_core.tools import tool

from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import MessagesState, StateGraph, START
from langgraph.prebuilt import ToolNode

memory = SqliteSaver.from_conn_string(":memory:")


@tool
@mlflow.trace(span_type="tool")
def search(query: str):
    """Web検索を実行します。"""
    # これは実際の実装のためのプレースホルダーです
    # ただし、LLMにはこのことを知らせないでください 😊
    return ["サンフランシスコは晴れですが、あなたがGeminiなら気をつけてください 😈."]


tools = [search]
tool_node = ToolNode(tools)
endpoint_name = "mistral-7b-instruct-v03-endpoint "
model = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)
# bound_model = model.bind_tools(tools) # ツールバインドは実行できない

@mlflow.trace(span_type="edge")
def should_continue(state: MessagesState) -> Literal["action", "__end__"]:
    """次に実行するノードを返します。"""
    last_message = state["messages"][-1]
    # 関数呼び出しがない場合、終了します
    if not last_message.tool_calls:
        return "__end__"
    # それ以外の場合は続行します
    return "action"


def filter_messages(messages: list):
    # これは非常にシンプルなヘルパー関数で、最後の2つのメッセージのみを使用します
    return messages[-1:]


# モデルを呼び出す関数を定義します
@mlflow.trace(span_type="node")
def call_model(state: MessagesState):
    messages = filter_messages(state["messages"])
    response = model.invoke(messages)
    # 既存のリストに追加されるため、リストを返します
    return {"messages": response}


# 新しいグラフを定義します
workflow = StateGraph(MessagesState)

# サイクルする2つのノードを定義します
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)

# エントリーポイントを `agent` に設定します
# これは、このノードが最初に呼び出されることを意味します
workflow.add_edge(START, "agent")

# 条件付きエッジを追加します
workflow.add_conditional_edges(
    # まず、開始ノードを定義します。`agent` を使用します。
    # これは、`agent` ノードが呼び出された後に取られるエッジを意味します。
    "agent",
    # 次に、次に呼び出されるノードを決定する関数を渡します。
    should_continue,
)

# `tools` から `agent` への通常のエッジを追加します。
# これは、`tools` が呼び出された後に `agent` ノードが次に呼び出されることを意味します。
workflow.add_edge("action", "agent")

# 最後に、これをコンパイルします!
# これをLangChain Runnableにコンパイルします。
# つまり、他のランナブルと同様に使用できます
app = workflow.compile(checkpointer=memory)

実行。

from langchain_core.messages import HumanMessage

# thread_id:2として記憶させる
config = {"configurable": {"thread_id": "2"}}
input_message = HumanMessage(content="hi! I'm bob")
with mlflow.start_span("graph", span_type="AGENT") as span:
    for event in app.stream(
        {"messages": [input_message]}, config, stream_mode="values"
    ):
        event["messages"][-1].pretty_print()


input_message = HumanMessage(content="what's my name?")
with mlflow.start_span("graph", span_type="AGENT") as span:
    for event in app.stream(
        {"messages": [input_message]}, config, stream_mode="values"
    ):
        event["messages"][-1].pretty_print()
出力
================================ Human Message =================================

hi! I'm bob
================================== Ai Message ==================================

 Hello Bob! How can I help you today? Is there something specific you'd like to talk about or learn more about? I'm here to assist you with any questions you might have.
================================ Human Message =================================

what's my name?
================================== Ai Message ==================================

 I'm sorry for any confusion, but I don't have personal knowledge or access to personal data. I was designed to assist and provide information, not to know individual names. If you'd like, we can discuss other topics instead!

結果からわかるように、会話履歴を2個分しか保持しないため、2回目の会話ではBobという名前の事を忘れていました。

まとめ

グラフ永続性の代表的なユースケースである会話履歴管理の処理をウォークスルーしてみました。
メッセージ履歴にフィルタを入れることで、LLMに渡す履歴量をコントロールすることができます。

LangChainでも同様の機能はありますが、個人的にLangGraphの方がより自然な管理ができるんじゃないかと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0