はじめに
少し前にDatabricksでMCP(Model Context Protocol)がサポートされました。
Databricksのベクトル検索などのサービスをマネージドなMCPサーバから利用したり、自分でMCPサーバをホストすることができます。
2025/6/15現在でベータ版です。
詳細な機能や注意点は上記公式ドキュメントか、いつものように@taka_yayoiさんの記事を読むとよく理解できます。
その中で、マネージドMCPサーバを使ったエージェント構築についても公式ドキュメントに記述されていますが、サンプルはシングルターンのエージェントです。
どうせならマルチホップに対応したエージェントが欲しかったので、勉強がてらLangChain/LangGraphとの連携を伴うDatabricks MCP エージェントのサンプルコードを作成しデプロイまで行ってみました。
開発・デプロイはDatabricks on AWS(非Free版)で行っています。
Databricks Free Editionでもトライしたのですが、MCPツールの実行でタイムアウトしてしまったため断念しました。対応してないんでしたっけ。。。?
Step1. 準備
ノートブックを作成し、必要なパッケージをインストールします。
%pip install -U -q mlflow[databricks]>=3.1 databricks-mcp "mcp>=1.9" databricks-agents databricks-langchain langgraph databricks-sdk
%pip install nest-asyncio
%restart_python
databricks-mcp
が初出のパッケージですね。
合わせて公式のmcp
パッケージもインストールが必要です。
また今回は(も)、LangChain/LangGraphを利用してエージェントを構築します。
langchainにはlangchain-mcp-adapters
というMCP対応のためのパッケージがあるのですが、今回は利用していません。
執筆時点ではlangchain-mcp-adapters
はMCP Authorizationに対応していなかったためです。
ただ、今回のコードはlangchain-mcp-adapters
の内容を大いに参考にさせていただきました。
開発者の方々には本当に感謝を伝えたいです。
Step2. カスタムResponsesAgentを定義
最重要ポイントです。
まずはコードを全て掲載します(非常に長いので折りたたんでいます)。
解説はコードの後に記載しています。
databricks_mcp_agent
%%writefile agents/databricks_mcp_agent.py
from contextlib import asynccontextmanager
import asyncio
from typing import Any, Callable, List, Generator, TypedDict, cast
from pydantic import BaseModel
from functools import reduce
import mlflow
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
ResponsesAgentRequest,
ResponsesAgentResponse,
ResponsesAgentStreamEvent,
)
from mlflow.entities import SpanType
from databricks_mcp import DatabricksOAuthClientProvider
from databricks.sdk import WorkspaceClient
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from langchain_core.messages import (
BaseMessage,
AIMessage,
ToolMessage,
AIMessageChunk,
)
from langchain_core.tools import BaseTool, ToolException, StructuredTool
from langgraph.prebuilt import create_react_agent
from databricks_langchain import ChatDatabricks
from mcp.types import (
CallToolResult,
EmbeddedResource,
ImageContent,
TextContent,
Tool as MCPTool,
)
NonTextContent = ImageContent | EmbeddedResource
MAX_ITERATIONS = 1000
# Tracingの有効化
mlflow.langchain.autolog()
class DatabricksConnection(TypedDict):
"""Databricks MCPサーバーへの接続情報を保持する型定義
Attributes:
server_url (str): 接続先のDatabricks MCPエンドポイントURL
workspace_client (WorkspaceClient): Databricksワークスペースクライアント(認証用)
"""
server_url: str
workspace_client: WorkspaceClient
@asynccontextmanager
async def _databricks_mcp_session(connection: DatabricksConnection):
"""Databricks MCPサーバーへの非同期セッションを作成するコンテキストマネージャ
Args:
connection (DatabricksConnection): MCPサーバーへの接続情報
Yields:
ClientSession: 初期化済みのMCPクライアントセッション
"""
async with streamablehttp_client(
url=connection.get("server_url"),
auth=DatabricksOAuthClientProvider(connection.get("workspace_client")),
) as (reader, writer, _):
async with ClientSession(reader, writer) as session:
await session.initialize()
yield session
async def _list_all_tools(session: ClientSession) -> list[MCPTool]:
"""MCPサーバーから全てのツール情報をページネーションで取得する
Args:
session (ClientSession): MCPクライアントセッション
Returns:
list[MCPTool]: 取得した全ツールのリスト
Raises:
RuntimeError: ページ数が上限を超えた場合
"""
current_cursor: str | None = None
all_tools: list[MCPTool] = []
iterations = 0
while True:
iterations += 1
if iterations > MAX_ITERATIONS:
raise RuntimeError(
f"Reached max of {MAX_ITERATIONS} iterations while listing tools."
)
list_tools_page_result = await session.list_tools(cursor=current_cursor)
if list_tools_page_result.tools:
all_tools.extend(list_tools_page_result.tools)
if list_tools_page_result.nextCursor is None:
break
current_cursor = list_tools_page_result.nextCursor
return all_tools
def _convert_call_tool_result(
call_tool_result: CallToolResult,
) -> tuple[str | list[str], list[NonTextContent] | None]:
"""MCPツール呼び出し結果をテキスト・非テキストに分割して返す
Args:
call_tool_result (CallToolResult): MCPツールの呼び出し結果
Returns:
tuple[str | list[str], list[NonTextContent] | None]: テキスト内容と非テキスト内容
Raises:
ToolException: エラーが発生した場合
"""
text_contents: list[TextContent] = []
non_text_contents = []
for content in call_tool_result.content:
if isinstance(content, TextContent):
text_contents.append(content)
else:
non_text_contents.append(content)
tool_content: str | list[str] = [content.text for content in text_contents]
if not text_contents:
tool_content = ""
elif len(text_contents) == 1:
tool_content = tool_content[0]
if call_tool_result.isError:
raise ToolException(tool_content)
return tool_content, non_text_contents or None
def _convert_mcp_tool_to_langchain_tool(
connection: DatabricksConnection,
tool: MCPTool,
) -> BaseTool:
"""MCPツール情報をLangChainのStructuredToolに変換する
Args:
connection (DatabricksConnection): MCPサーバー接続情報
tool (MCPTool): MCPツール情報
Returns:
BaseTool: LangChain互換のツール
"""
if connection is None:
raise ValueError("a connection config must be provided")
async def call_tool_async(
**arguments: dict[str, Any],
) -> tuple[str | list[str], list[NonTextContent] | None]:
async with _databricks_mcp_session(connection) as tool_session:
call_tool_result = await cast(ClientSession, tool_session).call_tool(
tool.name, arguments
)
return _convert_call_tool_result(call_tool_result)
def call_tool_sync(
**arguments: dict[str, Any]
) -> tuple[str | list[str], list[NonTextContent] | None]:
return asyncio.run(call_tool_async(**arguments))
return StructuredTool(
name=tool.name,
description=tool.description or "",
args_schema=tool.inputSchema,
coroutine=call_tool_async,
func=call_tool_sync,
response_format="content_and_artifact",
metadata=tool.annotations.model_dump() if tool.annotations else None,
)
def list_databricks_mcp_tools(
connections: list[DatabricksConnection],
) -> list[BaseTool]:
"""複数のMCPサーバーから全ツールを取得し、LangChainツールリストとして返す
Args:
connections (list[DatabricksConnection]): MCPサーバー接続情報のリスト
Returns:
list[BaseTool]: LangChain互換の全ツールリスト
"""
async def _load_databricks_mcp_tools(
connection: DatabricksConnection,
) -> list[BaseTool]:
if connection is None:
raise ValueError("connection config must be provided")
async with _databricks_mcp_session(connection) as tool_session:
tools = await _list_all_tools(tool_session)
converted_tools = [
_convert_mcp_tool_to_langchain_tool(connection, tool) for tool in tools
]
return converted_tools
async def gather():
tasks = [_load_databricks_mcp_tools(con) for con in connections]
return await asyncio.gather(*tasks)
# 結果をflat化してtoolの単一リストとして返す
return sum(asyncio.run(gather()), [])
class DatabricksMCPAgent(ResponsesAgent):
"""Databricks MCPサーバーとLangChainエージェントを組み合わせたエージェントクラス"""
def __init__(self, model_name, mcp_urls: list[str]):
"""DatabricksMCPAgentの初期化
Args:
model_name: 使用するLLMモデル(エンドポイント)名
mcp_urls (list[str]): MCPサーバーのURLリスト
"""
self.model_name = model_name
self.mcp_urls = mcp_urls
@mlflow.trace(span_type=SpanType.AGENT)
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
"""リクエストに基づいて予測を行い、最終的なレスポンスを返す
Args:
request (ResponsesAgentRequest): 予測リクエスト
Returns:
ResponsesAgentResponse: 予測結果のレスポンス
"""
events = [
event
for event in self.predict_stream(request)
if event.type == "response.output_item.done"
]
outputs = [event.item for event in events]
# usage総量を計算
usages = [event.usage for event in events]
total_usage = {
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0},
**reduce(
lambda x, y: {k: x.get(k, 0) + y.get(k, 0) for k in set(x) | set(y)},
usages,
),
}
return ResponsesAgentResponse(output=outputs, usage=total_usage)
@mlflow.trace(span_type=SpanType.AGENT)
def predict_stream(
self, request: ResponsesAgentRequest
) -> Generator[ResponsesAgentStreamEvent, None, None]:
"""ストリームモードで予測を行い、逐次的にレスポンスイベントを生成する
Args:
request (ResponsesAgentRequest): 予測リクエスト
Yields:
ResponsesAgentStreamEvent: ストリームイベント
"""
# MCP Serverに接続するコネクション情報を作成
ws = WorkspaceClient()
host = ws.config.host
connections = [
DatabricksConnection(server_url=url.format(host=host), workspace_client=ws)
for url in self.mcp_urls
]
# Databricks MCP Serverからツール情報を取得
tools = list_databricks_mcp_tools(connections)
# リクエストをLangGraphで利用できるように変換
messages, params = self._convert_request_to_lc_request(request)
base_params = {
k: v for k, v in params.items() if k in ["temperature", "max_tokens"]
}
llm = ChatDatabricks(model=self.model_name, **base_params).bind(**params)
# ReActエージェントを使って実行
react_agent = create_react_agent(llm, tools=tools)
for chunk in react_agent.stream({"messages": messages}, stream_mode="updates"):
for value in chunk.values():
messages = value.get("messages", [])
responses = self._convert_lc_messages_to_response(messages)
for response in responses:
yield response
@mlflow.trace(span_type=SpanType.PARSER)
def _convert_request_to_lc_request(
self, request: ResponsesAgentRequest
) -> (list[BaseMessage], dict[str, Any]):
"""MLflowのリクエストをLangChainのメッセージ・パラメータ形式に変換する
Args:
request (ResponsesAgentRequest): 変換するリクエスト
Returns:
tuple: メッセージリスト、パラメータ辞書
"""
lc_request = request.model_dump_compat(exclude_none=True)
custom_inputs = lc_request.pop("custom_inputs", {})
# custom_inputsは通常のパラメータとして展開
lc_request.update(custom_inputs)
messages = lc_request.pop("input")
# LangChainで有効なパラメータのみに限定
valid_params = [
"temperature",
"max_output_tokens",
"top_p",
"top_k",
]
params = {k: v for k, v in lc_request.items() if k in valid_params}
if "max_output_tokens" in params:
params["max_tokens"] = params.pop("max_output_tokens")
return messages, params
@mlflow.trace(span_type=SpanType.PARSER)
def _convert_lc_messages_to_response(
self, messages: list[BaseMessage]
) -> list[ResponsesAgentStreamEvent]:
"""LangChainメッセージをMLflowのストリームレスポンス形式に変換する
Args:
messages (list[BaseMessage]): 変換するメッセージリスト
Returns:
list[ResponsesAgentStreamEvent]: レスポンス出力のリスト
Raises:
ValueError: 未知のメッセージ型の場合
"""
def _create_response_agent_stream_event(
item, usage, metadata
) -> ResponsesAgentStreamEvent:
"""ストリームイベント(ResponsesAgentStreamEvent)を生成する内部関数"""
return ResponsesAgentStreamEvent(
type="response.output_item.done",
item=item,
usage=_convert_lc_usage_to_openai_usage(usage),
metadata=metadata,
)
def _convert_lc_usage_to_openai_usage(usage: dict[str, int]) -> dict[str, int]:
"""LangChainのusage情報をOpenAI Response API互換形式に変換"""
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
outputs = []
for message in messages:
if isinstance(message, ToolMessage):
item = self.create_function_call_output_item(
output=message.content,
call_id=message.tool_call_id,
)
metadata = message.response_metadata
usage = metadata.pop("usage", {})
outputs.append(
_create_response_agent_stream_event(item, usage, metadata)
)
elif (
isinstance(message, (AIMessage, AIMessageChunk)) and message.tool_calls
):
metadata = message.response_metadata
usage = metadata.pop("usage", {})
for tool_call in message.tool_calls:
item = self.create_function_call_item(
id=message.id,
call_id=tool_call.get("id"),
name=tool_call.get("name"),
arguments=str(tool_call.get("args")),
)
outputs.append(
_create_response_agent_stream_event(item, usage, metadata)
)
# 1件目のみusageを設定
usage = {}
elif isinstance(message, (AIMessage, AIMessageChunk)):
item = self.create_text_output_item(
text=message.content,
id=message.id,
)
metadata = message.response_metadata
usage = metadata.pop("usage", {})
outputs.append(
_create_response_agent_stream_event(item, usage, metadata)
)
else:
raise ValueError(f"Unknown message: {message}")
return outputs
# LLMとして利用するエンドポイント名
# LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4"
# MCPで利用したいGENIEのID。任意で変更ください。
MCP_GENIE_IDS = ["xxxxxxxxxxxxxxxxxxxxxxxxxx"] # MCPで利用したいGENIEのID
# MCPで利用したいUC FUNCTION(カタログ+スキーマ)名
MCP_UC_FUNCTIONS = ["system.ai"]
MCP_SERVER_URLS_TEMPLATE = [
f"{{host}}/api/2.0/mcp/genie/{genie_id}" for genie_id in MCP_GENIE_IDS
] + [
f"{{host}}/api/2.0/mcp/functions/{function.replace('.', '/') }"
for function in MCP_UC_FUNCTIONS
]
# mlflowにエージェントを設定
agent = DatabricksMCPAgent(model_name=LLM_ENDPOINT_NAME, mcp_urls=MCP_SERVER_URLS_TEMPLATE)
mlflow.models.set_model(agent)
if __name__ == "__main__":
# 簡単なテスト
input = {
"input": [{"role": "user", "content": r"what is 4*3?"}],
"context": {"conversation_id": "123", "user_id": "456"},
"top_p": 0.9,
}
for event in agent.predict_stream(ResponsesAgentRequest(**input)):
print(event)
前回取り上げた、mlflow3で新しく登場したクラスResponsesAgent
を使って実装しています。
ポイントはlist_databricks_mcp_tools
関数で、ここでMCPサーバと連携してツール情報を取得し、LangCahinのツール(BaseTool)一覧を取得しています。
あとはLangGraphのcreate_react_agent
で上記ツールを利用するようにReActエージェントを作成・利用しています。
処理場の注意点(というか改善ポイント)として、以下のようなものがあります。
- リクエスト受信時に毎回MCPサーバと通信してツールの一覧を取得しているため、効率が悪い
- ケースによりますが、エージェント作成時に取得・保持しておく方がよいかも
- Tool実行時にも毎回セッションを作成しているため、効率が悪い
試用する程度では問題になりませんが、念のため注意ください。
Step3. エージェントのロギング
エージェントをMLflowにロギングします。
注意点は、特にGenieを利用する際のresources
指定です。
Genieで利用するSQLウェアハウスや関連テーブルを登録しておかないと利用権限が得られません。
権限に関しては代理認証を使うという手もあります。(ただし、まだベータ版)
import os
import sys
import nest_asyncio
sys.path.append(os.path.join(os.getcwd(), "agents"))
nest_asyncio.apply()
import mlflow
from mlflow.models.resources import (
DatabricksServingEndpoint,
DatabricksFunction,
DatabricksGenieSpace,
DatabricksSQLWarehouse,
DatabricksTable,
)
# resoucesとして登録するLLMエンドポイント名やGenie IDの一覧などをエージェントのモジュールから取得
from databricks_mcp_agent import LLM_ENDPOINT_NAME, MCP_GENIE_IDS, MCP_UC_FUNCTIONS
# Genieで利用するテーブルの一覧
tables = [
"xxx_catalog.xxxx_schema.xxx_table1",
"xxx_catalog.xxxx_schema.xxx_table2",
]
resources = (
[
DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
DatabricksFunction(function_name="system.ai.python_exec"),
DatabricksSQLWarehouse(warehouse_id="xxxxxxxxxxx"), # Genieで利用するSQL WHのID
]
+ [DatabricksGenieSpace(genie_space_id=genie_id) for genie_id in MCP_GENIE_IDS]
+ [DatabricksTable(table_name=t) for t in tables]
)
with mlflow.start_run():
logged_agent_info = mlflow.pyfunc.log_model(
python_model="agents/databricks_mcp_agent.py",
name="databricks_mcp_agent",
pip_requirements=[
"mlflow>=3.1.0",
"mcp==1.9.4",
"databricks-mcp==0.1.0",
"databricks-sdk==0.55.0",
"langgraph==0.4.8",
"databricks-langchain==0.5.1",
],
resources=resources,
)
ロギングしたエージェントを利用してみます。
system.ai.python_exec
が使われるような内容にしてみました。
import mlflow
from pprint import pprint
model_uri = logged_agent_info.model_uri
agent = mlflow.pyfunc.load_model(model_uri)
input = {
"input": [
{"role": "user", "content": "1から100までの数の中で、素数はいくつある?"},
],
}
for event in agent.predict_stream(input):
pprint(event.get("item"))
print("-----------------")
{'arguments': '{\'code\': \'def is_prime(n):\\n """素数かどうかを判定する関数"""\\n '
'if n < 2:\\n return False\\n if n == 2:\\n '
'return True\\n if n % 2 == 0:\\n return False\\n '
'\\n # 3から√nまでの奇数で割り切れるかチェック\\n for i in range(3, '
'int(n**0.5) + 1, 2):\\n if n % i == 0:\\n '
'return False\\n return True\\n\\n# 1から100までの素数を見つける\\nprimes '
'= []\\nfor num in range(1, 101):\\n if '
'is_prime(num):\\n '
'primes.append(num)\\n\\nprint(f"1から100までの素数: '
'{primes}")\\nprint(f"\\\\n素数の個数: {len(primes)}個")\'}',
'call_id': 'toolu_bdrk_01LZtJzFYeqHVFE83gJ1UjyE',
'id': 'run--45c3b918-9bd0-4c0e-92e8-6cdbfef14bfe-0',
'name': 'system__ai__python_exec',
'type': 'function_call'}
-----------------
{'call_id': 'toolu_bdrk_01LZtJzFYeqHVFE83gJ1UjyE',
'output': '{"is_truncated":false,"columns":["output"],"rows":[["1から100までの素数: '
'[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, '
'61, 67, 71, 73, 79, 83, 89, 97]\\n\\n素数の個数: 25個\\n"]]}',
'type': 'function_call_output'}
-----------------
{'content': [{'text': '1から100までの数の中で、**素数は25個**あります。\n'
'\n'
'具体的には以下の25個の数が素数です:\n'
'2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, '
'53, 59, 61, 67, 71, 73, 79, 83, 89, 97\n'
'\n'
'素数とは、1と自分自身以外に約数を持たない1より大きい自然数のことです。最小の素数は2(唯一の偶数の素数)で、最大の素数は97となっています。',
'type': 'output_text'}],
'id': 'run--ee300d9b-4398-4c49-adb5-9f2e73721228-0',
'role': 'assistant',
'type': 'message'}
-----------------
Tracing結果はこちら。
想定通り、素数を求めるpythonコードを作成し、system.ai.python_exec
ツールで実行して結果を得ています。
Step4. エージェントのデプロイ
それでは、Databricks Mosaic AI Model Servingへのデプロイをします。
まずは、Unity Catalogに登録。
import mlflow
catalog = "training"
schema = "llm"
model_name = f"{catalog}.{schema}.databricks_mcp_agent"
mlflow.set_registry_uri("databricks-uc")
registered_model = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=model_name)
その後、Mosaic AI Agent Frameworkを使ってModel Servingへデプロイ。
from databricks import agents
deployment = agents.deploy(
registered_model.name,
registered_model.version,
endpoint_name='databricks_mcp_agent_endpoint',
scale_to_zero=True,
)
deployment.query_endpoint
問題なければ、数分後にデプロイが完了します。
Step5. エージェントを使ってみる
Playground上で試してみます。
python_exec
を無理やり複数回使うような質問にしてみます。
複数回ツールを呼びだして処理が進みます。
最終結果も出力されました。
トレース結果を見ても複数回LLMとツール呼び出しが行われていることがわかります。
ちゃんとMCPサーバを利用してツールを取得し、(微妙な例ですが)マルチホップでエージェントが自律的に処理を判断して回答を得るエージェントが作れました。
Genieとの連携も試していますが、Genieから得られた結果をさらに別のGenieスペースへ連携したいり、python処理にかけたりとかなり幅広く処理してくれそうです。
何気にGenieスペースをツールとして簡単に利用するのは大変だったので、公式がMCPで提供してくれるのは本当に便利ですね。
まとめ
DatabricksのマネージドMCPサーバと連携するエージェントを試作しました。
公式ドキュメントの実装方法と異なるやり方で行っていますが、個人的にはLangGraph使いたいのでこちらのやり方の方が性にあっていたりします。(公式のものも高い水準のサンプルで感心しています)
MCP熱がちょっと落ち着いてきたところだったんですが、また再燃しそうです。楽しい