目次
はじめに
入出力に型が欲しいと思う場面が多く、最近ちょっと書き方を変えてみた。
LangChainの場合
まずはLangChainの例。
chains/chain.py
from typing import AsyncGenerator, TypedDict
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import (
ChatPromptTemplate,
)
from langchain_core.runnables import Runnable, chain
Chain = Runnable["ChainInput", "ChainOutput"]
class ChainInput(TypedDict):
input: str
class ChainOutput(TypedDict):
output: str
def build_chain(chat_model: BaseChatModel) -> Chain:
"""Builds a chain for translation using the provided chat model."""
@chain
async def _build_chain(input: ChainInput) -> AsyncGenerator[ChainOutput, None]:
prompt = ChatPromptTemplate.from_messages(
[
("system", "あなたは優秀な翻訳者です。"),
(
"user",
"次の文章を英訳してください。\n{input}",
),
]
)
_chain = prompt | chat_model
async for output in _chain.astream({"input": input["input"]}):
if isinstance(output.content, str):
yield {"output": output.content}
return _build_chain
if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
load_dotenv()
chat_model = ChatOpenAI(name="gpt-4o-2024-11-20")
_chain = build_chain(chat_model)
async def main():
async for output in _chain.astream(
{"input": "こんにちは。私の名前はジョンです。"}
):
print(output)
asyncio.run(main())
出力結果↓
python -m chain
{'output': ''}
{'output': 'Hello'}
{'output': ','}
{'output': ' my'}
{'output': ' name'}
{'output': ' is'}
{'output': ' John'}
{'output': '.'}
{'output': ''}
ポイントとして、任意の関数をチェーンに変換できるchain
デコレータを利用している。
入出力の型が分かり、呼び出し元で怒られてくれるので利用しやすくなる↓
加えてBaseChatModel
等、LangChainの抽象的なクラスにのみ依存すれば汎用性も高まる。
LangGraphの場合
続いてLangGraphの例だが、これもLangChainと勝手は変わらない。
chains/graph.py
from typing import AsyncGenerator, TypedDict
from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import (
ChatPromptTemplate,
)
from langchain_core.runnables import Runnable, chain
from langgraph.graph import END, START, StateGraph
Graph = Runnable["GraphInput", "GraphOutput"]
class GraphInput(TypedDict):
input: str
class GraphOutput(TypedDict):
output: dict
class State(TypedDict):
value: str
def create_first_node(chat_model: BaseChatModel):
prompt = ChatPromptTemplate.from_messages(
[
(
"user",
"しりとりしてください。\n{input}",
),
]
)
_chain = prompt | chat_model
def _first_node(state: State):
res = _chain.invoke({"input": state["value"]})
return {
"value": res.content,
}
return _first_node
def create_second_node(chat_model: BaseChatModel):
prompt = ChatPromptTemplate.from_messages(
[
(
"user",
"しりとりしてください。\n{input}",
),
]
)
_chain = prompt | chat_model
def _second_node(state: State):
res = _chain.invoke({"input": state["value"]})
return {
"value": res.content,
}
return _second_node
def build_graph(chat_model: BaseChatModel) -> Graph:
builder = StateGraph(State)
builder.add_node(
"first_node",
create_first_node(chat_model),
input=State,
)
builder.add_node(
"second_node",
create_second_node(chat_model),
input=State,
)
builder.add_edge(START, "first_node")
builder.add_edge("first_node", "second_node")
builder.add_edge("second_node", END)
_graph = builder.compile()
@chain
async def _build_graph(input: GraphInput) -> AsyncGenerator[GraphOutput, None]:
async for res in _graph.astream({"value": input["input"]}):
if isinstance(res, dict):
yield {
"output": res,
}
return _build_graph
if __name__ == "__main__":
import asyncio
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
load_dotenv()
chat_model = ChatOpenAI(name="gpt-4o-2024-11-20")
_graph = build_graph(chat_model)
async def main():
async for output in _graph.astream({"input": "私からしりとりを始めます。リス"}):
print(output)
asyncio.run(main())
実行結果↓
python -m graph
{'output': {'first_node': {'value': 'スベル'}}}
{'output': {'second_node': {'value': 'ルーム'}}}
おわりに
もう少しこれで運用してみて痛みが出てきたら他の実装も検討することにする。