導入
LangGraphのHow-to Guideウォークスルーの4回目です。
今回は、こちらの内容である「グラフに永続性(メモリ)を追加する方法」をウォークスルーしてみます。
検証はDatabricks on AWS、DBRは15.3MLを使っています。
永続性の追加について
LangGraphで対話アプリなどを構築する際、一度の会話が終了しても会話履歴を保持して次の会話を再開したいときがあります。
LangChainではMemory機能が提供されていますが、LangGraphではどうでしょうか。
以下、公式ドキュメントの序文を邦訳。
多くの AI アプリケーションでは、複数の対話間でコンテキストを共有するためにメモリが必要です。LangGraphでは、メモリはCheckpointerを介してStateGraphに提供されます。
LangGraphワークフローを作成する際、以下を使用して、その状態を永続化するように設定することができます。
- チェックポイント ツール (AsyncSqliteSaver など)
- グラフのコンパイル時に
compile(checkpointer=my_checkpointer)
を呼び出します。例:
from langgraph.graph import StateGraph from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver builder = StateGraph(....) # ... define the graph memory = AsyncSqliteSaver.from_conn_string(":memory:") graph = builder.compile(checkpointer=memory) ...
これは、StateGraph とそのすべてのサブクラス (MessageGraph など) で機能します。
というわけで、LangGraphはCheckpointという概念を導入しており、それによってMemory機能を提供しています。
では、実際に公式ドキュメントのコードをウォークスルーしてみます。
(一部、簡易化しています)
Step1. パッケージインストール
LangGraphやLangChainなど、必要なパッケージをインストール。
%pip install -U langgraph==0.1.4 langchain==0.2.6 langchain-community==0.2.6 mlflow-skinny[databricks]==2.14.1 pydantic==2.7.4
dbutils.library.restartPython()
Step2. グラフへの永続性追加
サンプルに沿って永続化するためのグラフを作成します。
まず、メッセージのリストを状態として保持するクラスを定義します。
import mlflow
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages
# 状態
class State(TypedDict):
messages: Annotated[list, add_messages]
使用するツールを定義します。
公式ドキュメントの例では、固定文字列のリストを返すダミー処理をツールとして定義しています。
カスタムツールの作成についてはこちらを参照してください。
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
@tool
@mlflow.trace(span_type="tool")
def search(query: str):
"""Web検索"""
# 実際には、ここに実装をする。今回はダミー処理とする。
return ["あなたの質問の答えはここにあります。"]
tools = [search]
tool_node = ToolNode(tools)
次に、LLMのエンドポイントを準備します。
以前の記事で作成したDatabricks Model Servingのエンドポイントを流用します。
実際のサンプルでは、bound_model = model.bind_tools(tools)
を実行してLangChainのツールバインド機能を使っていますが、ChatDatabricks
クラスはツールバインドをサポートしていないため、今回は実行しません。
(結果として、上のToolNodeは実行されなくなります)
from langchain_community.chat_models import ChatDatabricks
endpoint_name = "mistral-7b-instruct-v03-endpoint "
model = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)
次にグラフを定義します。
公式ドキュメントの説明内容を邦訳すると、以下のとおり。
次に、グラフでいくつかの異なるノードを定義する必要があります。 LangGraphでは、ノードは関数またはrunnableオブジェクトのいずれかになります。 これには、2つの主要なノードが必要です。
- エージェント:取るべきアクション(ある場合)を決定する責任があります。
- ツールを呼び出す関数: エージェントがアクションを実行することを決定した場合、このノードはそのアクションを実行します。
また、いくつかのエッジを定義する必要があります。 これらのエッジの一部は条件付きである可能性があります。 これらが条件付きである理由は、ノードの出力に基づいて、いくつかのパスのうちの 1 つが採用される可能性があるためです。 たどられたパスは、そのノードが実行されるまで(LLMが決定するまで)わかりません。
- 条件付きエッジ:エージェントが呼び出された後、次のいずれかを行う必要があります。
a エージェントがアクションを実行するように指示した場合は、ツールを呼び出す関数を呼び出す必要があります
b エージェントが終了したと言った場合は、終了する必要があります- 通常のエッジ:ツールが呼び出された後、常にエージェントに戻り、次に何をすべきかを決定する必要があります
ノードと、どの条件付きエッジを取るかを決定する関数を定義しましょう。
では、説明のようにノードと条件付きエッジを定義します。
# 続行するかどうかを決定する関数を定義
from typing import Literal
@mlflow.trace(span_type="edge")
def should_continue(state: State) -> Literal["action", "__end__"]:
"""次に実行するノードを返す"""
last_message = state["messages"][-1]
# 関数呼び出しがない場合、終了する
if not last_message.tool_calls:
return "__end__"
# それ以外の場合は続行する
return "action"
# モデルを呼び出す関数を定義
@mlflow.trace(span_type="node")
def call_model(state: State):
response = model.invoke(state["messages"])
# 既存のリストに追加されるため、リストを返す
return {"messages": response}
それでは、これまで定義した内容を使ってグラフを作成します。
from langgraph.graph import StateGraph
# 新しいグラフを定義
workflow = StateGraph(State)
# サイクルする2つのノードを定義
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# エントリーポイントを `agent` に設定
# これはこのノードが最初に呼び出されることを意味する
workflow.set_entry_point("agent")
# 条件付きエッジを追加
workflow.add_conditional_edges(
# まず、開始ノードを定義する。ここでは `agent` を使用。
# これは `agent` ノードが呼び出された後に取られるエッジを意味する。
"agent",
# 次に、次に呼び出されるノードを決定する関数を渡す。
should_continue,
)
# `tools` から `agent` への通常のエッジを追加
# これは `tools` が呼び出された後に `agent` ノードが次に呼び出されることを意味する
workflow.add_edge("action", "agent")
さて、今回のポイントである、永続性を追加するためのCheckpointを作成します。
from langgraph.checkpoint.sqlite import SqliteSaver
memory = SqliteSaver.from_conn_string(":memory:")
Checkpointを含めてグラフをコンパイルします。
# 最後に、これをコンパイルします!
# これは LangChain Runnable にコンパイルされます。
# つまり、他のrunnableと同様に使用できます。
app = workflow.compile(checkpointer=memory)
コンパイルしたグラフを可視化します。
from IPython.display import Image, display
try:
display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
# これはいくつかの追加の依存関係を必要とし、オプションです
pass
問題なくグラフが出来ていそうです。
Step3. 実行する
エージェント(グラフ)を実行し、以前のメッセージを記憶していることを確認してみます。
まずは、Bobという名前で自己紹介します。
from langchain_core.messages import HumanMessage
# thread_id:2として記憶させる
config = {"configurable": {"thread_id": "2"}}
input_message = HumanMessage(content="hi! I'm bob")
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream(
{"messages": [input_message]}, config, stream_mode="values"
):
event["messages"][-1].pretty_print()
================================ Human Message =================================
hi! I'm bob
================================== Ai Message ==================================
Hello Bob! It's nice to meet you again. How can I assist you today?
また、MLflow Tracingも設定しているため、ノードの実行状態をUI上で追えます。
同じスレッドID設定を指定して実行すると、前の会話状態を記憶したまま、次の会話を返します。
input_message = HumanMessage(content="what is my name?")
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream(
{"messages": [input_message]}, config, stream_mode="values"
):
event["messages"][-1].pretty_print()
================================ Human Message =================================
what is my name?
================================== Ai Message ==================================
Your name is Bob, as you mentioned earlier. How can I help you today, Bob?
スレッドIDを変更すると、新たな会話がスタートします。
input_message = HumanMessage(content="what is my name?")
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream(
{"messages": [input_message]},
{"configurable": {"thread_id": "3"}},
stream_mode="values",
):
event["messages"][-1].pretty_print()
================================ Human Message =================================
what is my name?
================================== Ai Message ==================================
I apologize for any confusion. As I mentioned earlier, I don't have personal knowledge or access to personal data. I'm just a model and don't have the ability to know or remember specific names. If you'd like, we can create a character or persona for our conversation. What would you like to be called?
Bobという情報を保持していないため、名前を答えれません。
再び前のスレッドIDを使うと、会話を再開できます。
input_message = HumanMessage(content="You forgot??what is my name?")
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream(
{"messages": [input_message]},
{"configurable": {"thread_id": "2"}},
stream_mode="values",
):
event["messages"][-1].pretty_print()
================================ Human Message =================================
You forgot??what is my name?
================================== Ai Message ==================================
I apologize for the mistake. Your name is Bob, as you mentioned earlier. How can I help you today, Bob?
まとめ
グラフに永続性を追加する処理をウォークスルーしてみました。
永続性のキモとなるCheckpointについては何種類かの実装があるようで、どのように使い分けるかを考える必要がありそうです。
チャットアプリのAPIサーバを構築する際など、状態を永続化する必要があると思うので、このあたり必須機能となりそうですね。