0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

生成AIに関する記事を書こう!
Qiita Engineer Festa20242024年7月17日まで開催中!

DatabricksとLangGraphで学ぶAgenticアプローチ: 並行実行のためのMap-Reduceブランチ

Posted at

導入

LangGraphのHow-to Guideウォークスルーの3回目です。
今回は、下記の内容である「並列実行用のmap-reduceブランチを作成」をしてみます。

検証はDatabricks on AWS、DBRは15.3MLを使っています。

並列実行用の map-reduce ブランチとは

以下、公式ドキュメントの序文を邦訳。

並列実行用の map-reduce ブランチを作成する方法

エージェントの一般的なパターンは、オブジェクトのリストを生成し、それらの各オブジェクトに対して何らかの作業を行い、その結果を結合することです。
これは、一般的な map-reduce 操作と非常によく似ています。これは、いくつかの理由で注意が必要です。
第 1 に、オブジェクトのリストの長さが不明な場合があるため、構造化されたグラフを事前に定義するのは難しい場合があります。
第二に、このmap-reduceを行うには、複数のバージョンの状態が存在する必要があります。
...しかし、グラフは共通の共有状態を共有しているので、これはどのようになるのでしょうか?

LangGraphはSend APIを介してこれをサポートしています。
これを使用して、条件付きエッジを複数のノードに対して複数の異なる状態を送ることができます。
送信される状態は、コアグラフの状態とは異なる場合があります。

これが実際にどのようなものか見てみましょう。
単語のリストを生成し、各単語についてジョークを書き、最高のジョークを判断するおもちゃの例をまとめます。

なかなか表現が難解ですね。。。
ただ、記載されているようにエージェントは複数の出力をまず生成し、それらに対して処理を実行した上で結果を集約する流れが多いと思います。
これをMapReduceに似ていると表現しており、その流れをLangGraphで実現するための方法を解説しようとしています。

主にはLangGraphのSend APIを使って条件付きエッジを構成するやり方となるようです。

実際にサンプルコードを動かして、理解を進めてみましょう。

Step1. パッケージインストール

LangGraphやLangChainなど、必要なパッケージをインストール。

%pip install -U langgraph==0.1.4 langchain==0.2.6 langchain-community==0.2.6 mlflow-skinny[databricks]==2.14.1 pydantic==2.7.4
dbutils.library.restartPython()

Step2. Map-Reduceブランチの作成

では、今回の骨子であるMap-Reduceのように動作するグラフを構築します。

まず、グラフの中で利用するLLMを準備します。

公式ドキュメントのサンプルではOpenAIを使ったJSON出力を実装していましたが、
今回はDatabricks Model Serving上で構築した構造化出力可能なエンドポイントを用いることにします。
(もちろんJSONモード等に対応したOpenAIやAnthropicのAPIを使っても問題ありません)

詳しくは、以下の記事を参照ください。

import mlflow.deployments
import pandas as pd
import json

client = mlflow.deployments.get_deploy_client("databricks")
endpoint_name = "bulk-inference-endpoint" # Databricks Model Servingで構築した構造化出力用エンドポイント

def llm_invoke(prompt, baseclass):
    """ LLMエンドポイントに対してプロンプトを発行し、結果を指定したPydantic Baseクラスのスキーマを使って構造化出力する """

    system_prompt = f"You are a helpful assistant, reply in Japanese.You MUST answer using the following json schema:{baseclass.schema_json()}"
    inputs = pd.DataFrame(
        [
            {"prompt": prompt},
        ]
    )
    response = client.predict(
        endpoint=endpoint_name,
        inputs={
            "inputs": inputs.to_dict(orient="records"),
            "params": {
                "temperature": 0.1,
                "top_k": 0,
                "top_p": 0.6,
                "max_tokens": 512,
                "system_prompt": system_prompt,
                "json_schema": baseclass.schema_json(),
            },
        },
    )

    json_resp = json.loads(response["predictions"][0]["output"])
    return baseclass.parse_obj(json_resp)

次に、公式ドキュメントの内容をベースに、グラフを構築します。
(オリジナルに比べて、プロンプトやコメント文を邦訳しています)

import operator
from typing import Annotated, TypedDict

from langchain_core.pydantic_v1 import BaseModel
from langgraph.constants import Send
from langgraph.graph import END, StateGraph

import mlflow

# モデルとプロンプト
# 使用するモデルとプロンプトを定義します
subjects_prompt = """2つから5つの{topic}のリストを生成してください。"""
joke_prompt = """{subject}についてのジョークを生成してください。"""
best_joke_prompt = """以下は{topic}に関するいくつかのジョークです。最も良いジョークを選んでください!最も良いジョークのIDを返します。IDは0から始まる連番です。

{jokes}"""


class Subjects(BaseModel):
    subjects: list[str]


class Joke(BaseModel):
    joke: str


class BestJoke(BaseModel):
    id: int


# グラフのコンポーネント:グラフを構成するコンポーネントを定義します


# これはメイングラフの全体の状態になります。
# ユーザーがトピックを提供することを期待して、
# サブジェクトのリストを生成し、それぞれのサブジェクトに対してジョークを生成します
class OverallState(TypedDict):
    topic: str
    subjects: list
    # ここではoperator.addを使用しています
    # 個々のノードから生成されたジョークをすべて1つのリストに結合したいためです。
    # これは基本的に「reduce」の部分です
    jokes: Annotated[list, operator.add]
    best_selected_joke: str


# ジョークを生成するためにサブジェクトを「マップ」するノードの状態です
class JokeState(TypedDict):
    subject: str


# ジョークのサブジェクトを生成するための関数です
@mlflow.trace(span_type="node")
def generate_topics(state: OverallState):
    prompt = subjects_prompt.format(topic=state["topic"])
    response = llm_invoke(prompt, Subjects)
    return {"subjects": response.subjects}


# サブジェクトを指定してジョークを生成します
@mlflow.trace(span_type="node")
def generate_joke(state: JokeState):
    prompt = joke_prompt.format(subject=state["subject"])
    response = llm_invoke(prompt, Joke)
    return {"jokes": [response.joke]}


# 生成されたサブジェクトに対してジョークを続けるためのロジックを定義します
@mlflow.trace(span_type="node")
def continue_to_jokes(state: OverallState):
    # `Send`オブジェクトのリストを返します
    # 各`Send`オブジェクトは、グラフ内のノードの名前とそのノードに送信する状態で構成されています
    return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]


# 最も良いジョークを判断します
@mlflow.trace(span_type="node")
def best_joke(state: OverallState):
    jokes = "\n\n".format()
    prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
    response = llm_invoke(prompt, BestJoke)
    return {"best_selected_joke": state["jokes"][response.id]}


# グラフを構築します:ここですべてを組み合わせてグラフを構築します
graph = StateGraph(OverallState)
graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)
graph.set_entry_point("generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes)
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)
app = graph.compile()

出来たグラフを可視化すると以下の様になります。

from IPython.display import Image, display

display(Image(app.get_graph().draw_mermaid_png()))

image.png

generate_topicsノードから各ノードに対して分岐可能なようにグラフが可視化されます。

※ 実際にはgenerate_jokeにのみ処理を実行させる動作となります。

では、実行させてみましょう。

# グラフを呼び出します:ここでジョークのリストを生成するために呼び出します
with mlflow.start_span("graph", span_type="AGENT") as span:
    for s in app.stream({"topic": "animals"}):
        print(s)
出力
{'generate_topics': {'subjects': ['lion', 'tiger', 'elephant', 'giraffe', 'zebra']}}
{'generate_joke': {'jokes': ['Why did the elephant cross the road? To prove to the other elephants that it could be done.']}}
{'generate_joke': {'jokes': ["Why did the giraffe cross the road? To prove that he wasn't chicken!"]}}
{'generate_joke': {'jokes': ['ライオンが王様だと思っていたら、実はライオンキングだった。']}}
{'generate_joke': {'jokes': ['ズーラシャンプー']}}
{'generate_joke': {'jokes': ['Why did the tiger cross the road? To get to the other side of the jungle.']}}
{'best_joke': {'best_selected_joke': 'Why did the elephant cross the road? To prove to the other elephants that it could be done.'}}

処理の流れを追ってみます。

まず、generate_topicsノードにて、ジョークを生成するためのトピックを2~5個生成します。
今回は、Animalをお題としてトピックが5個生成されました。

次に生成されたトピックごとに、generate_jokeを呼び出します。
これを担っているのがcontinue_to_jokes関数の処理で、条件分岐エッジとして登録されています。

continue_to_jokes関数は内部でSend APIを呼び出しており、各トピックの内容を入力としてgenerate_jokeノードをSend APIでラップした結果を返しています。
これによって、generate_jokeノードが5回並列に実行され、各結果が状態に保管されます。

最後にbest_jokeノードが実行され、generate_jokeノードが生成したジョークの中から一つのジョークを選択します。

このように、ぱっと見ノードはそれぞれ1回実行されるグラフとして構成されていますが、Send APIを使うことで入力に応じて複数回ノードを実行させることができるようです。

まとめ

LangGraphの並列実行用の Map-Reduce ブランチ作成について、公式ドキュメントの内容をウォークスルーしてみました。

ノードの実行結果や状態を基に、Send APIによって特定のノードに処理を実行させてその結果を集約するという Map-Reduce の処理が実行できます。

便利だと思うのですが、複雑性が上がりそうで使用は注意が必要かもしれません。
どういったユースケースでの利用がよいのか、引き続き学んでいきたいと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?