0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LangChainのチェーンやLangGraphのグラフ入出力に型を定義する試み

Last updated at Posted at 2025-04-10

目次

はじめに

入出力に型が欲しいと思う場面が多く、最近ちょっと書き方を変えてみた。

LangChainの場合

まずはLangChainの例。

chains/chain.py
from typing import AsyncGenerator, TypedDict

from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import (
    ChatPromptTemplate,
)
from langchain_core.runnables import Runnable, chain

Chain = Runnable["ChainInput", "ChainOutput"]


class ChainInput(TypedDict):
    input: str


class ChainOutput(TypedDict):
    output: str


def build_chain(chat_model: BaseChatModel) -> Chain:
    """Builds a chain for translation using the provided chat model."""

    @chain
    async def _build_chain(input: ChainInput) -> AsyncGenerator[ChainOutput, None]:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", "あなたは優秀な翻訳者です。"),
                (
                    "user",
                    "次の文章を英訳してください。\n{input}",
                ),
            ]
        )

        _chain = prompt | chat_model

        async for output in _chain.astream({"input": input["input"]}):
            if isinstance(output.content, str):
                yield {"output": output.content}

    return _build_chain


if __name__ == "__main__":
    import asyncio

    from dotenv import load_dotenv
    from langchain_openai import ChatOpenAI

    load_dotenv()

    chat_model = ChatOpenAI(name="gpt-4o-2024-11-20")
    _chain = build_chain(chat_model)

    async def main():
        async for output in _chain.astream(
            {"input": "こんにちは。私の名前はジョンです。"}
        ):
            print(output)

    asyncio.run(main())

出力結果↓

python -m chain
{'output': ''}
{'output': 'Hello'}
{'output': ','}
{'output': ' my'}
{'output': ' name'}
{'output': ' is'}
{'output': ' John'}
{'output': '.'}
{'output': ''}

ポイントとして、任意の関数をチェーンに変換できるchainデコレータを利用している。

入出力の型が分かり、呼び出し元で怒られてくれるので利用しやすくなる↓

image.png

image.png

加えてBaseChatModel等、LangChainの抽象的なクラスにのみ依存すれば汎用性も高まる。

LangGraphの場合

続いてLangGraphの例だが、これもLangChainと勝手は変わらない。

chains/graph.py
from typing import AsyncGenerator, TypedDict

from langchain_core.language_models import BaseChatModel
from langchain_core.prompts import (
    ChatPromptTemplate,
)
from langchain_core.runnables import Runnable, chain
from langgraph.graph import END, START, StateGraph

Graph = Runnable["GraphInput", "GraphOutput"]


class GraphInput(TypedDict):
    input: str


class GraphOutput(TypedDict):
    output: dict


class State(TypedDict):
    value: str


def create_first_node(chat_model: BaseChatModel):
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "user",
                "しりとりしてください。\n{input}",
            ),
        ]
    )

    _chain = prompt | chat_model

    def _first_node(state: State):
        res = _chain.invoke({"input": state["value"]})

        return {
            "value": res.content,
        }

    return _first_node


def create_second_node(chat_model: BaseChatModel):
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "user",
                "しりとりしてください。\n{input}",
            ),
        ]
    )

    _chain = prompt | chat_model

    def _second_node(state: State):
        res = _chain.invoke({"input": state["value"]})

        return {
            "value": res.content,
        }

    return _second_node


def build_graph(chat_model: BaseChatModel) -> Graph:
    builder = StateGraph(State)
    builder.add_node(
        "first_node",
        create_first_node(chat_model),
        input=State,
    )
    builder.add_node(
        "second_node",
        create_second_node(chat_model),
        input=State,
    )
    builder.add_edge(START, "first_node")
    builder.add_edge("first_node", "second_node")
    builder.add_edge("second_node", END)

    _graph = builder.compile()

    @chain
    async def _build_graph(input: GraphInput) -> AsyncGenerator[GraphOutput, None]:
        async for res in _graph.astream({"value": input["input"]}):
            if isinstance(res, dict):
                yield {
                    "output": res,
                }

    return _build_graph


if __name__ == "__main__":
    import asyncio

    from dotenv import load_dotenv
    from langchain_openai import ChatOpenAI

    load_dotenv()

    chat_model = ChatOpenAI(name="gpt-4o-2024-11-20")
    _graph = build_graph(chat_model)

    async def main():
        async for output in _graph.astream({"input": "私からしりとりを始めます。リス"}):
            print(output)

    asyncio.run(main())

実行結果↓

python -m graph
{'output': {'first_node': {'value': 'スベル'}}}
{'output': {'second_node': {'value': 'ルーム'}}}

おわりに

もう少しこれで運用してみて痛みが出てきたら他の実装も検討することにする。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?