概要
langgraphの勉強をしていて、ToolNodeが何なのかよくわからなかったので、ソースコードを読んで理解を試みました。
モチベーション
langgraphのチュートリアルで以下のような記述があります。(要点だけ抜粋)。
from langgraph.prebuilt import ToolNode
# Define the tools for the agent to use
tools = [TavilySearchResults(max_results=1)]
tool_node = ToolNode(tools)
# Define a new graph
workflow = StateGraph(MessagesState)
# Define the two nodes we will cycle between
workflow.add_node("tools", tool_node)
workflow(StateGraph)にて、"tools"という名前のnodeにtool_nodeクラスを紐づけています。
ToolNodeはおそらくRunnableを継承しています。
nodeに紐づけられたRunnable継承クラスは実行時invokeが呼ばれるはずです。
tool単体のinvokeはtoolに紐づいた関数を実行するということでなんとなくわかるのですが、ToolNodesはtoolのリストのラッパーなので、toolのリストに対するinvokeがどういう挙動をしているのかわからないので、理解したいと思いました。
わかりたいこと
ToolNodeクラスのinvokeの入力と出力が何かを理解したいです。
想像としては、想像としては、入力はagentクラスの出力(tool_callプラパティが存在するAIMessage?)だと思います。出力は、toolの実行結果なのだと思いますが、複数のtoolが呼び出されている場合は辞書かリストで渡されるのか?など、あまりイメージできていません。
調査
ToolNodeのソースコードを読んでみます。
クラス冒頭の説明コメントに以下のように書かれています。
"""A node that runs the tools requested in the last AIMessage. It can be used
either in StateGraph with a "messages" key or in MessageGraph. If multiple
tool calls are requested, they will be run in parallel. The output will be
a list of ToolMessages, one for each tool call.
The `ToolNode` is roughly analogous to:
```python
tools_by_name = {tool.name: tool for tool in tools}
def tool_node(state: dict):
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
```
Important:
- The state MUST contain a list of messages.
- The last message MUST be an `AIMessage`.
- The `AIMessage` MUST have `tool_calls` populated.
"""
以下の内容が書かれています。
- 最後のAIMessageに含まれるtool_callをrunする
- "message"のkeyをもつStateGraphクラスか、MessageGraphクラスにて使われる
- 複数のtool_callがある場合は、pararellに実行する
- 出力は各tool_callに対応するToolMessageのリストである。
「ざっくりとはこういう処理だよ」というまえがきとともに、tool_node関数の例が書かれています。
関数の中身としては、辞書型のstateを受け取り、messagesバリューの最後のindexのtool_callsを取得し、各tool_callに対応するinvokeを実行し、resultリストにToolMessage型に変換したうえでappendする、みたいな処理です。文章で記述された要点の通りの処理になっています。
重要な点がコメントの最後に書かれています。
- stateはmessageのリストを含む必要がある。
- 最後のmessageはAIMessageである必要がある。
- AIMessageはtool_callsをもつ必要がある。
理解の整理
思いの外、クラスのコメントを読むだけで、スッキリしたので、ここで終わりにします。
ToolNodeの実際の処理やTooleNodeが継承しているRunnableCallableも見る必要性も想定していたのですが、今はそこまで深い理解はいらなさそうです。
もともとの想像に対して実態がどうだったか整理します。
入力については、tool_callsが存在するAIMessageを入力するという想像通りです。
出力はtoolの戻り値のリストかなと雑に想像していましたが、より正確にはToolMessageのリストが返されるようでした。
おわりに
これで、ToolNodeを拒否感なく使えそうです。