はじめに
最近 LangGraph を触って見たので、概念やコードの書き方について備忘録的に書いていきたいと思います。
今回はグラフの作成方法についてです。
ステートグラフ
LangGraph は状態を保持してグラフ上を移動しながら各ノードで処理を行うことでタスクを行います。Wikipediaにあるようにグラフとは
数学のグラフ理論におけるグラフ(英: graph)とは数学的構造の一つ。対象の集合で、対象の一部が相互に何らかの脈絡で「関係している」ようなものをいう。ここで対象とは頂点(節点やノードとも)と呼ばれる抽象物であり、互いに関係のある頂点の対は辺(枝やエッジとも)と呼ばれる[1]。
であり、ノード(頂点)をエッジ(辺)で結んだものになります。
まずはグラフを定義しましょう。グラフを作るにはState
を定義する必要があります。State
はTypedDict
で、以下のコードのように定義されます。
from typing_extensions import Annotated, TypedDict
from langgraph.graph import StateGraph
def reducer(a: list, b: int | None) -> list:
if b is not None:
return a + [b]
return a
class State(TypedDict):
x: Annotated[list, reducer]
message: str
graph = StateGraph(State)
State
は各ノードで更新されます。ノードの戻り値に含まれたキーのみ上書きされますが、x
のようにAnnotated
で reducer 関数を設定することにより、リストを加算する形で更新されます。
チャットボットを作成する場合、人間の入力や LLM の応答をリストとして保存したいので、そのための reducer 関数が用意されているのでそれを使うことが多いでしょう。
from typing_extensions import Annotated, TypedDict
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
graph = StateGraph(State)
ノード
ノードではなんらかの処理を行い、State
を更新します。
ノードはState
を受け取り、State
の更新情報のdict
を戻り値とした関数として定義されます。
add_node
でグラフに追加します。
def add_1_node(state: State) -> dict:
next_x = state['x'][-1] + 1
return {"x": next_x}
graph.add_node("add_one", add_1_node)
def div_by_2_node(state: State) -> dict:
next_x = state['x'][-1] / 2
return {
"x": next_x,
"message": "div2!"
}
graph.add_node("div_by_two", div_by_2_node)
開始・終了用のノードが用意されています。
from langgraph.graph import START ,END
エッジ
単純にノードからノードへ結ぶ場合はadd_edge
で設定できます。
graph.add_edge(START, "add_one")
1つの始点ノードから複数の終点ノードへとエッジを結びたい場合はadd_conditional_edges
を用います。
def is_even(state: State):
return state['x'][-1] % 2 == 0
graph.add_conditional_edges("add_one", is_even, {True: "div_by_two", False: END})
グラフのコンパイル
残りのエッジを追加します。
graph.add_edge("div_by_two", END)
グラフへのノードとエッジの追加が終わったら最後にグラフをコンパイルします。
app = graph.compile()
数字に1を足し、偶数だったら2で割り、奇数だったらそのままという処理をするグラフが作成されました。
実行
State
の初期値を与えてinvoke
すれば実行されます。
result1 = app.invoke({"x": 3})
result2 = app.invoke({"x": 4})
print(result1)
# {'x': [3, 4, 2], 'message': 'div2!'}
print(result2)
# {'x': [4, 5]}
期待通りの処理が行われました。
おわりに
今回は LangGraph でグラフを作り、簡単な計算を行いました。次回はノードでLLMを利用する話を扱いたいと思います。
その②を書きました。
参考文献
LangGraph公式