結構はまりました。
導入
以前の記事でMLflowのChatAgentを調べていました。
その後、MLflow ver 2.20.2がリリースされ一部ChatAgentインターフェースの変更が行われました。
また、ドキュメントも整備され始めています。
これで当分破壊的変更は来ないんじゃないかなと思うので、ChatAgentインターフェースを使ってカスタムモデルを定義し、さらにDatabricks Mosaic AI Model Serving上でサーブしてみたいと思います。
MLflowのバージョンは2.20.2を利用しました。
また、将来的なDatabricks Mosaic AI Model Servingの変更によって下記コードは動かなくなる可能性があります。
開発・検証はDatabricks on AWSで行いました。
クラスタのDBRは16.2、インスタンスタイプはg5.xlarge(GPU:A10)です。
やること
ChatAgentインターフェースというAgentを想定したインターフェースではあるのですが、基本に戻って(?)ローカルLLMのサービングに利用してみます。
今回は最近Rinna社が公開したQwen2.5-BakenekoシリーズをSGLangを使って動かします。
モデルは以下を事前にUnity Catalog Volumes上へダウンロードして利用しました。
今回のハマりポイント
- 2025/2/16時点では、
ChatAgent
インターフェースのモデルをDatabricks Mosaic AI Model Servingでサーブする再は、metadata
を変更する必要がある- 一時的なものだと思いますが、ハマりました
- SGLangとDatabricks Mosaic AI Model Servingの組み合わせ関連
- ひとまずCUDA Graphは無効化しておく(CUDA_HOME環境変数がデフォルトでは設定されないため)
- MLflow ロギング時の
FlashInfer
依存関係指定は、whlを直接指定しよう
Step1. ChatAgentインターフェースのカスタムモデルを定義
まずはノートブックを作成し、mlflow 2.20.2およびSGLangに必要なパッケージをインストールします。
%pip install flashinfer_python>=0.2.1.post1 --no-deps -i https://flashinfer.ai/whl/cu124/torch2.5
%pip install sgl-kernel --force-reinstall --no-deps
%pip install "sglang[srt]==0.4.3" openai
%pip install "mlflow-skinny[databricks]==2.20.2" loguru
%restart_python
次にChatAgentカスタムモデルを定義。
SGLangを使ってLLMの応答結果を返すシンプルな内容です。
コードは少々長いのですが、概ね以下の事をしています。
-
%%writefile
コマンドを使って以下処理を.pyファイルとして出力。これはmlflowの[models-from-code] (https://mlflow.org/docs/latest/model/models-from-code.html) としてカスタムモデルを使うため。 -
load_context
内でSGLangのサーバを起動し、predict
/predict_stream
内でそのサーバへクエリを発行・結果を返しています。-
predict
メソッドではChatAgentResponse
オブジェクトを、predict_stream
ではChatAgentChunk
オブジェクトを返す必要があり、使い分けが必要です。
-
- 細かい話ですが、簡易実装のためSGLangのTool Callingには対応させていません。
%%writefile "./sglang_online_chat_model.py"
from mlflow.pyfunc.model import ChatAgent
from mlflow.types.agent import (
ChatAgentMessage,
ChatAgentResponse,
ChatContext,
ChatAgentChunk,
)
from mlflow.types.chat import ChatUsage
from mlflow.models import set_model
from typing import Optional, Any, Generator
import openai
from loguru import logger
from sglang.utils import (
execute_shell_command,
wait_for_server,
terminate_process,
)
from sglang.srt.server_args import ServerArgs
DEFAULT_HOST_ADDRESS = "127.0.0.1"
class SGLangChatAgent(ChatAgent):
def __init__(self, server_process=None, port=30000):
self.server_process = server_process
self.port = port
self.model_name = "Unknown"
self.client = None
def load_context(self, context):
"""Load the model from the context."""
model_path = context.artifacts["llm-model"]
model_config = (
context.model_config.get("server", {}) if context.model_config else {}
)
if not self.server_process:
self.server_process, self.port = self._launch_sglang_server(
model_path, model_config
)
self.client = openai.Client(
base_url=f"http://{DEFAULT_HOST_ADDRESS}:{self.port}/v1", api_key="None"
)
logger.info(f"Server started at http://{DEFAULT_HOST_ADDRESS}:{self.port}")
def _launch_sglang_server(self, model_path, model_config):
"""Start the server process."""
if not model_path:
raise ValueError("model_path is required")
default_args = ServerArgs(model_path=model_path)
self.model_name = model_config.get("model_name", "Unknown")
port = model_config.get("port", default_args.port)
mem_fraction_static = model_config.get(
"mem-fraction-static", default_args.mem_fraction_static
)
additional_args = model_config.get("additional_args", "")
cli_args = (
f"python -m sglang.launch_server "
f"--model-path {model_path} "
f"--host {DEFAULT_HOST_ADDRESS} "
f"--port {port} "
f"--mem-fraction-static {mem_fraction_static} "
f"--disable-cuda-graph "
f"{additional_args}"
)
logger.info(f"Launching server with args: {cli_args}")
server_process = execute_shell_command(cli_args)
wait_for_server(f"http://{DEFAULT_HOST_ADDRESS}:{port}")
return server_process, port
def predict(
self,
messages: list[ChatAgentMessage],
context: Optional[ChatContext] = None,
custom_inputs: Optional[dict] = None,
) -> ChatAgentResponse:
if not self.client:
return ChatAgentResponse(
**{"messages": [{"role": "assistant", "content": "no response."}]}
)
# list[ChatAgentMessage]のメッセージ入力を辞書型に変換
llm_messages = self._convert_messages_to_dict(messages)
logger.debug(f"Messages: {llm_messages}")
custom_inputs = custom_inputs or {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=llm_messages,
stream=False,
**custom_inputs,
)
logger.debug(f"Response: {response}")
response_dict = response.to_dict()
output_message = response_dict.get("choices")[0].get("message")
usage = response_dict.get("usage", {})
return ChatAgentResponse(**{"messages": [output_message], "usage": usage})
def predict_stream(
self,
messages: list[ChatAgentMessage],
context: Optional[ChatContext] = None,
custom_inputs: Optional[dict[str, Any]] = None,
) -> Generator[ChatAgentChunk, None, None]:
if not self.client:
return
# list[ChatAgentMessage]のメッセージ入力を辞書型に変換
llm_messages = self._convert_messages_to_dict(messages)
logger.debug(f"Messages: {llm_messages}")
stream_response = self.client.chat.completions.create(
model=self.model_name,
messages=llm_messages,
stream=True,
**custom_inputs,
)
for chunk in stream_response:
if chunk.choices[0].delta.content:
yield ChatAgentChunk(
delta=ChatAgentMessage(
role="asssitant", content=chunk.choices[0].delta.content
),
finish_reason=chunk.choices[0].finish_reason,
custom_outputs=None,
usage=None,
)
def shutdown(self):
if self.server_process:
terminate_process(self.server_process)
self.server_process = None
logger.info("Shutdown LLM")
model = SGLangChatAgent()
set_model(model)
Step2. モデルのロギング
依存関係設定やモデル設定を準備し、mlflowにロギングします。
あわせてDatabricks Mosaic AI Model Servingで利用するために、Databricks Unity Catalog上に登録しています。(今回はtraining.llm
というカタログ・スキーマにモデルを登録)
import mlflow
import os
mlflow.set_registry_uri("databricks-uc")
extra_pip_requirements = [
"torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
"https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl", # 事前ビルド済みのWheelファイルを直接指定
"threadpoolctl==3.5.0",
"sgl-kernel",
"sglang[srt]==0.4.3",
"loguru",
"openai",
]
pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + extra_pip_requirements
model_path = (
"/Volumes/training/llm/model_snapshots/models--rinna--qwen2.5-bakeneko-32b-instruct-awq/"
)
model_config = {
"server": {
# "model_name": "rinna--qwen2.5-bakeneko-32b-instruct-awq",
"model_name": "SakanaAI",
"port": 20000,
"mem-fraction-static": 0.88,
"additional_args": "--log-level error",
}
}
artifacts = {
"llm-model": model_path,
}
# Unity Catalogへの登録先
registered_model_name = "training.llm.test_sglang_model"
with mlflow.start_run() as run:
_ = mlflow.pyfunc.log_model(
artifact_path="model",
python_model="sglang_online_chat_model.py",
artifacts=artifacts,
model_config=model_config,
pip_requirements=pip_requirements,
metadata={"task": "llm/v1/chat"},
await_registration_for=3600,
registered_model_name=registered_model_name,
)
2024/2/16時点では、ロギング時にmetadata
を指定してtask
を上書きすることを忘れないでください。
ChatAgentを用いて作成したカスタムクラスはtask
へ自動的にagent/v2/chat
を指定します。
一方、この指定がある場合、Databricks Mosaic AI Model Servingでエンドポイントを作成する際に以下のようなエラーが出ます。
Missing feedback served model. To serve an agent model, please deploy the model through the databricks-agents SDK.
用途的にdatabricks-agents SDKを使わずにエンドポイント構築をしたいため、ここではあらかじめmetadata
を異なるものに変更しています。
Step3. モデルのサービング
登録したモデルをDatabricks Mosaic AI Model Servingでサーブします。
以下のコードを実行してtest_sglang_model_endpoint
という名前のエンドポイントを作成。
import requests
import json
import mlflow
from mlflow import MlflowClient
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
# 登録モデルの最新バージョンを取得
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()
versions = [
mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")
]
# エンドピントの設定情報を定義
data = {
"name": f"test_sglang_model_endpoint",
"config": {
"served_entities": [
{
"entity_name": registered_model_name,
"entity_version": versions[0],
"workload_type": "GPU_MEDIUM",
"workload_size": "Small",
"scale_to_zero_enabled": True,
}
]
},
}
# APIを呼び出してエンドポイントを作成
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}
response = requests.post(
url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers
)
print(json.dumps(response.json(), indent=4))
Step4. 実行する
MLflowを使ってモデルをロードするケースと、サービングしたエンドポイントを使うケースの2種で推論してみます。
まずは、MLflowのload_model
でモデルをロードして使います。
import mlflow
mlflow.set_registry_uri("databricks-uc")
# Unity Catalog登録モデルのURL
model_url = f"models:/{registered_model_name}/{versions[0]}"
# モデルのロード
loaded_model = mlflow.pyfunc.load_model(model_url)
ようやく推論。
各種パラメータはcustom_inputs
として渡すことになります。
# 推論
messages = [{"role": "user", "content": "Databricksについて解説して"}]
custom_inputs = {"temperature": 0.1, "max_tokens": 500}
loaded_model.predict({"messages": messages, "custom_inputs": custom_inputs})
{'messages': [{'role': 'assistant',
'content': 'Databricksは、Apache Sparkの主要な貢献者であるMichael Stonebraker、Ion Stoica、および若干の他の研究者によって設立されたデータ分析プラットフォームの会社です。このプラットフォームは、データエンジニアリングとデータサイエンスの両方のニーズを満たすように設計されています。\n\nDatabricksは、クラウドベースの統合データ分析プラットフォームを提供しており、データの準備、共有、および分析を一元化します。これにより、データサイエンティストやエンジニアは、データの準備からモデルの構築、デプロイまでの一連の作業を効率的に進めることができます。\n\n主な特徴は以下の通りです:\n\n1. **Apache Sparkの最適化**:DatabricksはApache Sparkの専門家であり、そのプラットフォームはSparkのパフォーマンスを最大限に引き出すように設計されています。これにより、大量のデータを高速に処理することが可能になります。\n\n2. **統合されたワークスペース**:Databricksは、データの準備、共有、分析を一元化するワークスペースを提供します。これにより、チームはデータの準備からモデルの構築、デプロイまでの一連の作業を効率的に進めることができます。\n\n3. **機械学習のサポート**:Databricksは、機械学習のための高度な機能を提供します。これには、モデルの構築、訓練、評価、デプロイが含まれます。また、MLflowというオープンソースの機械学習ライフサイクル管理ツールも提供しています。\n\n4. **クラウドネイティブ**:Databricksは、AWS、Azure、Google Cloudなどの主要なクラウドプロバイダーと統合されています。これにより、ユーザーはクラウドのスケーラビリティと柔軟性を活用することができます。\n\n5. **コラボレーションと共有**:Databricksは、チームが共同で作業し、データとコードを共有するための機能を提供します。これにより、チームはデータの準備から分析までの一連の作業を効率的に進めることができます。\n\nDatabricksは、データエン',
'id': '67391102-c023-48c9-b4b5-3d05ac67c7de'}],
'usage': {'prompt_tokens': 37, 'completion_tokens': 500, 'total_tokens': 537}}
ストリーミング出力も試してみます。
for c in loaded_model.predict_stream({"messages": messages, "custom_inputs": custom_inputs}):
print(c["delta"]["content"], flush=True, end="")
2025-02-16 08:50:09.243 | DEBUG | code_model_7868a8497e6c4fc68e17ccb0cd023e9a:predict_stream:124 - Messages: [{'role': 'user', 'content': 'Databricksについて解説して', 'id': '0060ffe9-e7ba-4046-bd6a-440cf819f74d'}]
Databricksは、Apache Sparkの主要な貢献者であるMichael Stonebraker、Ion Stoica、および他の研究者たちによって設立された企業が運営するクラウドベースのデータ処理プラットフォームです。このプラットフォームは、データエンジニアリングとデータサイエンスの作業を簡素化し、効率化することを目指しています。
### 主な特徴
1. **Apache Sparkの最適化**
- DatabricksはApache Sparkの専門家によって作られ、Sparkのパフォーマンスと機能を最大限に引き出すように設計されています。これにより、大量のデータを高速に処理することが可能になります。
2. **統合されたワークスペース**
- Databricksは、データの準備、分析、機械学習のための統合されたワークスペースを提供します。これにより、データエンジニア、データサイエンティスト、データアナリストが協力し、共同作業を行うことができます。
3. **クラウドネイティブ**
- DatabricksはAWS、Azure、Google Cloudなどの主要なクラウドプロバイダー上で動作します。これにより、柔軟性とスケーラビリティが確保され、クラウドのリソースを効率的に利用することができます。
4. **機械学習のサポート**
- Databricksは、機械学習のための高度な機能を提供します。MLflowというオープンソースのフレームワークを使用して、モデルの開発、トレーニング、デプロイメントを管理することができます。
5. **データ共有とコラボレーション**
- Databricksは、チーム間でのデータ共有とコラボレーションを容易にする機能を提供します。ノートブック、ダッシュボード、データセットを共有し、共同作業を行うことができます。
6. **セキュリティとコンプライアンス**
- Databricksは、データの保護とコンプライアンスの要件を満たすためのセキュリティ機能を提供します。これには、認証、アクセス制御、暗号化などが含まれます。
### 用途
- **データエンジニアリング**
- データの準備
次にDatabricks Mosaic AI Model Servingのエンドポイントにクエリを投げてみます。
import mlflow.deployments
from pprint import pprint
client = mlflow.deployments.get_deploy_client("databricks")
# 推論
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "生成AIの今後の動向について説明してください。"},
]
custom_inputs = {"temperature": 0.1, "max_tokens": 500}
response = client.predict(
endpoint="test_sglang_model_endpoint",
inputs={
"messages": messages,
"custom_inputs": custom_inputs,
},
)
pprint(response)
{'messages': [{'content': '生成AIは、テキスト、画像、音楽など、あらゆる形式のコンテンツを生成するためのAI技術です。今後の動向について以下に説明します。\n'
'\n'
'1. 進化する技術: '
'生成AIは、より高度な学習アルゴリズムと大量のデータにより、ますます進化しています。これにより、生成されたコンテンツの質と精度が向上し、より自然でリアルな結果が得られるようになります。\n'
'\n'
'2. 多様な用途: '
'生成AIは、コンテンツ作成、デザイン、音楽作成、ゲーム開発など、さまざまな分野で利用されています。今後は、教育、医療、製造など、さらに多くの業界で利用される可能性があります。\n'
'\n'
'3. 人間との協働: '
'生成AIは、人間の創造性を補完し、新たなアイデアや解決策を提供するツールとして利用される可能性があります。これにより、人間とAIが協力して、より複雑で革新的なプロジェクトを実現することが可能になります。\n'
'\n'
'4. 法的・倫理的問題: '
'生成AIの利用は、著作権、プライバシー、フェイクニュースなどの法的・倫理的問題を引き起こす可能性があります。これらの問題に対処するための法的枠組みやガイドラインの開発が求められています。\n'
'\n'
'これらの動向は、生成AIの進化と利用の可能性を示していますが、同時に、その影響と課題に対処するための継続的な研究と対話も必要です。',
'id': 'a814b97a-8f14-4fd5-862b-61142cd16ca5',
'role': 'assistant'}],
'usage': {'completion_tokens': 349, 'prompt_tokens': 32, 'total_tokens': 381}}
ストリーミング出力はどうでしょうか。
response = client.predict_stream(
endpoint="test_sglang_model_endpoint",
inputs={
"messages": messages,
"custom_inputs": custom_inputs,
},
)
for c in response:
pprint(c)
HTTPError: 400 Client Error: Encountered an unexpected error while parsing the input data. Error 'This endpoint does not support streaming.' for url: https://xxx/serving-endpoints/test_sglang_model_endpoint/invocations. Response text: {"error_code": "BAD_REQUEST", "message": "Encountered an unexpected error while parsing the input data. Error 'This endpoint does not support streaming.'"}
残念ながら未対応のようです。
MLflowのIssueにはあがってるのですが、期待して待ちたいですね。
まとめ
MLflow ChatAgentを使ったカスタムモデルを作成しました。
エージェント用のインターフェースということもあり、正直LLMのサービング用としては逆に冗長感がありますね。temperature
などのパラメータをcustom_inputs
で渡す必要があったりなど。
今回のユースケースだと従来のChatModel
を利用する方がOpenAI APIと互換性のあるインターフェースが組めることもあって適していそうです。
ChatModelについては丁寧に解説されている記事がありますのでご一読を。
次回はちゃんとエージェントとしての役割でChatAgentを利用してみたいと思います。