2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Tool Calling Agentを作ってみる in Databricks

Last updated at Posted at 2024-10-03

導入

Databricks界隈(?)では、最近LLMによるAIエージェントが大人気です(たぶん)。
ちなみにDatabricksが考えるAIエージェントの定義は以下に記載されています。わかりやすい。

最近だと、御大の方々がQiita上でも記事を出されています。

上の記事のように、DatabricksはMosaic AI Agent Frameworkなどエージェント構築に適したフレームワークを提供しています。また、コアパッケージであるmlflowもエージェント対応が進んできているように思います。

ただ、まだDatabricks東京リージョンでは使えないエージェント関連機能も多く、悔しさに歯噛みしながら眠れない日々を過ごしています(ウソです)。

だからというわけではないのですが、勉強のためにちょっとしたエージェントを実装してみたいと思っていました。
丁度いいところにmlflow 2.17.0rc0でTool Callingが対応されたので、簡易Tool Calling Agentを実装してみます。

OpenAIなどのプロプライエタリAPIをそのまま使うのではなく 任意のローカルLLM(今回はMeta Llama 3.2)を使用したTool Calling なのが マニアックな 頑張ったポイントです。

実装はDatabricks on AWS(東京リージョン)、DBRは15.4MLで行いました。ノートブックのクラスタはサーバレスです。

かなりの簡易&独自突貫実装なので参考にする場合はご注意を。

Tool Calling Agentとは

今回作成するTool Calling Agentとは、いわゆるOpenAIなどがAPIとして提供しているTool(Function) Calling機能をLLMに対して付加するエージェントです。
Tool(Function) Callingは、与えられたツール/関数定義のリストとユーザのクエリを基にどのツール/関数をどういった引数で実行するかを取得する機能となります。(多くはJSON形式で結果を得ることができます)

その結果を基に実際にツール/関数の実行まで担うケースもありますが、今回は実行するべきツールとその引数を取得するエージェントを作成します。

下準備

エージェントから利用するLLMを準備するために、Databricks Mosaic AI Model Servingを使ってLLMエンドポイントを用意します。

Databricks Mosaic AI Model Servingといえばプロビジョニングされたスループット基盤APIが代表的です。が、今回は利用せずに以下記事の方法でMeta Llama 3.2 3B版のエンドポイントを作成して利用することにします。

使ったEXL2量子化済みモデルはこちら。

上記の記事で解説しているので、一旦コードは割愛。

なぜプロビジョニングされたスループット基盤APIを使わないかというと、上のやり方で作成したエンドポイントはlm-format-enforcerを使って構造化出力に対応させたためです。

Tool CallingにおいてはLLMからJSONなどの構造化テキストを高確率で得られるようにすることが非常に大事であり、なるべく楽したかった安定的に構造化出力できる方法を採用しました。

上記のエンドポイントはJSON形式の構造化出力に対応していますが、JSONスキーマの指定の仕方はOpenAIのStructured Outputと異なるやり方(メッセージの中に埋め込む)にしています。
これはMLflowのChatModelがOpenAI互換のStructured Outputインターフェースにまだ対応していないためです。(2.17.0RC0時点)

きっと次のバージョンあたりで対応してくるはず?

エージェントを作る

では、下準備で作成したLLMエンドポイントを利用したTool Calling エージェントを作成してみます。
基本的には下記記事のMLflow ChatModelを使ったエージェントのコードを踏襲しています。

Step1. パッケージインストール

ノートブックを作成し、MLflow 2.17とこの先使うlangchain/langgraphのパッケージをインストールします。

%pip install -q -U langchain-core==0.3.8 langgraph==0.2.34 langchain-databricks==0.1.0
%pip install -q -U typing-extensions
%pip install "mlflow-skinny[databricks]==2.17.0rc0"

dbutils.library.restartPython()

MLflow 2.17.0 RC版を利用しているため、正式リリース版では以下のコードが動作しないかもしれません。

Step2. エージェント用のクラスを定義

MLflowのカスタムChat Modelとして、Tool Callingを実行するエージェントクラスを定義します。
長いので折り畳み。

tool_calling_agent
%%writefile "./tool_calling_agent.py"

import mlflow
from mlflow.types.llm import (
    ChatResponse,
    ChatMessage,
    ChatParams,
    ChatChoice,
    FunctionToolCallArguments,
    ToolCall,
    ToolDefinition,
    FunctionToolDefinition,
    ToolParamsSchema,
)
from mlflow.pyfunc import ChatModel
from mlflow import deployments
from typing import List, Optional, Dict
from mlflow.models import set_model

import json
from langchain_databricks import ChatDatabricks
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers import StrOutputParser
import langchain_core.messages  # 名前衝突を避けるため
from pydantic import BaseModel


class ToolCallingAgent(ChatModel):
    def __init__(self):
        """プレースホルダー値でToolCallingAgentを初期化します。"""
        self.models = {}
        self.models_config = {}

    def load_context(self, context):
        """コネクタとモデル設定を初期化します。"""
        self.models = context.model_config.get("models", {})

    def predict(
        self, context, messages: List[ChatMessage], params: Optional[ChatParams] = None
    ) -> ChatResponse:
        """
        エージェントの会話を処理するための予測メソッド。

        Args:
            context: MLflowのコンテキスト。
            messages (List[ChatMessage]): 処理するメッセージのリスト。
            params (Optional[ChatParams]): 会話の追加パラメータ。

        Returns:
            ChatResponse: 構造化された応答オブジェクト。
        """
        # Fluent APIコンテキストハンドラを使用して、スパンに含まれる内容を制御
        with mlflow.start_span(name="Audit Agent") as root_span:
            # ユーザー入力をルートスパンに追加
            root_span.set_inputs(messages)

            # ルートスパンに属性を追加
            attributes = {**params.to_dict(), **self.models_config, **self.models}
            root_span.set_attributes(attributes)

            if not params.tools:  # Tool指定が無い場合、通常の応答を返す
                normal_endpoint = self._get_model_endpoint("normal")

                # 通常応答を作成
                response = self._build_normal_response(
                    messages, normal_endpoint, params.to_dict()
                )

                # レスポンスを作成
                output = ChatResponse(
                    choices=[
                        ChatChoice(
                            index=0,
                            message=ChatMessage(
                                role="assistant",
                                content=response,
                            ),
                        )
                    ],
                    usage={},
                    model=normal_endpoint,
                )
            else:  # Tool指定がある場合、実行すべきTool情報を引数付で返す
                # クエリに対する関数名を取得
                selector_endpoint = self._get_model_endpoint("selector")
                selector_params = self._get_model_params("selector")
                selector_sys_p = self._get_system_prompt("selector")
                tool_name = self._select_tool(
                    messages,
                    selector_sys_p,
                    params.tools,
                    selector_endpoint,
                    selector_params,
                )

                # ツール(関数)の引数情報を取得
                args_endpoint = self._get_model_endpoint("args")
                args_params = self._get_model_params("args")
                args_sys_p = self._get_system_prompt("args")
                tool_args = self._build_tool_args(
                    messages,
                    args_sys_p,
                    tool_name,
                    params.tools,
                    args_endpoint,
                    args_params,
                )

                # Tool Call情報を作成
                tool_call = FunctionToolCallArguments(
                    name=tool_name, arguments=json.dumps(tool_args)
                ).to_tool_call(
                    "1"
                )  # 手抜きで常にIDは1

                # レスポンスを作成
                output = ChatResponse(
                    choices=[
                        ChatChoice(
                            index=0,
                            message=ChatMessage(
                                role="assistant", content="", tool_calls=[tool_call]
                            ),
                        )
                    ],
                    usage={},
                    model=args_endpoint,
                )

            root_span.set_outputs(output)
        return output

    @mlflow.trace(name="Normal Response")
    def _build_normal_response(self, messages, endpoint: str, params: dict):
        """Build a normal response."""

        prompt = ChatPromptTemplate.from_messages(
            [
                ("placeholder", "{messages}"),
            ]
        )
        lc_messages = self._convert_messages_to_lc_messages(messages)
        chat_model = ChatDatabricks(endpoint=endpoint, **params)
        chain = prompt | chat_model | StrOutputParser()
        return chain.invoke({"messages": lc_messages})

    @mlflow.trace(name="Tool Selection")
    def _select_tool(
        self, messages, system_prompt: str, tools: list, endpoint: str, params: dict
    ):
        """Select the appropriate tool based on the query."""

        class FunctionSelector(BaseModel):
            function_name: str

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_prompt),
                ("placeholder", "{messages}"),
                # JSONスキーマの強制
                langchain_core.messages.ChatMessage(
                    role="json_schema", content=FunctionSelector.schema_json()
                ),
            ]
        )
        functions = [
            f"- Name: {t.function.name}, Description: {t.function.description}"
            for t in tools
        ]
        lc_messages = self._convert_messages_to_lc_messages(messages)

        chat_model = ChatDatabricks(endpoint=endpoint, **params)
        chain = prompt | chat_model | JsonOutputParser()
        function_name_result = chain.invoke(
            {"messages": lc_messages, "functions": functions}
        )

        return function_name_result["function_name"]

    @mlflow.trace(name="Tool Args Building")
    def _build_tool_args(
        self,
        messages,
        system_prompt: str,
        tool_name: str,
        tools: list,
        endpoint: str,
        params: dict,
    ):
        """Build the arguments for the selected tool."""

        selected_tool = [t for t in tools if t.function.name == tool_name][0]
        if not selected_tool:
            raise Exception("Function not found")

        # 選択されたToolから、JSONスキーマを生成
        selected_tool_schema = selected_tool.function.to_dict()
        tool_schema = {
            "title": selected_tool_schema.get("name"),
            "description": selected_tool_schema.get("description"),
            "type": selected_tool_schema.get("parameters").get("type"),
            "properties": selected_tool_schema.get("parameters").get("properties"),
            "required": selected_tool_schema.get("parameters").get("required", []),
        }

        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_prompt),
                ("placeholder", "{messages}"),
                # JSONスキーマの強制
                langchain_core.messages.ChatMessage(
                    role="json_schema", content=json.dumps(tool_schema)
                ),
            ]
        ).partial(tool=tool_name, tool_schema=tool_schema)

        chat_model = ChatDatabricks(endpoint=endpoint, **params)
        chain = prompt | chat_model | JsonOutputParser()
        lc_messages = self._convert_messages_to_lc_messages(messages)

        return chain.invoke({"messages": lc_messages})

    def _convert_messages_to_lc_messages(
        self, messages: list[ChatMessage]
    ) -> list[langchain_core.messages.ChatMessage]:
        """Convert the messages to LangChain messages."""

        return [
            langchain_core.messages.ChatMessage(role=m.role, content=m.content)
            for m in messages
        ]

    def _get_system_prompt(self, role: str) -> str:
        """Get the system prompt for the specified role."""
        role_config = self.models.get(role, {})

        return role_config.get("instruction")

    def _get_model_endpoint(self, role: str) -> str:
        """
        指定された役割のモデルEndpointを取得します。

        Args:
            role (str): エージェントの役割(例:"selector" または "args")。

        Returns:
            dict: エージェントのパラメータの辞書。
        """
        role_config = self.models.get(role, {})

        return role_config.get("endpoint")

    def _get_model_params(self, role: str) -> dict:
        """
        指定された役割のモデルパラメータを取得します。

        Args:
            role (str): エージェントの役割(例:"oracle" または "judge")。

        Returns:
            dict: エージェントのパラメータの辞書。
        """
        role_config = self.models.get(role, {})

        return {
            "temperature": role_config.get("temperature", 0.5),
            "max_tokens": role_config.get("max_tokens", 500),
        }


set_model(ToolCallingAgent())

predictメソッド内の以下の部分でTool Callingとして最終的に実行するべきツールとその引数を生成しています。
処理の流れとしては、_select_toolメソッドでクエリに対するツールを一覧から選択、_build_tool_argsメソッドで選択したツールに対するパラメータ情報を作成するという2段階のステップです。
両方のステップ共に下準備で作成したLLMエンドポイントを利用しています。
詳細はそれぞれのメソッドの中を参照ください。

            else:  # Tool指定がある場合、実行すべきTool情報を引数付で返す
                # クエリに対する関数名を取得
                selector_endpoint = self._get_model_endpoint("selector")
                selector_params = self._get_model_params("selector")
                selector_sys_p = self._get_system_prompt("selector")
                tool_name = self._select_tool(
                    messages,
                    selector_sys_p,
                    params.tools,
                    selector_endpoint,
                    selector_params,
                )

                # ツール(関数)の引数情報を取得
                args_endpoint = self._get_model_endpoint("args")
                args_params = self._get_model_params("args")
                args_sys_p = self._get_system_prompt("args")
                tool_args = self._build_tool_args(
                    messages,
                    args_sys_p,
                    tool_name,
                    params.tools,
                    args_endpoint,
                    args_params,
                )

                # Tool Call情報を作成
                tool_call = FunctionToolCallArguments(
                    name=tool_name, arguments=json.dumps(tool_args)
                ).to_tool_call(
                    "1"
                )  # 手抜きで常にIDは1

                # レスポンスを作成
                output = ChatResponse(
                    choices=[
                        ChatChoice(
                            index=0,
                            message=ChatMessage(
                                role="assistant", content="", tool_calls=[tool_call]
                            ),
                        )
                    ],
                    usage={},
                    model=args_endpoint,
                )

Step3. Model Configの作成

モデル設定と入力サンプルを用意します。
Step2のカスタムChatModelは、利用するLLMエンドポイントなどの情報をこのモデル設定から取得するようにしています。
モデル設定として、処理ごとのエンドポイント名やシステムプロンプトテンプレートを格納しています。
ざっとどのようなプロンプトでツール選択やパラメータ作成をしているのかが想像つくのではないでしょうか。
(もっとよいプロンプトがあると思いますが、突貫実装なので適当です(言い訳))

model_config = {
    "models": {
        "normal": {
            "endpoint": "Llama-3_2-3B-Instruct-exl2-endpoint",
        },
        "selector": {
            "endpoint": "Llama-3_2-3B-Instruct-exl2-endpoint",
            "instruction": (
                "あなたはクエリに対して適切な関数の名前を返すアシスタントです。"
                "ユーザからのクエリに回答するために必要な関数を以下から選んで関数名を返してください。\n\n"
                "FUNCTIONS: {functions}"
            ),
            "temperature": 0.0,
            "max_tokens": 2000,
        },
        "args": {
            "endpoint": "Llama-3_2-3B-Instruct-exl2-endpoint",
            "instruction": (
                "あなたはクエリに対してツール{tool}を使用することに決めました。"
                "このツールは関数として次のスキーマで定義された指定が必要です。"
                "SCHEMA:\n{tool_schema}\n\n"
                "このツールの実行に必要な引数をスキーマに合わせて返してください。"
            ),
            "temperature": 0.0,
            "max_tokens": 2000,
        },
    },
}

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "あまりスキルを必要としないスコーンを焼くための良いレシピは何ですか?",
        }
    ]
}

Step4. エージェントのロギング

MLflowを使って作成したエージェントクラスをロギングします。
この後Mosaic AI Model Servingを使うので、合わせてUnity Catalog上にもモデルを登録します。

import mlflow

# Unity Catalogの保管先モデル名
registered_model_name = "training.llm.tool_calling_agent"

# Databricks Unity Catalogを利用してモデル管理
mlflow.set_registry_uri("databricks-uc")

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        "model",
        python_model="tool_calling_agent.py",
        model_config=model_config,
        input_example=input_example,
        registered_model_name=registered_model_name,
    )

Step5. エージェントのデプロイ

Mosaic AI Model Servingを使ってエンドポイントを作成。
「サービング」メニューからUIを使って登録します。

他のLLMエンドポイントを呼び出しているので、DATABRICKS_HOSTなど必要な環境変数を設定しておきましょう。

image.png

ここまででエージェントの作成・登録・デプロイまで完了しました。

エージェントを使う

作成したエンドポイントを使ってTool Callingを実践してみます。

まず、Tool Callingを試すためのツールを準備。
LangGraphのTool Callingサンプルからコードを拝借しました。
天気の情報を取得するget_weather関数と最もCoolな都市一覧を取得するget_coolest_cities関数が使えるツールとなります。

from langchain_core.messages import AIMessage
from langchain_core.tools import tool

@tool
def get_weather(location: str):
    """Call to get the current weather."""
    if location.lower() in ["sf", "san francisco"]:
        return "It's 60 degrees and foggy."
    else:
        return "It's 90 degrees and sunny."

@tool
def get_coolest_cities():
    """Get a list of coolest cities"""
    return "nyc, sf"

tools = [get_weather, get_coolest_cities]

では、LangChainのChatDatabricksクラスを使って、推論を実施。

import json
from langchain_databricks import ChatDatabricks

endpoint = "test_tool_calling_agent"

# ツール指定無し:通常の処理
chat = ChatDatabricks(endpoint=endpoint)
print(" --- Without Tools ---")
print(chat.invoke(
    [{"role": "user", "content": "今の大阪の天気は?"}],
))

print()

# ツール指定あり:Tool Calling情報を得る
chat = ChatDatabricks(endpoint=endpoint).bind_tools(tools)
print(" --- With Tools ---")
print(chat.invoke(
    [{"role": "user", "content": "今の大阪の天気は?"}],
))
出力
 --- Without Tools ---
content='現在の大阪の天気を知るには、以下の方法を使用できます。\\n\\n1. インターネットで天気予報検索:Google検索やYahoo!天気予報などの検索エンジンで「大阪天気予報」を検索して、現在の天気と天気予報を確認できます。\\n\\n2.  天気予報アプリ:天気予報アプリをダウンロードしてイン' additional_kwargs={} response_metadata={} id='run-d4d6ff02-6bb6-44b1-b672-2615e5ba0d1f-0'

 --- With Tools ---
content='' additional_kwargs={'tool_calls': [{'function': {'name': 'get_weather', 'arguments': '{"location": "\\u5927\\u962a"}'}, 'id': '1', 'type': 'function'}]} response_metadata={} id='run-39eee941-3888-4b8a-a719-274766aae474-0' tool_calls=[{'name': 'get_weather', 'args': {'location': '大阪'}, 'id': '1', 'type': 'tool_call'}]

ツールの指定が無い場合は、通常の推論結果が返ってきています。
一方、ツール指定の場合はadditional_kwargsパラメータの中にtool_callsの設定が含まれており、呼び出すツール(関数)名とその引数の情報が保持されました。
ちゃんと天気を取得するツール(関数)が適切な引数と共に取得できていますね。

違うクエリも試してみます。

chat = ChatDatabricks(endpoint=endpoint).bind_tools(tools)
print(chat.invoke(
    [{"role": "user", "content": "最もクールな都市はどこ?"}],
))
出力
content='' additional_kwargs={'tool_calls': [{'function': {'name': 'get_coolest_cities', 'arguments': '{}'}, 'id': '1', 'type': 'function'}]} response_metadata={} id='run-92f23b66-d056-4fec-9d51-e7182a91ea69-0' tool_calls=[{'name': 'get_coolest_cities', 'args': {}, 'id': '1', 'type': 'tool_call'}]

今度はget_coolest_citiesを使うような内容が取得されました。

今度はlanggraphを使ってツールの実行までさせてみます。

from langgraph.prebuilt import ToolNode

chat = ChatDatabricks(endpoint=endpoint).bind_tools(tools)
message_with_single_tool_call = chat.invoke(
    [{"role": "user", "content": "今の大阪の天気は?"}],
)

# langgraphのToolNode機能を直接呼び出してツールを実行
tool_node = ToolNode(tools)
print(tool_node.invoke({"messages": [message_with_single_tool_call]}))
出力
{'messages': [ToolMessage(content="It's 90 degrees and sunny.", name='get_weather', tool_call_id='1')]}

関数が実行され、実行結果がToolMessageとして取得できました。

Tool Calling、ちゃんと動いてますね!

改善点

  • 今回は単一のツール呼び出し方法を返すだけの仕組になっています。OpenAI等のTool Callingは複数の内容を取得できるので拡張の余地があります
  • 適切なツールが無くても何らかのツール呼び出しが返ってくるようになっています。適切なツールが無い場合の動作を加えるとより便利になります
  • 構造化出力できなかったときなどのエラー処理がほとんど無いのでちゃんと入れる必要があります
  • エラーハンドリング含めていろいろ考えると複雑なフローになりそうなので、LangGraphなどを使って適切なワークフロー化をするべき
  • 中途半端にLangChainをカスタムChatModelの中で使ったためにコードが複雑になってしまいました。mlflow SDKだけで書くと見通しのよい記述にできたハズ

などなど。とにかくテスト不足でもあるので、いろいろ不具合もありそう。
もうちょっと実用的な形にしてリベンジしたい。。。

まとめ

簡易的なTool Callingエージェントを作ってみました。
エージェントと言えばTool Callingを使った処理、というイメージがなんとなくあるのでDatabricks/MLflow+ローカルLLMでこういったものが作れるのは面白いです。

また、Unity Catalog上のツール呼び出しなどもできそうなので、別途やってみたいと思います。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?