導入
LangGraphのHow-to Guideウォークスルーの9回目です。
今回は、こちらの内容である「過去のグラフの状態を表示および更新する方法」をウォークスルーしてみます。
検証はDatabricks on AWS、DBRは15.3MLを使っています。
グラフ状態履歴の操作とは
以下、公式ドキュメントの序文を邦訳。
過去のグラフの状態を表示および更新する方法
グラフのチェックポイントを開始すると、いつでもエージェントの状態を簡単に取得または更新できます。これにより、いくつかのことが可能になります。
- 割り込み中に状態をユーザーに表示して、ユーザーがアクションを受け入れるようにすることができます。
- グラフを巻き戻して、問題を再現または回避できます。
- 状態を変更して、エージェントをより大きなシステムに埋め込んだり、ユーザーがそのアクションをより適切に制御できるようにしたりできます。
この機能に使用される主なメソッドは次のとおりです。
- get_state: ターゲットコンフィグから値を取得する
- update_state: 指定された値をターゲットの状態に適用します。
注: これには、チェックポイントを渡す必要があります。
グラフの実行状態履歴に対して、タイムトラベルで様々な操作ができる機能のイメージです。
では、公式ドキュメントのコードをウォークスルーしてみましょう。
Step1. パッケージインストール
LangGraphやLangChainなど、必要なパッケージをインストール。
%pip install -U langgraph==0.1.4 langchain==0.2.6 langchain-community==0.2.6 mlflow-skinny[databricks]==2.14.1 pydantic==2.7.4
dbutils.library.restartPython()
Step2. エージェント(グラフ)の構築
シンプルなReActスタイルのエージェントを構築します。
モデルは以前の記事で作成したDatabricks Model Servingのエンドポイントを流用します。
import mlflow
from langchain_community.chat_models import ChatDatabricks
# 状態を設定
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.graph import MessagesState, START
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from langchain_core.pydantic_v1 import BaseModel
@tool
def search(query: str):
"""Web検索を実行します。"""
return "サンフランシスコは晴れです"
tools = [search]
tools_by_name = {tool.name: tool for tool in tools}
# Tool呼び出し用ノード。HumanMessageで結果を返す
@mlflow.trace(span_type="node")
def tool_node(state: dict):
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(HumanMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
# モデルを設定
endpoint_name = "mistral-7b-instruct-v03-endpoint"
model = ChatDatabricks(endpoint=endpoint_name, temperature=0.1)
# 続行するかどうかを決定する関数を定義
@mlflow.trace(span_type="edge")
def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
# 関数呼び出しがない場合、終了します
if not last_message.tool_calls:
return "end"
# それ以外の場合は続行します
else:
return "continue"
# モデルを呼び出す関数を定義
@mlflow.trace(span_type="node")
def call_model(state):
messages = state["messages"]
response = None
# 修正:Tool Call未対応のため、メッセージ件数が少ないときは強制的にツール呼び出しをする
if len(messages) < 2:
response = AIMessage(content="サンフランシスコの天気を調べます")
response.tool_calls = [
{"id": "1111", "name": "search", "args": {"query": "サンフランシスコ"}}
]
else:
response = model.invoke(messages)
# 既存のリストに追加するリストを返します
return {"messages": [response]}
# グラフを構築
# 新しいグラフを定義
workflow = StateGraph(MessagesState)
# サイクルする3つのノードを定義
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
# エントリーポイントを `agent` に設定
# これは、このノードが最初に呼び出されることを意味します
workflow.add_edge(START, "agent")
# 条件付きエッジを追加
workflow.add_conditional_edges(
# まず、開始ノードを定義します。`agent` を使用します。
# これは、`agent` ノードが呼び出された後に取られるエッジを意味します。
"agent",
# 次に、次に呼び出されるノードを決定する関数を渡します。
should_continue,
# 最後にマッピングを渡します。
# キーは文字列で、値は他のノードです。
# ENDはグラフが終了することを示す特別なノードです。
# これにより、`should_continue` が呼び出され、その出力がこのマッピングのキーと一致します。
# 一致したキーに基づいて、そのノードが次に呼び出されます。
{
# `tools` の場合、ツールノードを呼び出します。
"continue": "action",
# それ以外の場合は終了します。
"end": END,
},
)
# `tools` から `agent` への通常のエッジを追加
# これは、`tools` が呼び出された後に `agent` ノードが次に呼び出されることを意味します。
workflow.add_edge("action", "agent")
# メモリを設定
memory = MemorySaver()
# 最後に、これをコンパイルします!
# これをLangChain Runnableにコンパイルします。
# つまり、他のランナブルと同様に使用できます
# `ask_human` ノードが実行される前にブレークポイントを追加します
app = workflow.compile(checkpointer=memory)
グラフを可視化すると以下のようになります。
from IPython.display import Image, display
try:
display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
# これはいくつかの追加の依存関係を必要とし、オプションです
pass
Step3. エージェントとの対話と過去履歴の確認・更新
まずは構築したエージェントを使って対話してみます。
from langchain_core.messages import HumanMessage
config = {"configurable": {"thread_id": "1"}}
input_message = HumanMessage(
content="Use the search tool to look up the weather in SF"
)
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream({"messages": [input_message]}, config, stream_mode="values"):
event["messages"][-1].pretty_print()
================================ Human Message =================================
Use the search tool to look up the weather in SF
================================== Ai Message ==================================
サンフランシスコの天気を調べます
Tool Calls:
search (1111)
Call ID: 1111
Args:
query: サンフランシスコ
================================ Human Message =================================
サンフランシスコは晴れです
================================== Ai Message ==================================
San Francisco is sunny today.
MLflow Tracingのログは以下のようになります。
では、LangGraphの機能を使って実行した履歴を確認してみましょう。
all_states = []
for state in app.get_state_history(config):
print(state)
all_states.append(state)
print("--")
StateSnapshot(values={'messages': []}, next=('__start__',), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3c319-9179-6760-bfff-50781c91081e'}}, metadata={'source': 'input', 'step': -1, 'writes': {'messages': [HumanMessage(content='Use the search tool to look up the weather in SF')]}}, created_at='', parent_config=None)
--
StateSnapshot(values={'messages': [HumanMessage(content='Use the search tool to look up the weather in SF', id='aa70a81d-87fb-4049-bf87-4b2b62e49288')]}, next=('agent',), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3c319-917c-65fe-8000-066b3fb9caeb'}}, metadata={'source': 'loop', 'step': 0, 'writes': None}, created_at='', parent_config=None)
--
StateSnapshot(values={'messages': [HumanMessage(content='Use the search tool to look up the weather in SF', id='aa70a81d-87fb-4049-bf87-4b2b62e49288'), AIMessage(content='サンフランシスコの天気を調べます', id='89e37a2b-df1c-458a-9e53-b6c0e91ea124', tool_calls=[{'name': 'search', 'args': {'query': 'サンフランシスコ'}, 'id': '1111'}])]}, next=('action',), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3c319-9184-6129-8001-69f4ebd5d134'}}, metadata={'source': 'loop', 'step': 1, 'writes': {'agent': {'messages': [AIMessage(content='サンフランシスコの天気を調べます', id='89e37a2b-df1c-458a-9e53-b6c0e91ea124', tool_calls=[{'name': 'search', 'args': {'query': 'サンフランシスコ'}, 'id': '1111'}])]}}}, created_at='', parent_config=None)
--
StateSnapshot(values={'messages': [HumanMessage(content='Use the search tool to look up the weather in SF', id='aa70a81d-87fb-4049-bf87-4b2b62e49288'), AIMessage(content='サンフランシスコの天気を調べます', id='89e37a2b-df1c-458a-9e53-b6c0e91ea124', tool_calls=[{'name': 'search', 'args': {'query': 'サンフランシスコ'}, 'id': '1111'}]), HumanMessage(content='サンフランシスコは晴れです', id='aa8300f6-702a-4f7f-910e-70b28f3ab17b')]}, next=('agent',), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3c319-9188-64bb-8002-08e71a2edf8f'}}, metadata={'source': 'loop', 'step': 2, 'writes': {'action': {'messages': [HumanMessage(content='サンフランシスコは晴れです', id='aa8300f6-702a-4f7f-910e-70b28f3ab17b')]}}}, created_at='', parent_config=None)
--
StateSnapshot(values={'messages': [HumanMessage(content='Use the search tool to look up the weather in SF', id='aa70a81d-87fb-4049-bf87-4b2b62e49288'), AIMessage(content='サンフランシスコの天気を調べます', id='89e37a2b-df1c-458a-9e53-b6c0e91ea124', tool_calls=[{'name': 'search', 'args': {'query': 'サンフランシスコ'}, 'id': '1111'}]), HumanMessage(content='サンフランシスコは晴れです', id='aa8300f6-702a-4f7f-910e-70b28f3ab17b'), AIMessage(content=' San Francisco is sunny today.', response_metadata={'prompt_tokens': 53, 'completion_tokens': 8, 'total_tokens': 61}, id='run-abe2d61a-ffa5-4349-966d-14c9799bb075-0')]}, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3c319-9b04-6be6-8003-ee86f0b173c8'}}, metadata={'source': 'loop', 'step': 3, 'writes': {'agent': {'messages': [AIMessage(content=' San Francisco is sunny today.', response_metadata={'prompt_tokens': 53, 'completion_tokens': 8, 'total_tokens': 61}, id='run-abe2d61a-ffa5-4349-966d-14c9799bb075-0')]}}}, created_at='', parent_config=None)
--
これらの状態履歴から途中時点に戻り、そこからエージェントを再実行することができます。
search
ツールの呼び出しが実行される直前に戻りってみましょう。
# 最初の2個までの状態を取得して表示
to_replay = all_states[2]
to_replay.values
{'messages': [HumanMessage(content='Use the search tool to look up the weather in SF', id='aa70a81d-87fb-4049-bf87-4b2b62e49288'),
AIMessage(content='サンフランシスコの天気を調べます', id='89e37a2b-df1c-458a-9e53-b6c0e91ea124', tool_calls=[{'name': 'search', 'args': {'query': 'サンフランシスコ'}, 'id': '1111'}])]}
次に実行される(た)ノードの情報も取得できます。
# 次のノードを確認
to_replay.next
('action',)
以前の状態からエージェントを再度実行するには、config
をエージェントに渡して実行します。
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream(None, to_replay.config):
for v in event.values():
print(v)
{'messages': [HumanMessage(content='サンフランシスコは晴れです', tool_call_id='1111')]}
{'messages': [AIMessage(content=' San Francisco is sunny today.', response_metadata={'prompt_tokens': 53, 'completion_tokens': 8, 'total_tokens': 61}, id='run-0d8f5cf9-c6a7-4139-8650-85b4459f11ae-0')]}
以前の状態から再実行できました。
さらに、LangGraphのチェックポイントを使うと、過去状態からの再生だけでなく、以前の場所を分岐して、別のルートを探索したり、バージョン管理ができます。
これを行うために、特定時点の状態を編集してみます。search
への入力を変更してみましょう。
# 状態の最後のメッセージを取得する
# これは更新したいツールコールを含むメッセージです
last_message = to_replay.values["messages"][-1]
# そのツールコールの引数を更新する
last_message.tool_calls[0]["args"] = {"query": "ほげほげー"}
branch_config = app.update_state(
to_replay.config,
{"messages": [last_message]},
)
状態を変更したbranch_config
を使って実行を再開してみます。
with mlflow.start_span("graph", span_type="AGENT") as span:
for event in app.stream(None, branch_config):
for v in event.values():
print(v)
{'messages': [HumanMessage(content='サンフランシスコは晴れです', tool_call_id='1111')]}
{'messages': [AIMessage(content=' San Francisco is sunny today.', response_metadata={'prompt_tokens': 53, 'completion_tokens': 8, 'total_tokens': 61}, id='run-1ca9fdf0-b782-4a7a-ae59-7b9dd827abda-0')]}
実行結果だけ見ると変化が無いように見えるのですが、MLflow Tracingのログからノードの入力を見ると、変更した状態が反映されていることがわかります。
まとめ
グラフの状態履歴の取得から、状態を変更してからの再実行などを実行してみました。
状態の保存(永続化)も絡めて、履歴の再現などに利用するのかなという印象です。
状態操作はいろいろやりそうなので、覚えておこうと思います。