1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Databricks Appsでちょっとリッチなエージェントチャットボットを作る

Posted at

タイトルを正しく読めた方は心が綺麗。

はじめに

Databricks on AWSの東京リージョンにてDatabricks Appsが利用できるようになりました!!
イチユーザとして首を長くして待っておりました。

リリース日には以下のように紹介記事が最速で公開されています。
Appsとは何か、については是非以下の記事をご覧ください。

本記事は、せっかく東京リージョンで使えるようになりましたので、サンプル的な何かを作ってみようという趣旨の内容になります。

というわけで、今回はLLMを利用したチャットボットを作ります。
チャットボットと言えばすでにDatabricks Appsのテンプレートが用意されています。
ただ、かなり簡易的なものなので多少リッチなものを目指して作ってみます。

出来上がりのイメージはこちら。

image.png

正直Databricks Playgroundの焼き直し感はありますが、参考になればありがたいです。

チャットボットの機能仕様

以下の機能を提供します。

  • いくつかのツール呼び出し機能を備えたエージェントとのチャットができる
    - ツールはMCPで呼び出し
    - Pythonコードの実行と指定したサイトをマークダウン形式で取得できるツールを備える
  • チャット上ではツール呼び出しの内容などが確認できる
  • チャット履歴が保管できる(保管したチャット履歴のアップロードは未対応)

突貫開発したので、エラーハンドリングやロギングなどは不十分です。
参考にする際はご注意ください。

それでは、バックエンド→フロントエンドの順番で構築していきます。

Step1: バックエンドを作る

チャットボットのキモであるAIエージェント処理を作成します。
合せてエージェントが利用するツールを定義し、Unity Catalog Functionsへ登録します。

準備

Databricksでノートブックを作成し、必要なパッケージをインストールします。

%pip install -U "mlflow[databriicks]>=3.1.1" markitdown[pdf] pandas
%pip install unitycatalog-ai[databricks] unitycatalog-langchain[databricks] databricks-langchain
%pip install -U databricks-mcp "mcp==1.11.0" databricks-agents langgraph databricks-sdk
%pip install nest-asyncio

%restart_python

カタログ・スキーマの準備

Unity Catalog Functionsやエージェントを保管するためのカタログ・スキーマを準備します。

catalog = "training"
schema = "ai_agent"

# スキーマが無い場合、作成
create_schema_sql = """CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}"""
spark.sql(create_schema_sql.format(catalog = catalog, schema = schema))

ツールの定義

Markitdownを使うツールを定義します。
これを利用することで、指定したサイトやオンライン上のPDFなどをマークダウン形式のデータとして取得し、エージェント内で利用できるようになります。

まずはPythonの関数として処理を記述します。

def convert_to_markdown(url: str) -> dict[str, str]:
    """
    Converts the content of a given URL to markdown format.

    Args:
        url (str): The URL of the content to be converted.

    Returns:
        dict[str, str]: A dictionary containing the title and text content of the converted markdown.
    """
    from markitdown import MarkItDown

    md = MarkItDown(enable_plugins=False)
    result = md.convert(url)

    return {"title": result.title, "text": result.text_content}

Pythonの関数をUnity Catalog Functionsとして登録します。

from unitycatalog.ai.core.databricks import DatabricksFunctionClient

client = DatabricksFunctionClient()

client.create_python_function(
    func=convert_to_markdown,
    catalog=catalog,
    schema=schema,
    replace=True,
    dependencies=["markitdown[pdf]", "pandas>=2.3.0"],
    environment_version="None",
)

実行後はカタログエクスプローラ上で確認できます。

image.png

エージェントを定義する

先ほどのconvert_to_markdownと、DatabricksビルトインのPython実行関数python_execをツールとして実行するエージェントを定義します。ツールはMCPで呼び出すようにしています。

また、predict_streamインターフェースを実装しており、ツール呼び出しなどは途中経過の状態で結果を得られるようにしています。

詳しくは以下の記事で解説しており、今回のコードはその時とほぼ同じものになります。

MCP機能は2025年7月時点でベータ版となります。

※ コードは長いので、折り畳み。

databricks_mcp_agent
%%writefile databricks_mcp_agent.py

from contextlib import asynccontextmanager
import asyncio
from typing import Any, Callable, List, Generator, TypedDict, cast
from pydantic import BaseModel
from functools import reduce

import mlflow
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)
from mlflow.entities import SpanType

from databricks_mcp import DatabricksOAuthClientProvider
from databricks.sdk import WorkspaceClient
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from langchain_core.messages import (
    BaseMessage,
    AIMessage,
    ToolMessage,
    AIMessageChunk,
)
from langchain_core.tools import BaseTool, ToolException, StructuredTool
from langgraph.prebuilt import create_react_agent
from databricks_langchain import ChatDatabricks

from mcp.types import (
    CallToolResult,
    EmbeddedResource,
    ImageContent,
    TextContent,
    Tool as MCPTool,
)

NonTextContent = ImageContent | EmbeddedResource

MAX_ITERATIONS = 1000

# Tracingの有効化
mlflow.langchain.autolog()


class DatabricksConnection(TypedDict):
    """Databricks MCPサーバーへの接続情報を保持する型定義

    Attributes:
        server_url (str): 接続先のDatabricks MCPエンドポイントURL
        workspace_client (WorkspaceClient): Databricksワークスペースクライアント(認証用)
    """

    server_url: str
    workspace_client: WorkspaceClient


@asynccontextmanager
async def _databricks_mcp_session(connection: DatabricksConnection):
    """Databricks MCPサーバーへの非同期セッションを作成するコンテキストマネージャ

    Args:
        connection (DatabricksConnection): MCPサーバーへの接続情報

    Yields:
        ClientSession: 初期化済みのMCPクライアントセッション
    """
    async with streamablehttp_client(
        url=connection.get("server_url"),
        auth=DatabricksOAuthClientProvider(connection.get("workspace_client")),
        timeout=60,
    ) as (reader, writer, _):
        async with ClientSession(reader, writer) as session:
            await session.initialize()
            yield session


async def _list_all_tools(session: ClientSession) -> list[MCPTool]:
    """MCPサーバーから全てのツール情報をページネーションで取得する

    Args:
        session (ClientSession): MCPクライアントセッション

    Returns:
        list[MCPTool]: 取得した全ツールのリスト

    Raises:
        RuntimeError: ページ数が上限を超えた場合
    """
    current_cursor: str | None = None
    all_tools: list[MCPTool] = []

    iterations = 0

    while True:
        iterations += 1
        if iterations > MAX_ITERATIONS:
            raise RuntimeError(
                f"Reached max of {MAX_ITERATIONS} iterations while listing tools."
            )

        list_tools_page_result = await session.list_tools(cursor=current_cursor)

        if list_tools_page_result.tools:
            all_tools.extend(list_tools_page_result.tools)

        if list_tools_page_result.nextCursor is None:
            break

        current_cursor = list_tools_page_result.nextCursor
    return all_tools


def _convert_call_tool_result(
    call_tool_result: CallToolResult,
) -> tuple[str | list[str], list[NonTextContent] | None]:
    """MCPツール呼び出し結果をテキスト・非テキストに分割して返す

    Args:
        call_tool_result (CallToolResult): MCPツールの呼び出し結果

    Returns:
        tuple[str | list[str], list[NonTextContent] | None]: テキスト内容と非テキスト内容

    Raises:
        ToolException: エラーが発生した場合
    """
    text_contents: list[TextContent] = []
    non_text_contents = []
    for content in call_tool_result.content:
        if isinstance(content, TextContent):
            text_contents.append(content)
        else:
            non_text_contents.append(content)

    tool_content: str | list[str] = [content.text for content in text_contents]
    if not text_contents:
        tool_content = ""
    elif len(text_contents) == 1:
        tool_content = tool_content[0]

    if call_tool_result.isError:
        raise ToolException(tool_content)

    return tool_content, non_text_contents or None


def _convert_mcp_tool_to_langchain_tool(
    connection: DatabricksConnection,
    tool: MCPTool,
) -> BaseTool:
    """MCPツール情報をLangChainのStructuredToolに変換する

    Args:
        connection (DatabricksConnection): MCPサーバー接続情報
        tool (MCPTool): MCPツール情報

    Returns:
        BaseTool: LangChain互換のツール
    """
    if connection is None:
        raise ValueError("a connection config must be provided")

    async def call_tool_async(
        **arguments: dict[str, Any],
    ) -> tuple[str | list[str], list[NonTextContent] | None]:
        async with _databricks_mcp_session(connection) as tool_session:
            call_tool_result = await cast(ClientSession, tool_session).call_tool(
                tool.name, arguments
            )
        return _convert_call_tool_result(call_tool_result)

    def call_tool_sync(
        **arguments: dict[str, Any]
    ) -> tuple[str | list[str], list[NonTextContent] | None]:
        return asyncio.run(call_tool_async(**arguments))

    return StructuredTool(
        name=tool.name,
        description=tool.description or "",
        args_schema=tool.inputSchema,
        coroutine=call_tool_async,
        func=call_tool_sync,
        response_format="content_and_artifact",
        metadata=tool.annotations.model_dump() if tool.annotations else None,
    )

def list_databricks_mcp_tools(
    connections: list[DatabricksConnection],
) -> list[BaseTool]:
    """複数のMCPサーバーから全ツールを取得し、LangChainツールリストとして返す

    Args:
        connections (list[DatabricksConnection]): MCPサーバー接続情報のリスト

    Returns:
        list[BaseTool]: LangChain互換の全ツールリスト
    """

    async def _load_databricks_mcp_tools(
        connection: DatabricksConnection,
    ) -> list[BaseTool]:
        if connection is None:
            raise ValueError("connection config must be provided")

        async with _databricks_mcp_session(connection) as tool_session:
            tools = await _list_all_tools(tool_session)

        converted_tools = [
            _convert_mcp_tool_to_langchain_tool(connection, tool) for tool in tools
        ]
        return converted_tools

    async def gather():
        tasks = [_load_databricks_mcp_tools(con) for con in connections]
        return await asyncio.gather(*tasks)

    # 結果をflat化してtoolの単一リストとして返す
    return sum(asyncio.run(gather()), [])


class DatabricksMCPAgent(ResponsesAgent):
    """Databricks MCPサーバーとLangChainエージェントを組み合わせたエージェントクラス"""

    def __init__(self, model_name, mcp_urls: list[str]):
        """DatabricksMCPAgentの初期化

        Args:
            model_name: 使用するLLMモデル(エンドポイント)名
            mcp_urls (list[str]): MCPサーバーのURLリスト
        """
        self.model_name = model_name
        self.mcp_urls = mcp_urls

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        """リクエストに基づいて予測を行い、最終的なレスポンスを返す

        Args:
            request (ResponsesAgentRequest): 予測リクエスト

        Returns:
            ResponsesAgentResponse: 予測結果のレスポンス
        """
        events = [
            event
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        outputs = [event.item for event in events]
        # usage総量を計算
        usages = [event.usage for event in events]
        total_usage = {
            "input_tokens_details": {"cached_tokens": 0},
            "output_tokens_details": {"reasoning_tokens": 0},
            **reduce(
                lambda x, y: {k: x.get(k, 0) + y.get(k, 0) for k in set(x) | set(y)},
                usages,
            ),
        }

        return ResponsesAgentResponse(output=outputs, usage=total_usage)

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict_stream(
        self, request: ResponsesAgentRequest
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """ストリームモードで予測を行い、逐次的にレスポンスイベントを生成する

        Args:
            request (ResponsesAgentRequest): 予測リクエスト

        Yields:
            ResponsesAgentStreamEvent: ストリームイベント
        """

        # MCP Serverに接続するコネクション情報を作成
        ws = WorkspaceClient()
        host = ws.config.host
        connections = [
            DatabricksConnection(server_url=url.format(host=host), workspace_client=ws)
            for url in self.mcp_urls
        ]

        # Databricks MCP Serverからツール情報を取得
        tools = list_databricks_mcp_tools(connections)

        # リクエストをLangGraphで利用できるように変換
        messages, params = self._convert_request_to_lc_request(request)
        base_params = {
            k: v for k, v in params.items() if k in ["temperature", "max_tokens"]
        }
        llm = ChatDatabricks(model=self.model_name, **base_params).bind(**params)

        # ReActエージェントを使って実行
        react_agent = create_react_agent(llm, tools=tools)
        for chunk in react_agent.stream({"messages": messages}, stream_mode="updates"):
            for value in chunk.values():
                messages = value.get("messages", [])
                responses = self._convert_lc_messages_to_response(messages)
                for response in responses:
                    yield response

    @mlflow.trace(span_type=SpanType.PARSER)
    def _convert_request_to_lc_request(
        self, request: ResponsesAgentRequest
    ) -> (list[BaseMessage], dict[str, Any]):
        """MLflowのリクエストをLangChainのメッセージ・パラメータ形式に変換する

        Args:
            request (ResponsesAgentRequest): 変換するリクエスト

        Returns:
            tuple: メッセージリスト、パラメータ辞書
        """

        lc_request = request.model_dump_compat(exclude_none=True)
        custom_inputs = lc_request.pop("custom_inputs", {})

        # custom_inputsは通常のパラメータとして展開
        lc_request.update(custom_inputs)
        messages = lc_request.pop("input")

        # LangChainで有効なパラメータのみに限定
        valid_params = [
            "temperature",
            "max_output_tokens",
            "top_p",
            "top_k",
        ]
        params = {k: v for k, v in lc_request.items() if k in valid_params}
        if "max_output_tokens" in params:
            params["max_tokens"] = params.pop("max_output_tokens")

        return messages, params

    @mlflow.trace(span_type=SpanType.PARSER)
    def _convert_lc_messages_to_response(
        self, messages: list[BaseMessage]
    ) -> list[ResponsesAgentStreamEvent]:
        """LangChainメッセージをMLflowのストリームレスポンス形式に変換する

        Args:
            messages (list[BaseMessage]): 変換するメッセージリスト

        Returns:
            list[ResponsesAgentStreamEvent]: レスポンス出力のリスト

        Raises:
            ValueError: 未知のメッセージ型の場合
        """

        def _create_response_agent_stream_event(
            item, usage, metadata
        ) -> ResponsesAgentStreamEvent:
            """ストリームイベント(ResponsesAgentStreamEvent)を生成する内部関数"""
            return ResponsesAgentStreamEvent(
                type="response.output_item.done",
                item=item,
                usage=_convert_lc_usage_to_openai_usage(usage),
                metadata=metadata,
            )

        def _convert_lc_usage_to_openai_usage(usage: dict[str, int]) -> dict[str, int]:
            """LangChainのusage情報をOpenAI Response API互換形式に変換"""
            return {
                "input_tokens": usage.get("prompt_tokens", 0),
                "output_tokens": usage.get("completion_tokens", 0),
                "total_tokens": usage.get("total_tokens", 0),
            }

        outputs = []
        for message in messages:
            if isinstance(message, ToolMessage):
                item = self.create_function_call_output_item(
                    output=message.content,
                    call_id=message.tool_call_id,
                )
                metadata = message.response_metadata
                usage = metadata.pop("usage", {})
                outputs.append(
                    _create_response_agent_stream_event(item, usage, metadata)
                )
            elif (
                isinstance(message, (AIMessage, AIMessageChunk)) and message.tool_calls
            ):
                metadata = message.response_metadata
                usage = metadata.pop("usage", {})
                for tool_call in message.tool_calls:
                    item = self.create_function_call_item(
                        id=message.id,
                        call_id=tool_call.get("id"),
                        name=tool_call.get("name"),
                        arguments=str(tool_call.get("args")),
                    )
                    outputs.append(
                        _create_response_agent_stream_event(item, usage, metadata)
                    )
                    # 1件目のみusageを設定
                    usage = {}
            elif isinstance(message, (AIMessage, AIMessageChunk)):
                item = self.create_text_output_item(
                    text=message.content,
                    id=message.id,
                )
                metadata = message.response_metadata
                usage = metadata.pop("usage", {})
                outputs.append(
                    _create_response_agent_stream_event(item, usage, metadata)
                )
            else:
                raise ValueError(f"Unknown message: {message}")
        return outputs


# LLMとして利用するエンドポイント名
# LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4"

# MCPで利用したいUC FUNCTION(カタログ+スキーマ)名
TOOL_CATALOG = "training"
TOOL_SCHEMA = "ai_agent"
MCP_UC_FUNCTIONS = ["system.ai", f"{TOOL_CATALOG}.{TOOL_SCHEMA}"]
MCP_SERVER_URLS_TEMPLATE = [
    f"{{host}}/api/2.0/mcp/functions/{function.replace('.', '/') }"
    for function in MCP_UC_FUNCTIONS
]

# mlflowにエージェントを設定
agent = DatabricksMCPAgent(model_name=LLM_ENDPOINT_NAME, mcp_urls=MCP_SERVER_URLS_TEMPLATE)
mlflow.models.set_model(agent)

if __name__ == "__main__":
    # 簡単なテスト
    # input = {
    #     "input": [{"role": "user", "content": r"what is 4*3?"}],
    #     "context": {"conversation_id": "123", "user_id": "456"},
    #     "top_p": 0.9,
    # }
    # for event in agent.predict_stream(ResponsesAgentRequest(**input)):
    #     print(event)

    # 簡単なテスト2
    input = {
        "input": [{"role": "user", "content": r"次のサイトをマークダウン形式に変換して100文字程度で日本語で要約して:https://qiita.com/isanakamishiro2/items/5a303055e3760d9000e0"}],
        "context": {"conversation_id": "123", "user_id": "456"},
        "top_p": 0.9,
    }
    for event in agent.predict_stream(ResponsesAgentRequest(**input)):
        print(event)


モデルをロギング/デプロイする

定義したエージェントをUnity Catalogにロギングし、Databricks Model Serving上にデプロイします。

まずはロギング。

import nest_asyncio
import mlflow
from mlflow.models.resources import (
    DatabricksServingEndpoint,
    DatabricksFunction,
)
from databricks_mcp_agent import TOOL_CATALOG, TOOL_SCHEMA, LLM_ENDPOINT_NAME
from databricks_langchain import UCFunctionToolkit

# ノートブック上でasyncを使えるようにする
nest_asyncio.apply()

# 利用するUnity Catalog Functionsの一覧からリソースを作成
func_name = f"{TOOL_CATALOG}.{TOOL_SCHEMA}.*"
toolkit = UCFunctionToolkit(function_names=[func_name])
function_resources = [DatabricksFunction(function_name=func_name) for func_name in toolkit.tools_dict.keys()]

# エージェントが利用するリソース定義
resources = [
    DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
    DatabricksFunction(function_name="system.ai.python_exec"),
] + function_resources

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        python_model="databricks_mcp_agent.py",
        name="simple_mcp_agent",
        pip_requirements=[
            "mlflow>=3.1.1",
            "mcp==1.11.0",
            "databricks-mcp==0.2.0",
            "langgraph==0.5.2",
            "databricks-langchain==0.6.0",
            "unitycatalog-langchain==0.2.0",
            "unitycatalog-ai==0.3.1",
        ],
        resources=resources,
        registered_model_name=f"{TOOL_CATALOG}.{TOOL_SCHEMA}.simple_mcp_agent"
    )

デプロイ前には、以下のようなコードでロギングしたエージェントの動作を確認できます。

import mlflow
from pprint import pprint

model_uri = logged_agent_info.model_uri
agent = mlflow.pyfunc.load_model(model_uri)

input = {
    "input": [
        {"role": "user", "content": "what is the weather in Tokyo?"},
    ],
    "max_output_tokens": 1000,
    "top_p": 0.8,
    "temperature": 0.1,
}

for event in agent.predict_stream(input):
    pprint(event.get("item"))
    print("-----------------")

動作が問題なければ、次に、デプロイ。
今回はsimple_databricks_mcp_agentという名前のエンドポイントとしてデプロイしました。

from databricks import agents
from mlflow import MlflowClient

client = MlflowClient()
uc_model_name = f"{TOOL_CATALOG}.{TOOL_SCHEMA}.simple_mcp_agent"
versions = [
    mv.version for mv in client.search_model_versions(f"name='{uc_model_name}'")
]

deployment = agents.deploy(
    uc_model_name,
    versions[0],
    endpoint_name='simple_databricks_mcp_agent',
    scale_to_zero=True,
)
deployment.query_endpoint

ここまででバックエンドの準備が完了です。
次からDatabricks Appsを使うフロント部分を作っていきます。

Step2: フロントエンドを作る

Databricks Appsからアプリを追加

まず、サイドバーメニューの「新規」から「アプリ」を選択します。

image.png

次に、アプリをどのように作るかを選択します。
今回はStreamlitを使ったChatbotのテンプレートから作ることにします。

image.png

最初にリソースとしてサービングエンドポイントを指定できますので、バックエンド用に作成したエージェントsimple_databricks_mcp_agentを指定します。

image.png

認証設定はデフォルト設定とし、最後にアプリ名などを設定するとアプリのデプロイが開始されます。

image.png

数分待つと、アプリのデプロイが完了し、またデプロイメント元のソースコードを編集できるようになります。

image.png

アプリを変更する

このままだとUIがシンプルすぎる&MLflow3のResponsesAgentに対するクエリに対応してないため、コードを変更します。
それぞれかなり長いので、以下に折り畳みで掲載します。
ちなみに、ほとんどのUI部分はClaude4で作ってもらいました。UI構築におけるLLMの活用は本当に便利です。

各種ファイルの変更内容
app.py
import logging
import streamlit as st
from datetime import datetime
import json
import os
from typing import List, Dict, Optional, Any
from pydantic import BaseModel
import math
import random
from model_serving_utils import query_endpoint

# ロギングの設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 環境変数のチェック
assert os.getenv("SERVING_ENDPOINT"), "SERVING_ENDPOINT must be set in app.yaml."

# ページ設定
st.set_page_config(
    page_title="AI Chat Assistant with Tools & Structured Responses",
    page_icon="🛠",
    layout="wide",
    initial_sidebar_state="expanded",
)

# カスタムCSS
st.markdown(
    """
<style>
    .main-header {
        font-size: 2.5rem;
        font-weight: bold;
        text-align: center;
        background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        margin-bottom: 2rem;
    }
    .chat-message {
        padding: 1rem;
        border-radius: 10px;
        margin: 1rem 0;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
    }

    .user-message {
        background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
        color: #0c4a6e;
        margin-left: 20%;
    }

    .assistant-message {
        background: linear-gradient(135deg, #f0f9ff 0%, #dbeafe 100%);
        color: #1e40af;
        margin-right: 20%;
    }

    .tool-call-message {
        background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
        color: #92400e;
        margin-right: 15%;
        border-left: 4px solid #f59e0b;
    }

    .tool-response-message {
        background: linear-gradient(135deg, #ecfdf5 0%, #d1fae5 100%);
        color: #065f46;
        margin-right: 10%;
        border-left: 4px solid #10b981;
    }
        
    .confidence-high { border-left-color: #28a745; }
    .confidence-medium { border-left-color: #ffc107; }
    .confidence-low { border-left-color: #dc3545; }
    
    .system-message {
        background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
        color: white;
        text-align: center;
        font-style: italic;
    }
    
    .tool-details {
        background: #f8f9fa;
        padding: 0.8rem;
        border-radius: 5px;
        margin-top: 0.5rem;
        font-family: 'Courier New', monospace;
        font-size: 0.85rem;
        border: 1px solid #dee2e6;
    }
    
    .tool-result {
        background: #e8f5e8;
        padding: 0.8rem;
        border-radius: 5px;
        margin-top: 0.5rem;
        border: 1px solid #c3e6c3;
    }
    
    .tool-error {
        background: #f8d7da;
        padding: 0.8rem;
        border-radius: 5px;
        margin-top: 0.5rem;
        border: 1px solid #f5c6cb;
    }
    
    .response-metadata {
        background: #f8f9fa;
        padding: 0.5rem;
        border-radius: 5px;
        margin-top: 0.5rem;
        font-size: 0.8rem;
        color: #666;
    }
    
    .keyword-tag {
        display: inline-block;
        background: #e9ecef;
        color: #495057;
        padding: 0.2rem 0.5rem;
        border-radius: 15px;
        margin: 0.1rem;
        font-size: 0.7rem;
    }
    
    .follow-up-question {
        background: #fff3cd;
        border: 1px solid #ffeaa7;
        border-radius: 5px;
        padding: 0.5rem;
        margin: 0.2rem 0;
        cursor: pointer;
        transition: background-color 0.2s;
    }
    
    .follow-up-question:hover {
        background: #fff1b3;
    }
    
    .tool-icon {
        font-size: 1.2rem;
        margin-right: 0.5rem;
    }
</style>
""",
    unsafe_allow_html=True,
)


# セッション状態の初期化
def initialize_session_state():
    if "responses" not in st.session_state:
        st.session_state.responses = []
    if "temperature" not in st.session_state:
        st.session_state.temperature = 0.7
    if "max_tokens" not in st.session_state:
        st.session_state.max_tokens = 1000
    if "system_prompt" not in st.session_state:
        st.session_state.system_prompt = "あなたは親切で知識豊富なAIアシスタントです。文脈に応じて必要なときだけツールを使用してしてください。"


initialize_session_state()


# チャット履歴の保存
def save_chat_history():
    """チャット履歴をJSONファイルに保存する"""
    if st.session_state.responses:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"chat_with_tools_{timestamp}.json"

        export_data = {
            "metadata": {
                "export_time": datetime.now().isoformat(),
                "total_responses": len(st.session_state.responses),
                "settings": {
                    "temperature": st.session_state.temperature,
                    "max_tokens": st.session_state.max_tokens,
                    "system_prompt": st.session_state.system_prompt,
                },
            },
            "responses": st.session_state.responses,
        }

        try:
            with open(filename, "w", encoding="utf-8") as f:
                json.dump(export_data, f, ensure_ascii=False, indent=2)

            st.success(f"✅ チャット履歴を {filename} に保存しました")

            # ファイルサイズも表示
            file_size = os.path.getsize(filename)
            st.info(f"📁 ファイルサイズ: {file_size:,} bytes")

        except Exception as e:
            st.error(f"❌ 保存エラー: {str(e)}")
    else:
        st.warning("⚠ 保存する履歴がありません")


# サイドバー設定
def setup_sidebar():
    with st.sidebar:
        st.markdown("## ⚙ 設定")

        # モデル設定
        with st.expander("🤖 モデル設定"):
            st.session_state.temperature = st.slider(
                "Temperature",
                min_value=0.0,
                max_value=1.0,
                value=st.session_state.temperature,
                step=0.1,
                help="応答の創造性を調整します",
            )

            st.session_state.max_tokens = st.slider(
                "Max Tokens",
                min_value=100,
                max_value=4000,
                value=st.session_state.max_tokens,
                step=100,
                help="応答の最大長を設定します",
            )

        # システムプロンプト
        with st.expander("📝 システムプロンプト"):
            st.session_state.system_prompt = st.text_area(
                "システムプロンプト",
                value=st.session_state.system_prompt,
                height=100,
                help="AIの役割や振る舞いを定義します",
            )

        # チャット履歴管理
        with st.expander("💾 チャット管理"):
            col1, col2 = st.columns(2)
            with col1:
                if st.button("🗑 履歴クリア", use_container_width=True):
                    st.session_state.responses = []
                    st.rerun()

            with col2:
                if st.button("💾 履歴保存", use_container_width=True):
                    save_chat_history()

            # 履歴のダウンロード
            if st.session_state.responses:
                export_data = {
                    "metadata": {
                        "export_time": datetime.now().isoformat(),
                        "total_responses": len(st.session_state.responses),
                    },
                    "responses": st.session_state.responses,
                }
                chat_json = json.dumps(export_data, ensure_ascii=False, indent=2)
                st.download_button(
                    "📥 JSON形式でダウンロード",
                    data=chat_json,
                    file_name=f"chat_with_tools_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
                    mime="application/json",
                    use_container_width=True,
                )

        # 統計情報
        display_chat_stats()


# チャット統計の表示
def display_chat_stats():
    if st.session_state.responses:
        st.markdown("## 📊 チャット統計")

        total_responses = len(st.session_state.responses)
        messages = [m["item"] for m in st.session_state.responses if "item" in m]
        user_messages = len(
            [m for m in messages if m["type"] == "message" and m["role"] == "user"]
        )
        assistant_messages = len(
            [m for m in messages if m["type"] == "message" and m["role"] == "assistant"]
        )
        function_calls = len([m for m in messages if m["type"] == "function_call"])

        col1, col2 = st.columns(2)
        with col1:
            st.metric("総メッセージ", total_responses)
            st.metric("ユーザー", user_messages)
        with col2:
            st.metric("AI", assistant_messages)
            st.metric("ツール使用", function_calls)


# ツールコールメッセージの表示
def display_tool_call_message(tool_call_data: Dict, metadata: Dict):
    function_name = tool_call_data["name"]
    arguments = tool_call_data["arguments"]
    timestamp = datetime.fromtimestamp(
        int(metadata.get("created", datetime.now().timestamp()))
    )

    icon = "🛠"
    st.markdown(
        f"""
    <div class="chat-message tool-call-message">
        <strong><span class="tool-icon">{icon}</span>ツール呼び出し</strong> <small>{timestamp}</small><br>
        <strong>関数:</strong> {function_name}<br>
        <div class="tool-details">
            <strong>引数:</strong><br>
            {arguments}
        </div>
    </div>
    """,
        unsafe_allow_html=True,
    )


# ツールレスポンスメッセージの表示
def display_tool_response_message(tool_response_data: Dict, metadata: Dict):
    output = tool_response_data["output"]
    timestamp = datetime.fromtimestamp(
        int(metadata.get("created", datetime.now().timestamp()))
    )

    icon = "🗨️"
    st.markdown(
        f"""
    <div class="chat-message tool-response-message">
        <strong><span class="tool-icon">{icon}</span>ツール結果</strong> <small>{timestamp}</small><br>
        <div class="tool-result">
            <strong>結果:</strong><br>
            {output}
        </div>
    </div>
    """,
        unsafe_allow_html=True,
    )


# メッセージの表示
def display_message(item: Dict, metadata: Dict):
    role = item["role"]
    content = item["content"][0]["text"]
    timestamp = datetime.fromtimestamp(int(metadata.get("created", 0)))

    if role == "user":
        st.markdown(
            f"""
        <div class="chat-message user-message">
            <strong>👤 あなた</strong> <small>{timestamp}</small><br>
            {content}
        </div>
        """,
            unsafe_allow_html=True,
        )
    elif role == "assistant":
        st.markdown(
            f"""
        <div class="chat-message assistant-message">
            <strong>🤖 AI Assistant</strong> <small>{timestamp}</small><br>
            {content}
        </div>
        """,
            unsafe_allow_html=True,
        )


# レスポンスの処理振り分け
def process_response(response: Dict):
    if "item" not in response:
        display_message({"role": "assistant", "content": [{"text": "Error."}]}, {})
    elif response["item"].get("type") == "message":
        display_message(response.get("item"), response.get("metadata", {}))
    elif response["item"].get("type") == "function_call":
        display_tool_call_message(response.get("item"), response.get("metadata", {}))
    elif response["item"].get("type") == "function_call_output":
        display_tool_response_message(
            response.get("item"), response.get("metadata", {})
        )


# レスポンスからメッセージを作成
def create_message(response: Dict):
    if response.get("item", {}).get("type", "") == "message":
        return {
            "role": response["item"].get("role"),
            "content": response["item"]["content"][0]["text"],
        }

    return None


# ユーザー入力の処理
def process_user_input(user_input: str, chat_container):
    """ユーザー入力を処理してAIレスポンスを生成"""

    with chat_container:
        # ユーザーメッセージを追加
        user_message = {
            "item": {
                "type": "message",
                "role": "user",
                "content": [{"text": user_input}],
            },
            "metadata": {"created": datetime.now().timestamp()},
        }
        st.session_state.responses.append(user_message)
        process_response(user_message)

        # AIレスポンスを生成

        with st.spinner("🤔 考え中..."):
            # メッセージ履歴を準備
            api_messages = []
            system_prompt = st.session_state.system_prompt
            if system_prompt:
                system_resp = {
                    "item": {
                        "type": "message",
                        "role": "system",
                        "content": [{"text": system_prompt}],
                    },
                    "metadata": {"created": datetime.now().timestamp()},
                }
                api_messages.append(create_message(system_resp))
            for resp in st.session_state.responses:
                msg = create_message(resp)
                if msg and msg["role"] in ["user", "asssitant", "system"]:
                    api_messages.append(msg)

            # エージェントへのストリーミングクエリ実行
            for response in query_endpoint(
                endpoint_name=os.getenv("SERVING_ENDPOINT"),
                messages=api_messages,
                max_tokens=st.session_state.max_tokens,
                temperature=st.session_state.temperature,
            ):
                st.session_state.responses.append(response)
                process_response(response)

    st.rerun()


# メイン画面
def main():
    # ヘッダー
    st.markdown(
        '<h1 class="main-header">🛠 AI エージェントチャット</h1>', unsafe_allow_html=True
    )

    # チャット履歴の表示
    chat_container = st.container()
    with chat_container:
        for i, response in enumerate(st.session_state.responses):
            process_response(response)

    # チャット入力
    st.markdown("---")

    # プリセット質問(レスポンスタイプ別)
    st.markdown("### 💡 クイックスタート")

    st.markdown("#### 💬 一般的な質問例")
    col1, col2, col3, col4 = st.columns(4)
    with col1:
        if st.button("📝 文章作成", use_container_width=True):
            process_user_input("ビジネスメールの書き方を教えてください", chat_container)
    with col2:
        if st.button("💡 アイデア", use_container_width=True):
            process_user_input(
                "新しいビジネスアイデアを5つ提案してください", chat_container
            )
    with col3:
        if st.button("🔍 説明", use_container_width=True):
            process_user_input(
                "機械学習について分かりやすく説明してください", chat_container
            )
    with col4:
        if st.button("🎯 アドバイス", use_container_width=True):
            process_user_input("時間管理のコツを教えてください", chat_container)

    # メッセージ入力
    user_input = st.chat_input("メッセージを入力してください...")

    if user_input:
        process_user_input(user_input, chat_container)


# サイドバーとメイン画面の設定
setup_sidebar()
main()

# フッター
st.markdown("---")
st.markdown(
    f"""
    <div style='text-align: center; color: #666; padding: 1rem;'>
        🛠 AIエージェントチャット<br>
        <small>Powered by Databricks Apps</small><br>
    </div>
    """,
    unsafe_allow_html=True,
)
model_serving_utils.py
from mlflow.deployments import get_deploy_client
from typing import Generator


def query_endpoint(
    endpoint_name: str,
    messages: list[dict[str, str]],
    max_tokens,
    temperature,
) -> Generator:
    """
    チャット補完またはエージェントサービングエンドポイントをクエリします。
    複数のメッセージを返すエージェントサービングエンドポイントをクエリする場合、
    このメソッドは最後のメッセージを返します。
    """
    return get_deploy_client("databricks").predict_stream(
        endpoint=endpoint_name,
        inputs={
            "input": messages,
            "max_output_tokens": max_tokens,
            "temperature": temperature,
        },
    )
requirements.txt
mlflow>=3.1.1
streamlit==1.46.1

これで完了です。
Databricks Appsのコンソール上からデプロイを実行することで変更内容を反映してデプロイしましょう。

image.png

Step3: 使う

アプリを開くと以下のような画面が表示されます。
サイドバーから二つのパラメータとシステムプロンプト、チャット履歴の管理ができます。

image.png

チャット欄からクエリを入力することで対話的に実行できます。普通のチャットボットですね。

image.png

事前に定義した2種だけですが、ツール呼び出しも対応しています。
例えば、特定のサイトの内容を取り出して処理したりもできます。
(Unity Catalog Functionsの制約もあって、文章量が大きいサイトなどはうまく利用できません)

image.png

DatabricksのMCPサーバ機能に対応したエージェントが利用できるため、Genieスペースをツールに利用するなどするとビジネスに利用しやすいチャットボットが容易に作れるのではないかと思います。

まとめ

機能的にはまだまだ貧弱ですが、Databricks Appsを使ってエージェントとの対話型チャットボットアプリを作ってみました。Databricks単体でバックエンド/フロントエンド両方を作成してデプロイできるのは本当に便利です。また、当然本番環境グレードでの運用も可能。すごい。

また、こういったアプリはDatabricks Marketplaceでも公開されています。

例えばDatabricks社から公開されている以下のアプリはGenie APIを利用する対話的なチャットアプリです。デザインがDatabricksにマッチしてイケてる。

image.png

他にもいろいろ出ており、github上にコードも公開されているので勉強がてらいろいろ見てみようと思います。

Appsを使った個人的所感として、認証・認可関連はまだ理解不足なところがあり、このあたり種々実験したいと考えています。

今回の反省点は・・・素のstreamlitを使うとどうしてもstreamlit感が出るので、もうちょっとクールなデザインにしたかったところですね。センスのあるデザイン能力が欲しい。。。

とはいえ、Databricksを通じてエンドユーザ向けのアプリを提供できるようになったのは本当にありがたいです。この調子で東京リージョンでもAgentBricksなどもどんどん利用できるようになることを願っています!

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?