0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

作って理解するMLflow ChatAgetのresource指定

Posted at

導入

MLflowでカスタムモデルをロギングする際にはmlflow.pyfunc.save_modelmlflow.pyfunc.log_modelメソッドを使用しますが、これらのメソッドの引数にはresourcesというオプションパラメータが存在します。

以前から存在は知っていたのですが、正直何に使うものか全然理解していませんでした。
ただ、今更ながら以下のDocumentを読んでようやく理解しました。

個人的備忘録も兼ねて、resourcesパラメータを指定する場合としない場合の挙動を簡単に確認します。

検証はDatabricks on AWS上で行いました。ノートブックの実行はサーバレスクラスタを利用しています。

何に使うの?

上記のDocumentより抜粋。

自動認証パススルーのリソースの指定

AIエージェントは、タスクを完了するために他のリソースに対して認証する必要があることがよくあります。 たとえば、エージェントは、非構造化データをクエリするためにベクトル検索インデックスにアクセスする必要がある場合があります。

依存リソースの認証で説明されているように、モデルサービングは、エージェントをデプロイするときに、 Databricksマネージド リソースと外部リソースの両方に対する認証をサポートします。

最も一般的な Databricks リソースの種類については、Databricks では、ログ記録中にエージェントのリソース依存関係を事前に宣言することをサポートし、推奨しています。 これにより、エージェントをデプロイするときに 自動認証パススルー が有効になり、Databricks は、エージェント エンドポイント内からこれらのリソース依存関係に安全にアクセスするための有効期間の短い資格情報を自動的にプロビジョニング、ローテーション、管理します。

自動認証パススルーを有効にするには、次のコードに示すように、log_model() API の resources パラメーターを使用して依存リソースを指定します。

というわけで、resourcesパラメータで指定したDatabricksのマネージドリソース(ベクトル検索インデックスやモデルサービングエンドポイントなど)は、認証のパススルーが行われます。

概ね理解はできるのですが、より理解を深めるためにDatabricksモデルサービング上にAIエージェントをデプロイし、resourcesを指定するケース/しないケースでの動作を確認してみます。

Step1. 試験用エージェントを作る

検証用に、簡単なReActエージェントをmlflow+langgraphで作成してDatabricks上にサービングします。

まず、ノートブックを作成して必要なパッケージをインストール。

%pip install -U -qqqq databricks-langchain databricks-agents>=0.16.0 mlflow-skinny[databricks] langgraph uv

%restart_python

次にMLflowのカスタムChatAgentインターフェースを使って、エージェント処理を定義します。
今回はLangGraphのReActエージェントをラップしたツールを実行するエージェントを定義します。
(Databricks Documentにも同様のエージェントサンプルがありますのでお好みで使い分けてください)

処理内で利用するLLMには、以下の記事でDatabricks Mosaic AI Model Serving上にデプロイしたエンドポイントを利用しています。

%%writefile agent.py
from typing import Literal, Generator, List, Optional, Any, Dict, Mapping, Union
import uuid

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    UCFunctionToolkit,
)
from langchain_core.tools import tool
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import create_react_agent

from langchain_core.messages import BaseMessage, convert_to_openai_messages
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from databricks_langchain import ChatDatabricks
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from functools import reduce

class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        usages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    self._convert_lc_message_to_chat_message(msg)
                    for msg in node_data.get("messages", [])
                )
                usages.extend(
                    msg.response_metadata for msg in node_data.get("messages", [])
                )

        usage = self._sum_usages(usages)
        return ChatAgentResponse(messages=messages, usage=usage)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(
                        **{"delta": self._convert_lc_message_to_chat_message(msg)},
                        usage=msg.response_metadata
                    )
                    for msg in node_data["messages"]
                )

    def _convert_lc_message_to_chat_message(
        self, lc_message: BaseMessage
    ) -> ChatAgentMessage:
        """Convert a LangChain message to a ChatAgentMessage."""
        msg = convert_to_openai_messages(lc_message)
        if not "id" in msg:
            msg.update({"id": str(uuid.uuid4())})

        return ChatAgentMessage(**msg)

    def _sum_usages(self, usages: list[dict]) -> dict:
        """Sum up the usages from the list of usages."""

        def add_usages(a: dict, b: dict) -> dict:
            pt = "prompt_tokens"
            ct = "completion_tokens"
            tt = "total_tokens"
            return {
                pt: a.get(pt, 0) + b.get(pt, 0),
                ct: a.get(ct, 0) + b.get(ct, 0),
                tt: a.get(tt, 0) + b.get(tt, 0),
            }

        return reduce(add_usages, usages, {})

## テスト用のツール
@tool
def get_weather(city: Literal["nyc", "sf"]):
    """Use this to get weather information."""
    if city == "nyc":
        return "It might be cloudy in nyc"
    elif city == "sf":
        return "It's always sunny in sf"
    else:
        raise AssertionError("Unknown city")

# Databricks MosaicAI Model Serving上のエンドポイント名
LLM_ENDPOINT_NAME = "sglang_qwen_bakeneko_32b_v2_awq_endpoint"
llm = ChatDatabricks(model=LLM_ENDPOINT_NAME)

tools = [get_weather]

graph = create_react_agent(llm, tools=tools)
AGENT = LangGraphChatAgent(graph)
mlflow.models.set_model(AGENT)

これで検証に使うエージェントの定義が準備できました。
では、実際にResoucesを指定するケース・指定しないケースごとにエージェントを利用してみます。

Step2. Resoucesを指定せずに使う

まずはresoucesを指定しないケースです。
Step1で作成したカスタムエージェントをロギングします。

import mlflow
from databricks import agents

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        pip_requirements=[
            "mlflow",
            "langgraph==0.3.18",
            "databricks-langchain==0.4.0",
        ],
        # resources=resources, # 指定しない
    )

ロギングしたエージェントをDatabricksのモデルサービングにデプロイします。

mlflow.set_registry_uri("databricks-uc")

catalog = "training"
schema = "llm"
model_name = "chat_agent_sample"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"
run_id = logged_agent_info.run_id

# Unity Catalogへの登録
uc_registered_model_info = mlflow.register_model(
    model_uri=f"runs:/{run_id}/agent", name=UC_MODEL_NAME
)
# モデルサービングへのデプロイ
agents.deploy(
    UC_MODEL_NAME, uc_registered_model_info.version
)

2025年3月現在、東京リージョンではagents.deploy実行時にエラーが出ます。
ただ、モデルサービングへのデプロイ自体は処理されます。

早くMosaic AI Agent Frameworkの対応来て欲しいなあ。

一定時間経過後、デプロイが完了します。(サービングメニューから確認可能)

image.png

このエージェントをPlaygroundから利用してみましょう。

image.png

サンプルの質問を試してみます。

image.png

エラーが出ました。

メッセージを読むとわかるのですが、Credentialが読めない旨の内容です。

今回のエージェントは内部でDatabricks MosaicAI Model Servingのエンドポイントを利用しているのですが、そこに対する利用資格がないために起こっているものと推測します。

例えば環境変数にAPIトークンを設定して利用するなど適切に資格情報を管理すれば良さそうですが、リソースの種類が増えていくとなかなか手動で管理することは大変です。

では、resourcesを指定するケースで試してみます。

Step3. Resoucesを指定して使う

では、再度エージェントのロギングからやり直します。

今度は、resourcesパラメータを設定します。

import mlflow
from agent import tools, LLM_ENDPOINT_NAME
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

# モデルサービングのエンドポイントをリソースとして含める
resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
# ツール内にDatabricksのマネージドリソースが含まれる場合、resourcesに加える(今回は無い)
for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        pip_requirements=[
            "mlflow",
            "langgraph==0.3.18",
            "databricks-langchain==0.4.0",
        ],
        resources=resources, # 指定!
    )

今回の対象はモデルサービングのエンドポイントです。
これをDatabricksServingEndpointにエンドポイント名を指定してリソースに加えています。

あとはStep2と同様にデプロイします。

mlflow.set_registry_uri("databricks-uc")

catalog = "training"
schema = "llm"
model_name = "chat_agent_sample"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"
run_id = logged_agent_info.run_id

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=f"runs:/{run_id}/agent", name=UC_MODEL_NAME
)
agents.deploy(
    UC_MODEL_NAME, uc_registered_model_info.version
)

デプロイが終わったら同様にPlaygroundで試してみます。

image.png

今度は正常に実行できました。

resoucesに資格情報をパススルーするリソースを指定したため、エージェント利用者の資格情報を使ってその先のLLMエンドポイントへアクセスできたためだと理解しています。

というわけで、MLflowでエージェントを作成し、Databricksにデプロイする際はきちんとresourcesを設定しましょう。多くのユースケースで資格情報パススルーを使うことで利便性が高まりセキュリティも守られると思います。

余談ですがツール呼び出しエージェントとしてもちゃんと動作します。

image.png

まとめ

以前から疑問だったresourcesパラメータの利用について備忘録としてまとめました。
これまでは環境変数でCredintialを連携したりしていたのですが、Databricksリソースについて言えばresourcesを使った資格情報パススルーを積極的に使うべきだなと思います。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?