ローカルLLM、いろいろ試したいモデルが溜まっている。。。
導入
DeepSeek社がDeepSeek-R1の最新版であるDeepSeek-R1-0528を公開しました。
また、npaka先生のまとめがわかりやすいと思います。
個人的に最近は動きが少ない印象を持っていたのですがさすがのDeepSeek、ベンチマーク結果は「o3」や「Gemini 2.5 Pro」に匹敵する性能です。
ベンチマークが全てではありませんが、この性能のモデルがオープンウェイトで公開するのは毎度凄い。(動く環境を作るのが大変ですが)
また、DeepSeek-R1-0528のchain-of-thoughtを蒸留してQwen3-8B Baseに事後学習適用したモデル、DeepSeek-R1-0528-Qwen3-8Bも公開されています。(この文章、どう日本語訳するべきなんだろう。。。)
こちらもオリジナルのQwen3 8Bの性能をAIME 2024で10%向上させ、Qwen3-235B-thinkingに近い性能を発揮させています。とんでもない。
なかなか面白いモデルだと思いましたので、いつものようにDatabricks上で動かしてみたいと思います。
R1本体を動かすのは環境的にも難しいため、今回はDeepSeek-R1-0528-Qwen3-8Bを動かします。
*とりあえず動かした結果だけ見たい方はStep4以降を参照ください。
検証はDatabricks on AWS、ノートブックはGPU(g5.xlarge)の専用モードクラスタ、DBRは16.4 LTS MLを利用しました。
とはいえ、コード内容は以下の記事とほぼ同じです。細かい解説は下記を参照にしてください。
Step1. 準備・カスタムモデルの定義
モデルファイルを以下からダウンロードし、Unity Catalogのボリュームに保管しておきます。
モデルは以下にあります。
今回は以下のようなコードでモデルをダウンロードしました。
Unity Catalogのtraining.llm
スキーマに存在するボリュームへモデルを保管しています。
# 別セルで以下を実行し、必要なパッケージをインストールする
# %pip install -U huggingface-hub
# dbutils.library.restartPython()
from typing import Optional
import shutil
model_id = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
revision = None
def download_model(model_id:str, revision:Optional[str]=None):
import os
from huggingface_hub import snapshot_download
api_token = dbutils.secrets.get("huggingface", "access_token")
UC_VOLUME = "/Volumes/training/llm/model_snapshots"
rev_dir = ("--" + revision) if revision else ""
local_dir = f"/tmp/{model_id}{rev_dir}"
# local_dirを適切なlinux file pathに変換
local_dir = local_dir.replace(" ", "_").replace(":", "_").replace(".", "_")
uc_dir = f"/models--{model_id.replace('/', '--')}"
snapshot_location = snapshot_download(
repo_id=model_id,
revision=revision,
local_dir=local_dir,
token = api_token,
force_download=True,
)
shutil.copytree(f"{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", dirs_exist_ok=True)
download_model(model_id, revision)
次にノートブックを作成し、必要なパッケージをインストールします。
%pip install -qq --no-deps flashinfer_python==0.2.5 -i https://flashinfer.ai/whl/cu124/torch2.6 --trusted-host flashinfer.ai
%pip install -qq "sglang[srt]==0.4.6.post5" "openai==1.82.1"
%pip install -qq "mlflow-skinny[databricks]>=2.22.0" loguru uv rich databricks-agents nest-asyncio
%restart_python
次にMLflowのカスタムチャットモデルを定義します。
最終的にこのモデルをDatabricks Mosaic AI Model Servingにデプロイして利用します。
今回もSGLangを推論エンジンに利用しました。
かなり長いので折り畳み。
カスタムモデル定義
%%writefile sglang_chat_model.py
import os
from typing import Generator
import torch
import mlflow
import sglang as sgl
from sglang.utils import (
launch_server_cmd,
wait_for_server,
print_highlight,
terminate_process,
)
from mlflow.pyfunc import ChatModel
from mlflow.types.llm import (
ChatMessage,
ChatCompletionResponse,
ChatChoice,
ChatParams,
ChatCompletionChunk,
)
from mlflow.models import set_model
from loguru import logger
import openai
import requests
import psutil
SERVER_WAIT_TIMEOUT = 60 * 10
ENV_SGLANG_CLI_ARGS = "SGLANG_CLI_ARGS"
mlflow.openai.autolog()
def terminate_all_sglang_processes():
"""
すべての実行中のSGLangプロセスを終了します。
この関数は、すべての実行中のプロセスを反復処理し、SGLangサーバーの起動コマンドに一致する
プロセスを終了します。
"""
sglang_processes = []
for proc in psutil.process_iter():
check_cmd = ["python", "-m", "sglang.launch_server"]
try:
cmd = proc.cmdline()[0:3]
if cmd == check_cmd:
sglang_processes.append(proc)
break
except:
pass # アクセス許可なしの場合など
if sglang_processes:
for proc in sglang_processes:
terminate_process(proc)
def launch_sglang_server(model_path, model_config):
"""
指定されたモデルパスと構成でSGLangサーバーを起動します。
Args:
model_path (str): サーバーで使用するモデルのパス。
model_config (dict): サーバーの構成ディクショナリ。
Returns:
tuple: サーバープロセスとポート番号を含むタプル。
Raises:
ValueError: model_pathが提供されていない場合。
"""
# VRAMの使用量をチェック。First Deviceに十分な容量が無い場合、全てのSGLangプロセスを終了させる
mem_free, mem_total = torch.cuda.mem_get_info(device=0)
if mem_free / mem_total < 0.5: # 50%未満だと全てのSGLangプロセスを終了させる
logger.warning("Terminating all SGLang processes due to insufficient VRAM.")
terminate_all_sglang_processes()
if not model_path:
raise ValueError("model_path is required")
config = model_config.get("server", {})
args_from_env = os.environ.get(
ENV_SGLANG_CLI_ARGS, ""
) # 環境変数から追加引数を取得
args = args_from_env + " " + " ".join(config.get("args", []))
cli_args = (
f"python -m sglang.launch_server "
f"--model-path {model_path} "
f"--host 0.0.0.0 "
f"{args}"
)
logger.info(f"Launching server with args: {cli_args}")
server_process, port = launch_server_cmd(cli_args)
wait_for_server(f"http://localhost:{port}", timeout=SERVER_WAIT_TIMEOUT)
return server_process, port
def health_check(port) -> bool:
"""
指定されたポートで稼働しているサーバーのヘルスチェックを行います。
Args:
port (int): ヘルスチェックを行うサーバーのポート番号。
Returns:
bool: サーバーが正常に稼働している場合はTrue、そうでない場合はFalse。
"""
response = requests.get(f"http://localhost:{port}/health")
return response.ok
def build_messages_for_llm(messages: list[ChatMessage]) -> list[dict]:
"""
チャットメッセージのリストをSGLangサーバー用の辞書型リストに変換します。
Args:
messages (list[ChatMessage]): チャットメッセージのリスト。
Returns:
list[dict]: SGLangサーバー用に変換された辞書型のメッセージリスト。
"""
# contentのNone埋め(sglang利用における制約)
for msg in messages:
if msg.content is None:
msg.content = ""
# list[ChatAgentMessage]のメッセージ入力を辞書型に変換
return [msg.to_dict() for msg in messages]
def build_parameters_for_llm(params: ChatParams, stream: bool = False) -> dict:
"""
推論パラメータを構築します。
Args:
params (ChatParams): 予測パラメータ。
stream (bool): ストリーミングを有効にするかどうか。
Returns:
dict: 構築された推論パラメータの辞書。
Raises:
ValueError: custom_inputs['response_format']またはcustom_inputs['extra_body']が辞書でない場合。
"""
# 固有の標準パラメータを利用
default_params_dict = {
"temperature": 0.7,
"top_p": 0.95,
"max_tokens": 5000,
"n": 1,
}
# default_params_dictに含まれるキーが存在しない場合のみ、そのキーの値を上書き
dict_params = params.to_dict() if params else {}
for key, value in default_params_dict.items():
dict_params.setdefault(key, value)
# logger.info(f"Sampling parameters: {dict_params}")
# サーバに指定可能なパラメータ名の一覧
valid_parameters = [
"temperature",
"max_tokens",
"max_completion_tokens",
"top_p",
"n",
"stop",
"frequency_penalty",
"presence_penalty",
"tools",
"response_format",
"extra_body",
]
# Custom Inputsの利用
if "custom_inputs" in dict_params:
custom_inputs = dict_params.pop("custom_inputs")
dict_params.update(custom_inputs)
# 入力可能なパラメータのみに限定
dict_params = {k: v for k, v in dict_params.items() if k in valid_parameters}
# Streamingの設定
dict_params["stream"] = stream
return dict_params
class SGLangChatModel(ChatModel):
"""
SGLangChatModelは、SGLangサーバーを使用してチャットモデルを提供するMLflowのChatModelクラスです。
Attributes:
server_process (subprocess.Popen): サーバープロセスのインスタンス。
port (int): サーバーがリッスンしているポート番号。
auto_shutdown (bool): オブジェクトの破棄時にサーバーを自動的にシャットダウンするかどうか。
model_name (str): 使用するモデルの名前。
client (openai.Client): OpenAIクライアントのインスタンス。
Methods:
load_context(context): コンテキストからモデルをロードします。
predict(messages, params): メッセージに基づいて予測を行います。
predict_stream(messages, params): メッセージに基づいてストリーム予測を行います。
_shutdown(): サーバープロセスをシャットダウンします。
_build_completion_parameter(params, stream): 予測パラメータを構築します。
"""
def __init__(
self,
server_process=None,
port=None,
model_name: str = "Unknown",
auto_shutdown: bool = True,
):
self.server_process = server_process
self.port = port
self.auto_shutdown = auto_shutdown
self.model_name = model_name
self.client = None
def load_context(self, context):
"""
コンテキストからモデルをロードします。
Args:
context (mlflow.pyfunc.PythonModelContext): MLflowのPythonモデルコンテキスト。
Raises:
ValueError: モデルパスが提供されていない場合。
"""
if self.server_process is None:
logger.info(f"Starting server...")
model_path = context.artifacts["llm-model"]
model_config = context.model_config or {}
self.server_process, self.port = launch_sglang_server(
model_path, model_config
)
self.model_name = model_config.get("model", self.model_name)
self.client = openai.Client(
base_url=f"http://localhost:{self.port}/v1", api_key="None"
)
def predict(self, messages: list[ChatMessage], params: ChatParams = None):
"""
メッセージに基づいて予測を行います。
Args:
messages (list[ChatMessage]): チャットメッセージのリスト。
params (ChatParams, optional): 予測パラメータ。
Returns:
ChatCompletionResponse: 予測結果のレスポンス。
Raises:
ValueError: サーバープロセスが存在しない場合。
"""
# SGLangのサーバプロセスが存在しない場合、ダミーメッセージを返す。
if self.client is None or not health_check(self.port):
return ChatCompletionResponse(
choices=[
{
"index": 0,
"message": {
"role": "asssitant",
"content": "no response from server.",
},
}
]
)
# list[ChatAgentMessage]のメッセージ入力を辞書型に変換
llm_messages = build_messages_for_llm(messages)
# 推論パラメータの構築
dict_params = build_parameters_for_llm(params)
# Chat Completionの実行
response = self.client.chat.completions.create(
model=self.model_name,
messages=llm_messages,
**dict_params,
)
# Reasoning ContentをCustom Outputとして保持
dict_resp = response.to_dict()
message = dict_resp["choices"][0]["message"]
if (
not "content" in message or message["content"] is None
): # contentを必ず含めるようにする
message["content"] = ""
if "reasoning_content" in message:
dict_resp["custom_outputs"] = {
"reasoning_content": message["reasoning_content"]
}
# 結果の返却
return ChatCompletionResponse.from_dict(dict_resp)
def predict_stream(
self, messages: list[ChatMessage], params: ChatParams = None
) -> Generator[ChatCompletionChunk, None, None]:
"""
メッセージに基づいてストリーム予測を行います。
Args:
messages (list[ChatMessage]): チャットメッセージのリスト。
params (ChatParams, optional): 予測パラメータ。
Yields:
ChatCompletionChunk: ストリーム予測結果のチャンク。
Raises:
ValueError: サーバープロセスが存在しない場合。
"""
# SGLangのサーバプロセスが存在しない場合、ダミーメッセージを返す。
if self.client is None or not health_check(self.port):
return ChatCompletionResponse(
choices=[
{
"index": 0,
"message": {
"role": "asssitant",
"content": "no response from server.",
},
}
]
)
# list[ChatAgentMessage]のメッセージ入力を辞書型に変換
llm_messages = build_messages_for_llm(messages)
# 推論パラメータの構築
dict_params = build_parameters_for_llm(params, stream=True)
# Chat Completionの実行
response = self.client.chat.completions.create(
model=self.model_name,
messages=llm_messages,
**dict_params,
)
# ChunkのStream返却
for chunk in response:
dict_chunk = chunk.to_dict()
delta = dict_chunk["choices"][0]["delta"]
if (
not "content" in delta or delta["content"] is None
): # contentを必ず含めるようにする
delta["content"] = ""
if "reasoning_content" in delta:
dict_chunk["custom_outputs"] = {
"reasoning_content": delta["reasoning_content"]
}
yield ChatCompletionChunk.from_dict(dict_chunk)
def _shutdown(self):
"""
サーバープロセスをシャットダウンします。
このメソッドは、サーバープロセスが存在し、かつauto_shutdownがTrueに設定されている場合に、
サーバープロセスを終了させます。終了後、server_processをNoneに設定し、auto_shutdownをFalseに設定します。
"""
if self.server_process and self.auto_shutdown:
logger.info("shutdown sglang server.")
terminate_process(self.server_process)
self.server_process = None
self.auto_shutdown = False
def __del__(self):
self._shutdown()
model = SGLangChatModel()
set_model(model)
Step2. モデルのロギング
Step1で定義したカスタムモデルを使ってMLflowにモデルをロギングします。
モデルの設定関連などでコード量が増えていますが、やっていることはlog_model
でモデルを保管することです。
import mlflow
import os
mlflow.set_registry_uri("databricks-uc")
# 依存関係の設定
extra_pip_requirements = [
"torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124",
"https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl",
"sglang[srt]==0.4.6.post5",
"openai==1.82.1",
"loguru",
"psutil",
]
pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + extra_pip_requirements
# 利用する重みファイルの場所
model_path = "/Volumes/training/llm/model_snapshots/models--deepseek-ai--DeepSeek-R1-0528-Qwen3-8B/"
# モデル設定。主にSGLangの起動設定
model_config = {
"model": os.path.basename(os.path.normpath(model_path)),
"server": {
"args": [
"--mem-fraction-static 0.85",
"--grammar-backend llguidance",
"--tool-call-parser qwen25",
"--max-running-requests 1",
"--decode-log-interval 1000",
],
},
}
artifacts = {
"llm-model": model_path,
}
# 入力サンプル
input_example = {
"messages": [
{
"role": "user",
"content": "What is a good recipe for baking scones that doesn't require a lot of skill?",
}
],
"max_tokens":100,
"temperature":0.6,
"top_p":0.8,
"presence_penalty":1.5,
"custom_inputs":{"extra_body":{"chat_template_kwargs": {"enable_thinking": False}}}
}
with mlflow.start_run() as run:
logged_model = mlflow.pyfunc.log_model(
artifact_path="model",
python_model="sglang_chat_model.py",
artifacts=artifacts,
model_config=model_config,
input_example=input_example,
pip_requirements=pip_requirements,
await_registration_for=3000,
)
このままUnity Catalogにも登録します。
mlflow.set_registry_uri("databricks-uc")
run_id = logged_model.run_id
model_uri = f"runs:/{run_id}/model"
registered_model_name = "training.llm.sglang_qwen3"
uc_registered_model_info = mlflow.register_model(
model_uri=model_uri,
name=registered_model_name,
await_registration_for=3000,
)
Step3. Model Servingへデプロイする
Mosaic AI Agent Frameworkを使ってDatabricks Mosaic AI Model Servingにエージェントとしてデプロイします。
from databricks import agents
import mlflow
from mlflow import MlflowClient
# エンドポイントの名前
endpoint_name = "chatmodel_" + registered_model_name.replace(".", "__")
print(endpoint_name)
# Unity Catalog上モデルの最新バージョン取得
client = MlflowClient()
versions = [
mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")
]
print(versions)
agents.deploy(
registered_model_name,
versions[0],
scale_to_zero=False,
endpoint_name=endpoint_name,
)
このままだとクラスタータイプがCPUとしてデプロイされるため、サービングメニューのUIからクラスタータイプを変更します。
問題なければ、数十分後にはサービングエンドポイントへのデプロイが完了します。
Step4. Playgroundで試す
では、実際に利用してみましょう。
思考モデルが得意であろう質問を投げかけてみます。
実際にはもっと回答が続くのですが、内容的にはなかなか興味深いものでした。
若干日本語がおかしいところはありますので、そこは課題かな。
ちなみに、思考内容はこのようになっています。
最終出力とは変わって、割とカジュアルというか実際的な文言で思考していますね。
また、そこまで思考過程が長くないのも特徴な気がします。
次にコーディングもさせてみましょう。
HTMLを保存して動かすと以下のようなモックアプリができていました。
メモリ制約であまり大した確認までできていませんが、それなりのものを作れる能力がありそうです。
モック作る程度なら必要十分そう。
おわりに
DeepSeek-R1-0528-Qwen3-8BをDatabricks上で試してみました。
8Bという大きくないパラメータサイズの割に、性能の高さを感じました。
簡単に試しただけですが、確かにQwen3 8Bより推論能力が高まっている感覚はあります。
ちょっと残念なのは、Qwen3 8B Baseを基にしたモデルのため、ベースモデルではないQwen3で使える/no_think
が有効ではないこと。Qwen3のそこを再現するモデルではないので、しょうがないですが。
一時期蒸留が急速に話題になりましたが、また再燃して様々なオープンウェイトモデルが出てくると面白くなるなと思います。最近はClaudeやGeminiを使うことが多いのですが、ローカルLLMは多くの可能性があると考えており、特に小パラメータで協力なLLMがこれからも登場することを期待しています。