35
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Streamlit × LangGraphでHuman-in-the-loop Agentの実装:広告コピー文生成アプリケーション

Last updated at Posted at 2024-12-06

はじめに

株式会社NTTデータ デジタルサクセスコンサルティング事業部の@yamato0811です。

データサイエンス・AI開発の現場において、Streamlitはその手軽さからプロトタイピングやデモンストレーションに広く活用されています。また、LangGraphはLLMを活用したAgentフローを容易に構築するフレームワークとして注目を集めています。

この度、StreamlitとLangGraphを利用してHuman-in-the-loopアプリケーションを実装する機会がありました。実装を進める中で、Agentフローの構築およびStreamlitとの連携が難しく、工夫が必要となる場面に遭遇しました。特に、StreamlitとLangGraphによるHuman-in-the-loopの実装は事例が少なく、参考にできる文献がほとんど存在しません。

本記事では、広告コピー文を生成する簡単なAgentアプリケーションの作成例を紹介しています。具体的な実装例を共有することで、同様のAgent開発やアプリケーション開発に挑戦するエンジニアの皆様にとって、少しでもお役に立てれば幸いです。

リポジトリは、以下のリンクからご確認いただけます。

想定読者

  • Streamlitを用いたアプリケーション開発経験がある方
  • LangGraphの概要を理解している方
  • StreamlitとLangGraphを連携させたアプリケーション開発に興味がある方
  • LangGraphでのHuman-in-the-loop実装に興味がある方

Human-in-the-loopとは

Human-in-the-loopは、エージェントシステムにおいて人間の介入を通じてユーザー体験を向上させる手法です。この手法では、ユーザーがエージェントの動作を監視し、必要に応じて介入することで、より正確で効果的な結果を得ることができます。

LangGraphでのHuman-in-the-loopの実装

LangGraphでは、breakpointを用いてグラフの実行を特定のステップで一時停止し、人間の承認や意見を求めることができます。以下には特に承認を求める際の流れを示します。

  1. ブレークポイント: LangGraphのbreakpointは、エージェントが特定のノードの実行前に一時停止し、人間の承認を待つことを可能にする
  2. 承認を待つ:グラフ状態を保存するチェックポイントを作成し、人間からの承認を待つ。チェックポイントはスレッドに保存され、あとからグラフを再開することが可能
  3. グラフの再開:人間からの承認が得られたら、グラフの実行を再開する

langgraph_HITL.png

Human-in-the-loopのサンプルコードを以下に示します。

# node2の前で一時停止を行うブレイクポイントを設定してグラフをコンパイル
graph = builder.compile(checkpointer=checkpoitner, interrupt_before=["node_2"])

# ブレークポイントまでグラフを実行する
for event in graph.stream(inputs, thread, stream_mode="values"):
    print(event)

# ... 人間の承認を得る ...

# 承認されたら、最後に保存されたチェックポイントからグラフの実行を再開する
for event in graph.stream(None, thread, stream_mode="values"):
    print(event)

builder.compileメソッドのinterrupt_before引数に、ブレークポイントを設定したいノードのID(この場合はnode_2)を指定します。これにより、グラフの実行時に、指定したノードの前で自動的に一時停止するようになります。
グラフの実行は、graph.streamメソッドを使用します。最初の実行時には、入力データinputsとスレッド情報threadを渡します。
再開時には、graph.stream(None, thread, stream_mode="values")のように、最初の引数にNoneを渡すことで、保存されたチェックポイントから実行が再開されます。

図およびサンプルコードはLangGraphのHuman-in-the-loop How-toページより引用しています。

アプリケーションの概要

商品情報を入力すると、広告コピー文を自動生成するAgentアプリケーションを開発しました。 本アプリケーションはHuman-in-the-loopを実装しており、生成されたコピーにユーザーが満足しなければ、フローの途中でAgentにコピーの再検討(再生成)を指示できます。加えて、Agentは改善に必要な追加情報をユーザーへインタラクティブに尋ねることができ、より精度の高いコピー文を生成できます。

output.gif

python実装は以下に公開しています。Agentとして利用するLLMには、Amazon BedrockのClaude 3.5 Sonnet v2 および、Azure OpenAI gpt-4o を選択可能です。

LangGraphのグラフ構造

以下に、今回作成したAgentアプリケーションのグラフ構造を以下に示します。

アプリ実行の流れ

  1. キャッチコピー生成の開始:
    アプリを開始すると、最初のノード(__start__)に到達する。このノードは、アプリの開始を表す。
  2. キャッチコピー生成:
    generate_copyノードに到達すると、AgentはLLM(今回はOpenAI GPT-4o)を用いてキャッチコピーを生成する。
  3. ユーザーの選択:
    user_select_copyノードに到達すると、ユーザーに生成されたキャッチコピーを表示し、ユーザーに選択を促す。
    もし、ユーザーが「再検討」を選択した場合は、reflect_copyノードに遷移する。
  4. キャッチコピーの改善:
    reflect_copyノードに到達すると、Agentは生成したキャッチコピーの改善点と改善に必要な追加情報を思考する。
  5. ユーザーの追加情報の入力:
    user_input_additioal_info_copyノードに到達すると、ユーザーに追加情報の入力を促す。
  6. キャッチコピーの再生成:
    ユーザーが追加情報を入力すると、generate_copyノードに戻り、Agentはキャッチコピーの再生成を行う。
  7. キャッチコピー生成の終了:
    ユーザーが生成されたキャッチコピーを承認すると、dummy_endノードに到達し、キャッチコピー生成が終了する。

※ 2~6の手順は、ユーザーが生成されたキャッチコピーを承認するまで繰り返されます(いわゆるAgent Loop)。

LangGraphでAgentグラフの実装

本章では、LangGraphを使用したAgentグラフの実装方法について、具体的なコード例を交えて詳しく説明します。Stateの定義方法、各Nodeの役割とその実装手順、およびAgentグラフの効果的な構築・管理方法について解説します。さらに、Structured Outputを利用して出力形式を保証する方法についても説明します。

LangGraphのAgentグラフの実装はagentディレクトリ配下にまとめています。以下では、実装の詳細について説明します。
※ プログラムの一部のみ抜粋していますので、適宜githubコードを参照ください。

Stateの定義

グラフのノード間を遷移するState情報はagent/state.pyに以下のように定義しています。また、DisplayMessageDictは、Streamlitに表示するメッセージのタイトル、アイコン、本文を定義しています。

agent/state.py
from typing import Annotated, List

from typing_extensions import TypedDict


class DisplayMessageDict(TypedDict):
    title: Annotated[str, "表示用のタイトル"]
    icon: Annotated[str, "アイコン"]
    message_text: Annotated[str, "表示用のメッセージ"]


class State(TypedDict):
    # ================
    # Input
    # ================
    # Initial
    product_info: Annotated[str, "商品情報"]

    # Stremlit上でのuser input
    additional_info_input: Annotated[str, "入力された追加情報"]
    selected_copy: Annotated[dict, "選択されたキャッチコピー"]

    # ================
    # Output
    # ================
    copies: Annotated[list[dict], "キャッチコピーのリスト"]
    additional_info: Annotated[str, "必要な追加情報"]
    is_rethink: Annotated[bool, "再検討を行うか"]

    # ================
    # 処理用State
    # ================
    iteration_count: Annotated[int, "現在の反復回数"]
    is_finish: Annotated[bool, "終了判定フラグ"]
    display_message_dict: Annotated[DisplayMessageDict, "表示用のメッセージ"]

    # 履歴管理用
    messages: Annotated[list, "会話履歴のリスト"]

Nodeの定義

以下に、agent/node.pyの主要な部分の説明を行います。

まず、LangGraphのノード管理を簡素化するため、独自にデータクラスNodeTypeを定義しています。これと後述するGraphBuilderクラスと組み合わせて使用することで、ノード名のハードコードを避け、可読性と保守性の向上が期待できます。

agent/node.py
@dataclass
class NodeType:
    name: str  # ノードの名前
    func: Callable  # ノードで実行する関数

Nodeクラスでは、LangGraphの各ノードの処理を定義します。
各ノードはNodeTypeを使って定義し、ノード名と対応する処理関数を結び付けます。例えば、self.generate_copy"generate_copy"という名前を持ち、self._generate_copy関数を実行するノードとして定義します。

agent/node.py
class Node:
    def __init__(
        self,
        llm: LLM,
        prompt: Dict[str, Dict[str, str]],
    ) -> None:

        self.llm = llm
        self.prompt = prompt

        # ================
        # Define Node
        # ================
        self.generate_copy = NodeType("generate_copy", self._generate_copy)
        self.user_select_copy = NodeType("user_select_copy", self._user_input)
        self.reflect_copy = NodeType("reflect_copy", self._reflect_copy)
        self.user_input_additioal_info_copy = NodeType(
            "user_input_additioal_info_copy", self._user_input
        )
        self.end = NodeType("dummy_end", self._end_node)

genarate_copyノード

genarate_copyノードの処理関数は以下のように記載しています。

初回のコンテンツ生成(state["iteration_count"] == 0)と2回目以降(elseブロック)で処理が異なります。初回は商品情報を入力したプロンプトを、2回目以降はユーザーの追加情報を考慮したユーザープロンプトを使用してLLMを実行します。

また、今回は所望の形式で出力を得るため、LLM実行にlangchainのStructured Outputを用いています。(具体的な使用については後述します。)

LLMの出力結果は整形してStateに格納します。display_message_dictには、Streamlitに表示するためのメッセージを格納します。

agent/node.py
    def _generate_copy(self, state: State) -> State:
        print("Node: generate_copy")

        product_info = state["product_info"]

        # 初回コンテンツ生成
        if state["iteration_count"] == 0:
            system_prompt = SystemMessage(
                content=self.prompt["generate_copy"]["system"]
            )
            human_prompt = HumanMessagePromptTemplate.from_template(
                self.prompt["generate_copy"]["user_first"]
            ).format(
                product_info=product_info,
                output_format_instruction=get_output_format_instructions(Copies),
            )

            state["messages"] = [system_prompt, human_prompt]
        # 2回目以降のコンテンツ生成
        else:
            human_prompt = HumanMessagePromptTemplate.from_template(
                self.prompt["generate_copy"]["user_second"]
            ).format(
                product_info=product_info,
                additional_info=state["additional_info"],
                additional_info_input=state["additional_info_input"],
                output_format_instruction=get_output_format_instructions(Copies),
                state=state,
            )

        # 履歴にユーザーの入力を追加
        state["messages"].append(human_prompt)
        # invoke
        ai_message = self.llm((state["messages"]), Copies)
        # 履歴にAIの出力を追加
        state["messages"].append(AIMessage(ai_message.model_dump_json()))

        # AIの出力をリストに変換
        output_list = ai_message.model_dump()["copies"]

        # streamlit表示用のメッセージ
        message_text = ""
        for output in output_list:
            # avoid to break markdown format
            output["copy_text"] = output["copy_text"].replace("\n", "")
            # markdown改行のため空白スペース(\u0020)が2つ必要
            message_text += f"""
        **【{output["title"]}】**\u0020\u0020
        **キャッチコピー**:{output["copy_text"]}\u0020\u0020
        **理由**:{output["reason"]}
        """
        display_message_dict = {
            "title": f"**キャッチコピーの作成** {state['iteration_count'] + 1}回目",
            "icon": "📝",
            "message_text": message_text,
        }

        # 'reason'キーのみを削除した新しいリストを生成
        filtered_list = filter_key_from_list(output_list, "reason")

        # 状態の更新
        state["copies"] = filtered_list
        state["display_message_dict"] = display_message_dict

        return state

genarate_copyノードで使用してるプロンプトは別ファイルで管理しています。
今回はXMLタグを使ってプロンプトの構造化を行っています。

agent/prompt/prompt_templates.yaml
generate_copy:
  system: |
    あなたはプロのコピーライターです。
  user_first: |
    <instruction>
    以下のproductタグ内の情報を基に、商品のキャッチコピーを3つ生成して下さい。
    なぜその出力にしたかの理由も考えて下さい。
    </instruction>
    <product>
    {product_info}
    </product>
    <output>
    {output_format_instruction}
    </output>
  user_second: |
    <instruction>
    以下のadditional_infoタグ内のユーザーからの追加情報を考慮した上で、productタグ内の情報を基に、商品のキャッチコピー3つ生成して下さい。
    なぜその出力にしたかの理由も考えて下さい。
    </instruction>
    <additional_info>
    {additional_info}: {additional_info_input}
    </additional_info>
    <output>
    {output_format_instruction}
    </output>

reflect_copyノード

reflect_copyノードも同様にして定義しています。
genarate_copyノードで初回実行かどうかを判定するため、処理の最後でカウントアップ(state["iteration_count"] += 1)している点には注意してください。

agent/node.py
    def _reflect_copy(self, state: State) -> State:
        print("Node: reflect_copy")

        copies = state["copies"]

        human_prompt = HumanMessagePromptTemplate.from_template(
            self.prompt["reflect_copy"]["user"]
        ).format(
            copies=copies,
            output_format_instruction=get_output_format_instructions(ReflectDetails),
        )

        # 履歴にユーザーの入力を追加
        state["messages"].append(human_prompt)
        # invoke
        ai_message = self.llm((state["messages"]), ReflectDetails)
        # 履歴にAIの出力を追加
        state["messages"].append(AIMessage(ai_message.model_dump_json()))

        # 文字列をPythonの辞書に変換
        data = ai_message.model_dump()

        display_message_dict = {
            "title": f"**キャッチコピーの改善** {state['iteration_count'] + 1}回目",
            "icon": "🔄",
            "message_text": f"""
            **改善点**:{data["improvement_point"]}\u0020\u0020
            **必要な追加情報**:{data["additional_info"]}\u0020\u0020
            **理由**:{data["reason"]}
            """,
        }

        # 状態の更新
        state["additional_info"] = data["additional_info"]
        state["display_message_dict"] = display_message_dict

        # カウントアップ
        state["iteration_count"] += 1

        return state

reflect_copyノードで使用しているプロンプトは以下です。

agent/prompt/prompt_templates.yaml
reflect_copy:
  user: |
    <instruction>
    copyタグ内の複数のキャッチコピーを評価し、改善点を考えて下さい。
    また、改善点を実現するために必要な「ユーザーからの追加情報」をとても簡潔に1つだけ考えて下さい。
    なぜその出力にしたかの理由も考えて下さい。
    </instruction>
    <copy>
    {copies}
    </copy>
    <output>
    {output_format_instruction}
    </output>

その他のノード

user_inputノードは、一時停止を行うためのダミーノードなので内部処理は記載していません。また、Streamlit実装で必要となるダミーのend_nodeも定義しておきます。

agent/node.py
    def _user_input(self, state: State):
        pass

    def _end_node(self, state: State):
        print("Node: end_node")
        return {"is_finish": True, "display_message_dict": None}

分岐関数

Conditional edge(条件分岐)を実現するための分岐関数も定義しています。
stateのis_rethinkがTrueの場合に"reflect"へ、Falseの場合に"end"へルーティングします。

agent/node.py
    def should_rethink(self, state: State) -> Literal["reflect", "end"]:
        if state["is_rethink"]:
            return "reflect"
        else:
            return "end"

(補足)Structured Output

LLMに事前定義した形式で出力させるように、LangChainのStructured Outputを使用しています。

出力のスキーマ形式は、agent/output_structure.pyに定義してあります。

agent/output_structure.py
from typing import Optional

from pydantic import BaseModel, Field


# ================
# Copy Generation
# ================
class Copy(BaseModel):
    """コピー"""

    title: str = Field(description="タイトル(例:案1, 案2, ..)")
    reason: str = Field(description="回答の理由")
    copy_text: str = Field(description="キャッチコピー")


class Copies(BaseModel):
    """コピーの出力形式"""

    copies: list[Copy] = Field(description="キャッチコピーのリスト")


# ================
# Reflect
# ================
class ReflectDetails(BaseModel):
    """ユーザーからのフィードバック情報"""

    reason: str = Field(description="回答の理由")
    improvement_point: str = Field(description="改善点")
    additional_info: Optional[str] = Field(
        default=None,
        description="ユーザーに求める追加情報の内容(体言止め)",
    )

LLMを呼び出す際に、以下のようにpydanticで定義したスキーマを渡すことで、スキーマに従った出力を取得できます。

# invoke
ai_message = self.llm((state["messages"]), Copies)
json_data = ai_message.model_dump()

このときの出力例は以下です。

[
  {
    "title": "案1",
    "reason": "『保湿』という商品の機能を直接的にアピールしつつ、つるつるの肌を想像させるため。",
    "copy_text": "しっとり肌へ、うるおいの魔法を。"
  },
  {
    "title": "案2",
    "reason": "季節や環境に左右されず保湿できることを強調し、毎日使える定番アイテムであることを伝えています。",
    "copy_text": "毎日の潤いチャージ、これ一本で。"
  },
  {
    "title": "案3",
    "reason": "肌に優しいイメージを持たせ、安心して使える保湿クリームであることを伝えています。",
    "copy_text": "肌に優しさ、保湿の贈り物。"
  }
]

また、LangChainのOutputParserを使うと、LLM出力のフォーマット指示プロンプトを取得することも可能です。

utils/node_util.py
def get_output_format_instructions(model: BaseModel) -> str:
    """出力フォーマットの指示を取得する"""
    parser = PydanticOutputParser(pydantic_object=model)
    output_format_instruction = parser.get_format_instructions()
    return output_format_instruction

例えば、上記の関数を使用してCopiesのフォーマット指示を取得すると、以下のような指示が得られます。このフォーマット指示を、出力形式を指定したいプロンプトに付加することで、出力形式の指示を与えることができます。

The output should be formatted as a JSON instance that conforms to the JSON schema below.

As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}
the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.

Here is the output schema:
{"$defs": {"Copy": {"description": "コピー", "properties": {"title": {"description": "タイトル(例:案1, 案2, ..)", "title": "Title", "type": "string"}, "reason": {"description": "回答の理由", "title": "Reason", "type": "string"}, "copy_text": {"description": "キャッチコピー", "title": "Copy Text", "type": "string"}}, "required": ["title", "reason", "copy_text"], "title": "Copy", "type": "object"}}, "description": "コピーの出力形式", "properties": {"copies": {"description": "キャッチコピーのリスト", "items": {"$ref": "#/$defs/Copy"}, "title": "Copies", "type": "array"}}, "required": ["copies"]}

この仕組みを利用することで、プロンプトを意識することなく、スキーマ形式を変更するだけで出力形式を変更することが可能となります。

GraphBuliderの定義

LangGraph のノード管理を簡素化するため、グラフ処理のラッパークラスGraphBuilderを独自に作成しています。このクラスを利用することで、ノード名をハードコードすることなく、グラフ構造を定義できるため、コードの保守性と再利用性が向上します。

agent/graph.py
class GraphBuilder:
    def __init__(self, state: State) -> None:
        self.work_flow: StateGraph = StateGraph(state)

    def add_node(self, node: NodeType) -> None:
        self.work_flow.add_node(node.name, node.func)

    def add_edge(self, from_node: NodeType, to_node: NodeType) -> None:
        self.work_flow.add_edge(from_node.name, to_node.name)

    def set_finish_point(self, end_node: NodeType) -> None:
        self.work_flow.set_finish_point(end_node.name)

    def set_entry_point(self, node: NodeType) -> None:
        self.work_flow.set_entry_point(node.name)

    def add_conditional_edges(
        self,
        from_node: NodeType,
        condition_func: Callable,
        path_map: Dict[str, Union[str, Any]],
    ) -> None:
        self.work_flow.add_conditional_edges(from_node.name, condition_func, path_map)

    def compile_flow(self) -> CompiledStateGraph:
        return self.work_flow.compile()

Graphの定義

agent/agent.pyでグラフ構造を構築しています。

graph_builderを使用して、ノードとエッジをグラフに追加していきます。add_nodeでノードを、add_edgeでノード間の接続を定義します。また、add_conditional_edgesでは、条件分岐の接続を定義しています。
set_entry_pointset_finish_pointで開始ノードと終了ノードを指定します。

その後、interrupt_beforeに一時停止を行うノードを指定し、グラフのコンパイルを行っています。

agent/agent.py
class Agent:
    def __init__(
        self,
        llm: LLM,
        prompt: Dict[str, Dict[str, str]],
    ) -> None:
        # ================
        # Init
        # ================
        graph_builder = GraphBuilder(State)
        self.node = Node(llm, prompt)

        # ================
        # Build Graph
        # ================
        # Add nodes
        graph_builder.add_node(self.node.generate_copy)
        graph_builder.add_node(self.node.user_select_copy)
        graph_builder.add_node(self.node.reflect_copy)
        graph_builder.add_node(self.node.user_input_additioal_info_copy)
        graph_builder.add_node(self.node.end)

        # Add edges
        graph_builder.add_edge(self.node.generate_copy, self.node.user_select_copy)
        graph_builder.add_conditional_edges(
            self.node.user_select_copy,
            self.node.should_rethink,
            {
                "reflect": self.node.reflect_copy.name,
                "end": self.node.end.name,
            },
        )
        graph_builder.add_edge(
            self.node.reflect_copy, self.node.user_input_additioal_info_copy
        )
        graph_builder.add_edge(
            self.node.user_input_additioal_info_copy, self.node.generate_copy
        )

        # Set entry and finish point
        graph_builder.set_entry_point(self.node.generate_copy)
        graph_builder.set_finish_point(self.node.end)

        # Set up memory
        self.memory = MemorySaver()

        self.graph = graph_builder.work_flow.compile(
            checkpointer=self.memory,
            interrupt_before=[
                self.node.user_select_copy.name,
                self.node.user_input_additioal_info_copy.name,
            ],
        )

interrupt_beforeについての補足
LangGraphでは、グラフコンパイル時の引数interrupt_beforeで、グラフの実行を一時停止するブレークポイントを設定することができます。ブレークポイントとは、指定したノードの実行前、または実行後にグラフの実行を一時停止する機能であり、停止のタイミングはグラフコンパイル時の引数interrupt_beforeまたは interrupt_afterで制御できます。
ブレークポイントを使用する場合は、必ずチェックポインターを利用し、グラフの状態を永続化しておく必要があります。グラフの実行再開時には、以下のようにグラフの入力にNoneを指定するだけで、停止直前のグラフの状態を引き継いで一時停止したノードから実行を再開できます。

agent.graph.stream(None, thread, stream_mode="values")

https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/breakpoints/#simple-usage

StreamlitでHuman-in-the-loopの実装

本章では、StreamlitとLangGraphを連携させる具体的な実装方法について説明します。LangGraphと連携するためには、グラフの状態を取得しつつ、グラフの実行(ブレイクポイントからのグラフ再実行)を行うことが重要です。そこで、グラフの状態を取得しながら各ノードに応じた適切な処理を実装する方法について、具体的なコードを用いながら詳しく解説します。

※ プログラムの一部のみ抜粋していますので、適宜githubコードを参照ください。

メイン処理実行準備

まず、LangGraphフロー実行に必要なStreamlitの諸々の設定を行います。

以下のコードには、次の処理が含まれています。

  • ページconfigの設定
  • LLMとプロンプトの初期化
  • セッション管理の初期化
    • Streamlit Session Stateは別ファイルに切り出して管理しています
  • ユーザー入力フォームの表示
    • 初期表示される入力UIも別ファイルに切り出しています
  • 過去のメッセージ履歴を表示
    • Streamlitではユーザーによる入力やボタンのクリックなどのイベントが発生すると、アプリケーション全体が再実行され、変数がリセットされます
    • そのため、Session Stateに保存したメッセージを表示する必要があります
app.py
MODEL = "claude-3-5"  # specify "gpt-4o" or "claude-3-5"
THREAD_ID = "1"
PROMPT_PATH = "agent/prompt/prompt_templates.yaml"
TEMPERATURE = 1.0

def main() -> None:
    # ================
    # Page Config
    # ================
    st.set_page_config(
        page_title="Streamlit×LangGraph コピー生成",
        page_icon="🤖",
        initial_sidebar_state="auto",
    )
    st.title("Streamlit×LangGraph コピー生成")

    # ================
    # Init Actor
    # ================
    llm = LLM(MODEL, TEMPERATURE)
    prompt: Dict[str, Dict[str, str]] = load_yaml(PROMPT_PATH)

    # ================
    # Streamlit Session State
    # ================
    session_manager = SessionManager(llm=llm, prompt=prompt)
    agent = session_manager.get_agent()

    # ================
    # Input
    # ================
    product_info = input_form()

    # ================
    # Display
    # ================
    display_history(session_manager.get_messages())

グラフ実行部分

次に、LangGraphのグラフ実行部分について説明します。

まず、グラフ実行に必要なthreadinitial_inputを定義し、これらを用いてグラフの実行準備をします。initial_inputにはStateの初期値を設定します。

app.py
    thread = {"configurable": {"thread_id": THREAD_ID}}
    initial_input = {
        "product_info": product_info,
        "iteration_count": 0,
        "is_finish": False,
        "display_message_dict": None,
        "messages": [],
        "additional_info_input": "",
    }

while TrueのループがLang Graphのグラフ実行部分です。
グラフが終了ノードに到達するまではグラフの実行を続ける(break pointで止まったとしても、ユーザー入力があればグラフの実行を続ける)必要があるため、最終ノードでのみbreakするループを作成します。

ループの中では、グラフの状態(位置)によって処理を分けます。

条件 処理
開始ノードの場合 initial_inputを設定してグラフ実行
次のノードが存在する場合
(開始ノード以外で)
次のノード名に応じて処理を分岐
終了ノードの場合 生成されたコピーを表示し、while Trueループを抜ける
app.py
    while True:
        # 開始ノードの場合
        if agent.is_start_node(thread):
            # グラフの実行
            stream_graph(agent, initial_input, thread, session_manager)

        # 次のノードがある場合
        next_graph: tuple[str, ...] | Any = agent.get_next_node(thread)
        if next_graph:
            if next_graph[0] == agent.node.user_select_copy.name:
                select_item(
                    agent=agent,
                    thread=thread,
                    state_key="copies",
                    selectbox_message="お気に入りのキャッチコピーを選択してください",
                    state_update_key="selected_copy",
                    as_node=next_graph[0],
                )
            elif next_graph[0] == agent.node.user_input_additioal_info_copy.name:
                input_additional_info(
                    agent=agent,
                    thread=thread,
                    as_node=next_graph[0],
                )

            # グラフの実行
            stream_graph(agent, None, thread, session_manager)

        # 終了ノードの場合
        if agent.is_end_node(thread):
            selected_copy = agent.get_state_value(thread, "selected_copy")
            st.success(f"生成したコピー: {selected_copy["copy_text"]}")
            break

ここで、開始・終了ノードの判定(agent.is_start_node(), agent.is_end_node())や、次のノード名取得(agent.get_next_node())は以下のHelper関数を作成し、使用しています。また、get_state_valueはStateを取得する関数です。

agent/agent.py
    def is_start_node(self, thread: dict) -> bool:
        return self.graph.get_state(thread).created_at is None

    def is_end_node(self, thread: dict) -> bool:
        return self.get_state_value(thread, "is_finish")

    def get_next_node(self, thread: dict) -> tuple[str, ...]:
        return self.graph.get_state(thread).next
    
    def get_state_value(
        self, thread: dict, name: str
    ) -> Union[dict[str, Any], Any, None]:
        state = self.graph.get_state(thread)
        if state and name in state.values:
            return state.values.get(name)
        return None

dammy_endノードを作成していたのは、アプリケーションの終了判定を行うためです。
dammy_endノード内で、is_finish StateをTrueに更新しています。

グラフの実行、再実行には事前適宜した関数stream_graphを使用しています。

utils/app_util.py
def stream_graph(
    agent: Agent,
    input: Dict | None,
    thread: Dict,
    session_manager: SessionManager,
) -> None:
    """
    グラフのストリーミングを行う
    """
    events = agent.graph.stream(input, thread, stream_mode="values")
    for event in events:  # event is state in each node.
        # get synchronized result. You should not get state from thread before update completely.
        if display_message_dict := event.get("display_message_dict"):
            # 表示
            _display_message(display_message_dict)
            # messageの保存
            session_manager.save_message_to_session_state(display_message_dict)

グラフのstream実行の際に、agent.get_state_valueによってStateを取得すると、Stateの一部の要素が空になって返ってくることがありました。これは恐らく、Iteratorから取得したevent(State)と、threadから取得するStateの同期に遅延が発生しているためです。

events = agent.graph.stream(input, thread, stream_mode="values")
for event in events:
    if display_message_dict := agent.get_state_value(thread, 'display_message_dict')

そこで、完全に更新されたグラフ結果を得るため、event.get()を利用するようにしました。この修正により、Stateが空になる問題が解消され、各ステップの実行結果を正確に取得できるようになりました。

events = agent.graph.stream(input, thread, stream_mode="values")
for event in events:  # event is state in each node.
    # get synchronized result. You should not get state from thread before update completely.
    if display_message_dict := event.get("display_message_dict"):

各ノードに応じた処理分岐

上述の通り、次のノードが存在する場合は、以下のように次のノード名に応じて処理を分岐させます。

if next_graph[0] == agent.node.user_select_copy.name:

コードの肥大化を防ぐために、処理は別ファイルに切り出しています。
ここでは、次のノードがuser_input_additioal_info_copyの場合の処理を説明します。

app.py
elif next_graph[0] == agent.node.user_input_additioal_info_copy.name:
    input_additional_info(
        agent=agent,
        thread=thread,
        as_node=next_graph[0],
    )

input_additional_info関数では、ユーザーに追加情報を求めるための入力フィールドを表示し、その情報をStateに保存します。

utils/app_user_input_logic.py
def input_additional_info(agent: Agent, thread: dict, as_node: str) -> None:
    """
    追加情報入力関数。

    Args:
        agent (Agent): エージェントオブジェクト
        thread (dict): 現在のスレッドの辞書
        as_node (str): ステート更新ノード名
    """
    additional_info = agent.get_state_value(thread, "additional_info")

    # ユーザー情報入力
    additional_info_input = st.text_input(f"{additional_info}」を入力してください")

    if not st.button(
        "次へ",
        disabled=not bool(additional_info_input),
        key=as_node,
    ):
        print("User Input Stop")
        st.stop()  # この時点で処理が停止

    print("Entered User Input: ", additional_info_input)

    agent.graph.update_state(
        thread,
        {
            "additional_info_input": additional_info_input,
            "display_message_dict": None,
        },
        as_node=as_node,
    )

「次へ」ボタンを押下されていない場合には次の処理を実行しないように、st.stop()を使用しています。st.stop()の使い方はStreamlitのドキュメントを参照してください。

まとめ

本記事では、StreamlitとLangGraphを連携させてHuman-in-the-loopな広告コピー生成アプリケーションを実装する方法を解説しました。本記事を参考にぜひ、Streamlit×LangGraphのアプリケーションを実装してみてください!

謝辞

本実装は、同じチームの@ren8kさんと共に取り組みました。
また、アプリケーションに関しては、同チームの藤田さんから多くのアドバイスをいただきました。
この場を借りて、心より感謝申し上げます。

仲間募集

NTTデータ テクノロジーコンサルティング事業本部 では、以下の職種を募集しています。

1. クラウド技術を活用したデータ分析プラットフォームの開発・構築(ITアーキテクト/クラウドエンジニア)

クラウド/プラットフォーム技術の知見に基づき、DWH、BI、ETL領域におけるソリューション開発を推進します。
https://enterprise-aiiot.nttdata.com/recruitment/career_sp/cloud_engineer

2. データサイエンス領域(データサイエンティスト/データアナリスト)

データ活用/情報処理/AI/BI/統計学などの情報科学を活用し、よりデータサイエンスの観点から、データ分析プロジェクトのリーダーとしてお客様のDX/デジタルサクセスを推進します。
https://enterprise-aiiot.nttdata.com/recruitment/career_sp/datascientist

3.お客様のAI活用の成功を推進するAIサクセスマネージャー

DataRobotをはじめとしたAIソリューションやサービスを使って、
お客様のAIプロジェクトを成功させ、ビジネス価値を創出するための活動を実施し、
お客様内でのAI活用を拡大、NTTデータが提供するAIソリューションの利用継続を推進していただく人材を募集しています。
https://nttdata.jposting.net/u/job.phtml?job_code=804

4.DX/デジタルサクセスを推進するデータサイエンティスト《管理職/管理職候補》 データ分析プロジェクトのリーダとして、正確な課題の把握、適切な評価指標の設定、分析計画策定や適切な分析手法や技術の評価・選定といったデータ活用の具現化、高度化を行い分析結果の見える化・お客様の納得感醸成を行うことで、ビジネス成果・価値を出すアクションへとつなげることができるデータサイエンティスト人材を募集しています。

https://nttdata.jposting.net/u/job.phtml?job_code=898

ソリューション紹介

Trusted Data Foundationについて

~データ資産を分析活用するための環境をオールインワンで提供するソリューション~
https://www.nttdata.com/jp/ja/lineup/tdf/
最新のクラウド技術を採用して弊社が独自に設計したリファレンスアーキテクチャ(Datalake+DWH+AI/BI)を顧客要件に合わせてカスタマイズして提供します。
可視化、機械学習、DeepLearningなどデータ資産を分析活用するための環境がオールインワンで用意されており、これまでとは別次元の量と質のデータを用いてアジリティ高くDX推進を実現できます。

TDFⓇ-AM(Trusted Data Foundation - Analytics Managed Service)について

~データ活用基盤の段階的な拡張支援(Quick Start) と保守運用のマネジメント(Analytics Managed)をご提供することでお客様のDXを成功に導く、データ活用プラットフォームサービス~
https://www.nttdata.com/jp/ja/lineup/tdf_am/
TDFⓇ-AMは、データ活用をQuickに始めることができ、データ活用の成熟度に応じて段階的に環境を拡張します。プラットフォームの保守運用はNTTデータが一括で実施し、お客様は成果創出に専念することが可能です。また、日々最新のテクノロジーをキャッチアップし、常に活用しやすい環境を提供します。なお、ご要望に応じて上流のコンサルティングフェーズからAI/BIなどのデータ活用支援に至るまで、End to Endで課題解決に向けて伴走することも可能です。

NTTデータとDatabricksについて NTTデータは、お客様企業のデジタル変革・DXの成功に向けて、「databricks」のソリューションの提供に加え、情報活用戦略の立案から、AI技術の活用も含めたアナリティクス、分析基盤構築・運用、分析業務のアウトソースまで、ワンストップの支援を提供いたします。

https://www.nttdata.com/jp/ja/lineup/databricks/

NTTデータとTableauについて

ビジュアル分析プラットフォームのTableauと2014年にパートナー契約を締結し、自社の経営ダッシュボード基盤への採用や独自のコンピテンシーセンターの設置などの取り組みを進めてきました。さらに2019年度にはSalesforceとワンストップでのサービスを提供開始するなど、積極的にビジネスを展開しています。

これまでPartner of the Year, Japanを4年連続で受賞しており、2021年にはアジア太平洋地域で最もビジネスに貢献したパートナーとして表彰されました。
また、2020年度からは、Tableauを活用したデータ活用促進のコンサルティングや導入サービスの他、AI活用やデータマネジメント整備など、お客さまの企業全体のデータ活用民主化を成功させるためのノウハウ・方法論を体系化した「デジタルサクセス」プログラムを提供開始しています。
https://www.nttdata.com/jp/ja/lineup/tableau/

NTTデータとAlteryxについて
Alteryxは、業務ユーザーからIT部門まで誰でも使えるセルフサービス分析プラットフォームです。

Alteryx導入の豊富な実績を持つNTTデータは、最高位にあたるAlteryx Premiumパートナーとしてお客さまをご支援します。

導入時のプロフェッショナル支援など独自メニューを整備し、特定の業種によらない多くのお客さまに、Alteryxを活用したサービスの強化・拡充を提供します。

https://www.nttdata.com/jp/ja/lineup/alteryx/

NTTデータとDataRobotについて
DataRobotは、包括的なAIライフサイクルプラットフォームです。

NTTデータはDataRobot社と戦略的資本業務提携を行い、経験豊富なデータサイエンティストがAI・データ活用を起点にお客様のビジネスにおける価値創出をご支援します。

https://www.nttdata.com/jp/ja/lineup/datarobot/

NTTデータとInformaticaについて

データ連携や処理方式を専門領域として10年以上取り組んできたプロ集団であるNTTデータは、データマネジメント領域でグローバルでの高い評価を得ているInformatica社とパートナーシップを結び、サービス強化を推進しています。
https://www.nttdata.com/jp/ja/lineup/informatica/

NTTデータとSnowflakeについて
NTTデータでは、Snowflake Inc.とソリューションパートナー契約を締結し、クラウド・データプラットフォーム「Snowflake」の導入・構築、および活用支援を開始しています。

NTTデータではこれまでも、独自ノウハウに基づき、ビッグデータ・AIなど領域に係る市場競争力のあるさまざまなソリューションパートナーとともにエコシステムを形成し、お客さまのビジネス変革を導いてきました。
Snowflakeは、これら先端テクノロジーとのエコシステムの形成に強みがあり、NTTデータはこれらを組み合わせることでお客さまに最適なインテグレーションをご提供いたします。

https://www.nttdata.com/jp/ja/lineup/snowflake/

35
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?