1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

N番煎じでDeepSeek-R1-0528-Qwen3-8BをDatabricksで試してみる

Posted at

ローカルLLM、いろいろ試したいモデルが溜まっている。。。

導入

DeepSeek社がDeepSeek-R1の最新版であるDeepSeek-R1-0528を公開しました。

また、npaka先生のまとめがわかりやすいと思います。

個人的に最近は動きが少ない印象を持っていたのですがさすがのDeepSeek、ベンチマーク結果は「o3」や「Gemini 2.5 Pro」に匹敵する性能です。
ベンチマークが全てではありませんが、この性能のモデルがオープンウェイトで公開するのは毎度凄い。(動く環境を作るのが大変ですが)

image.png

また、DeepSeek-R1-0528のchain-of-thoughtを蒸留してQwen3-8B Baseに事後学習適用したモデル、DeepSeek-R1-0528-Qwen3-8Bも公開されています。(この文章、どう日本語訳するべきなんだろう。。。)
こちらもオリジナルのQwen3 8Bの性能をAIME 2024で10%向上させ、Qwen3-235B-thinkingに近い性能を発揮させています。とんでもない。

なかなか面白いモデルだと思いましたので、いつものようにDatabricks上で動かしてみたいと思います。
R1本体を動かすのは環境的にも難しいため、今回はDeepSeek-R1-0528-Qwen3-8Bを動かします。
*とりあえず動かした結果だけ見たい方はStep4以降を参照ください。

検証はDatabricks on AWS、ノートブックはGPU(g5.xlarge)の専用モードクラスタ、DBRは16.4 LTS MLを利用しました。

とはいえ、コード内容は以下の記事とほぼ同じです。細かい解説は下記を参照にしてください。

Step1. 準備・カスタムモデルの定義

モデルファイルを以下からダウンロードし、Unity Catalogのボリュームに保管しておきます。
モデルは以下にあります。

今回は以下のようなコードでモデルをダウンロードしました。
Unity Catalogのtraining.llmスキーマに存在するボリュームへモデルを保管しています。


# 別セルで以下を実行し、必要なパッケージをインストールする
# %pip install -U huggingface-hub
# dbutils.library.restartPython()

from typing import Optional
import shutil

model_id = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
revision = None

def download_model(model_id:str, revision:Optional[str]=None):
    import os
    from huggingface_hub import snapshot_download

    api_token = dbutils.secrets.get("huggingface", "access_token")

    UC_VOLUME = "/Volumes/training/llm/model_snapshots"

    rev_dir = ("--" + revision) if revision else ""
    local_dir = f"/tmp/{model_id}{rev_dir}"
    # local_dirを適切なlinux file pathに変換
    local_dir = local_dir.replace(" ", "_").replace(":", "_").replace(".", "_")
    
    uc_dir = f"/models--{model_id.replace('/', '--')}"
    
    snapshot_location = snapshot_download(
        repo_id=model_id,
        revision=revision,
        local_dir=local_dir,
        token = api_token,
        force_download=True,
    )

    shutil.copytree(f"{local_dir}", f"{UC_VOLUME}{uc_dir}{rev_dir}", dirs_exist_ok=True)

download_model(model_id, revision)

次にノートブックを作成し、必要なパッケージをインストールします。

%pip install -qq --no-deps flashinfer_python==0.2.5 -i https://flashinfer.ai/whl/cu124/torch2.6 --trusted-host flashinfer.ai
%pip install -qq "sglang[srt]==0.4.6.post5" "openai==1.82.1"
%pip install -qq "mlflow-skinny[databricks]>=2.22.0" loguru uv rich databricks-agents nest-asyncio

%restart_python

次にMLflowのカスタムチャットモデルを定義します。
最終的にこのモデルをDatabricks Mosaic AI Model Servingにデプロイして利用します。
今回もSGLangを推論エンジンに利用しました。
かなり長いので折り畳み。

カスタムモデル定義
%%writefile sglang_chat_model.py

import os
from typing import Generator

import torch
import mlflow
import sglang as sgl
from sglang.utils import (
    launch_server_cmd,
    wait_for_server,
    print_highlight,
    terminate_process,
)

from mlflow.pyfunc import ChatModel
from mlflow.types.llm import (
    ChatMessage,
    ChatCompletionResponse,
    ChatChoice,
    ChatParams,
    ChatCompletionChunk,
)
from mlflow.models import set_model

from loguru import logger
import openai
import requests
import psutil

SERVER_WAIT_TIMEOUT = 60 * 10
ENV_SGLANG_CLI_ARGS = "SGLANG_CLI_ARGS"

mlflow.openai.autolog()

def terminate_all_sglang_processes():
    """
    すべての実行中のSGLangプロセスを終了します。

    この関数は、すべての実行中のプロセスを反復処理し、SGLangサーバーの起動コマンドに一致する
    プロセスを終了します。
    """

    sglang_processes = []
    for proc in psutil.process_iter():
        check_cmd = ["python", "-m", "sglang.launch_server"]

        try:
            cmd = proc.cmdline()[0:3]
            if cmd == check_cmd:
                sglang_processes.append(proc)
                break
        except:
            pass  # アクセス許可なしの場合など

    if sglang_processes:
        for proc in sglang_processes:
            terminate_process(proc)


def launch_sglang_server(model_path, model_config):
    """
    指定されたモデルパスと構成でSGLangサーバーを起動します。

    Args:
        model_path (str): サーバーで使用するモデルのパス。
        model_config (dict): サーバーの構成ディクショナリ。

    Returns:
        tuple: サーバープロセスとポート番号を含むタプル。

    Raises:
        ValueError: model_pathが提供されていない場合。
    """
    # VRAMの使用量をチェック。First Deviceに十分な容量が無い場合、全てのSGLangプロセスを終了させる
    mem_free, mem_total = torch.cuda.mem_get_info(device=0)
    if mem_free / mem_total < 0.5:  # 50%未満だと全てのSGLangプロセスを終了させる
        logger.warning("Terminating all SGLang processes due to insufficient VRAM.")
        terminate_all_sglang_processes()

    if not model_path:
        raise ValueError("model_path is required")

    config = model_config.get("server", {})

    args_from_env = os.environ.get(
        ENV_SGLANG_CLI_ARGS, ""
    )  # 環境変数から追加引数を取得
    args = args_from_env + " " + " ".join(config.get("args", []))

    cli_args = (
        f"python -m sglang.launch_server "
        f"--model-path {model_path} "
        f"--host 0.0.0.0 "
        f"{args}"
    )

    logger.info(f"Launching server with args: {cli_args}")

    server_process, port = launch_server_cmd(cli_args)
    wait_for_server(f"http://localhost:{port}", timeout=SERVER_WAIT_TIMEOUT)

    return server_process, port


def health_check(port) -> bool:
    """
    指定されたポートで稼働しているサーバーのヘルスチェックを行います。

    Args:
        port (int): ヘルスチェックを行うサーバーのポート番号。

    Returns:
        bool: サーバーが正常に稼働している場合はTrue、そうでない場合はFalse。
    """
    response = requests.get(f"http://localhost:{port}/health")
    return response.ok


def build_messages_for_llm(messages: list[ChatMessage]) -> list[dict]:
    """
    チャットメッセージのリストをSGLangサーバー用の辞書型リストに変換します。

    Args:
        messages (list[ChatMessage]): チャットメッセージのリスト。

    Returns:
        list[dict]: SGLangサーバー用に変換された辞書型のメッセージリスト。
    """

    # contentのNone埋め(sglang利用における制約)
    for msg in messages:
        if msg.content is None:
            msg.content = ""

    # list[ChatAgentMessage]のメッセージ入力を辞書型に変換
    return [msg.to_dict() for msg in messages]


def build_parameters_for_llm(params: ChatParams, stream: bool = False) -> dict:
    """
    推論パラメータを構築します。

    Args:
        params (ChatParams): 予測パラメータ。
        stream (bool): ストリーミングを有効にするかどうか。

    Returns:
        dict: 構築された推論パラメータの辞書。

    Raises:
        ValueError: custom_inputs['response_format']またはcustom_inputs['extra_body']が辞書でない場合。
    """

    # 固有の標準パラメータを利用
    default_params_dict = {
        "temperature": 0.7,
        "top_p": 0.95,
        "max_tokens": 5000,
        "n": 1,
    }

    # default_params_dictに含まれるキーが存在しない場合のみ、そのキーの値を上書き
    dict_params = params.to_dict() if params else {}
    for key, value in default_params_dict.items():
        dict_params.setdefault(key, value)

    # logger.info(f"Sampling parameters: {dict_params}")

    # サーバに指定可能なパラメータ名の一覧
    valid_parameters = [
        "temperature",
        "max_tokens",
        "max_completion_tokens",
        "top_p",
        "n",
        "stop",
        "frequency_penalty",
        "presence_penalty",
        "tools",
        "response_format",
        "extra_body",
    ]

    # Custom Inputsの利用
    if "custom_inputs" in dict_params:
        custom_inputs = dict_params.pop("custom_inputs")
        dict_params.update(custom_inputs)

    # 入力可能なパラメータのみに限定
    dict_params = {k: v for k, v in dict_params.items() if k in valid_parameters}

    # Streamingの設定
    dict_params["stream"] = stream

    return dict_params


class SGLangChatModel(ChatModel):
    """
    SGLangChatModelは、SGLangサーバーを使用してチャットモデルを提供するMLflowのChatModelクラスです。

    Attributes:
        server_process (subprocess.Popen): サーバープロセスのインスタンス。
        port (int): サーバーがリッスンしているポート番号。
        auto_shutdown (bool): オブジェクトの破棄時にサーバーを自動的にシャットダウンするかどうか。
        model_name (str): 使用するモデルの名前。
        client (openai.Client): OpenAIクライアントのインスタンス。

    Methods:
        load_context(context): コンテキストからモデルをロードします。
        predict(messages, params): メッセージに基づいて予測を行います。
        predict_stream(messages, params): メッセージに基づいてストリーム予測を行います。
        _shutdown(): サーバープロセスをシャットダウンします。
        _build_completion_parameter(params, stream): 予測パラメータを構築します。
    """

    def __init__(
        self,
        server_process=None,
        port=None,
        model_name: str = "Unknown",
        auto_shutdown: bool = True,
    ):
        self.server_process = server_process
        self.port = port
        self.auto_shutdown = auto_shutdown
        self.model_name = model_name
        self.client = None

    def load_context(self, context):
        """
        コンテキストからモデルをロードします。

        Args:
            context (mlflow.pyfunc.PythonModelContext): MLflowのPythonモデルコンテキスト。

        Raises:
            ValueError: モデルパスが提供されていない場合。
        """

        if self.server_process is None:

            logger.info(f"Starting server...")

            model_path = context.artifacts["llm-model"]
            model_config = context.model_config or {}

            self.server_process, self.port = launch_sglang_server(
                model_path, model_config
            )
            self.model_name = model_config.get("model", self.model_name)

        self.client = openai.Client(
            base_url=f"http://localhost:{self.port}/v1", api_key="None"
        )

    def predict(self, messages: list[ChatMessage], params: ChatParams = None):
        """
        メッセージに基づいて予測を行います。

        Args:
            messages (list[ChatMessage]): チャットメッセージのリスト。
            params (ChatParams, optional): 予測パラメータ。

        Returns:
            ChatCompletionResponse: 予測結果のレスポンス。

        Raises:
            ValueError: サーバープロセスが存在しない場合。
        """

        # SGLangのサーバプロセスが存在しない場合、ダミーメッセージを返す。
        if self.client is None or not health_check(self.port):
            return ChatCompletionResponse(
                choices=[
                    {
                        "index": 0,
                        "message": {
                            "role": "asssitant",
                            "content": "no response from server.",
                        },
                    }
                ]
            )

        # list[ChatAgentMessage]のメッセージ入力を辞書型に変換
        llm_messages = build_messages_for_llm(messages)

        # 推論パラメータの構築
        dict_params = build_parameters_for_llm(params)

        # Chat Completionの実行
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=llm_messages,
            **dict_params,
        )

        # Reasoning ContentをCustom Outputとして保持
        dict_resp = response.to_dict()
        message = dict_resp["choices"][0]["message"]
        if (
            not "content" in message or message["content"] is None
        ):  # contentを必ず含めるようにする
            message["content"] = ""
        if "reasoning_content" in message:
            dict_resp["custom_outputs"] = {
                "reasoning_content": message["reasoning_content"]
            }

        # 結果の返却
        return ChatCompletionResponse.from_dict(dict_resp)

    def predict_stream(
        self, messages: list[ChatMessage], params: ChatParams = None
    ) -> Generator[ChatCompletionChunk, None, None]:
        """
        メッセージに基づいてストリーム予測を行います。

        Args:
            messages (list[ChatMessage]): チャットメッセージのリスト。
            params (ChatParams, optional): 予測パラメータ。

        Yields:
            ChatCompletionChunk: ストリーム予測結果のチャンク。

        Raises:
            ValueError: サーバープロセスが存在しない場合。
        """

        # SGLangのサーバプロセスが存在しない場合、ダミーメッセージを返す。
        if self.client is None or not health_check(self.port):
            return ChatCompletionResponse(
                choices=[
                    {
                        "index": 0,
                        "message": {
                            "role": "asssitant",
                            "content": "no response from server.",
                        },
                    }
                ]
            )

        # list[ChatAgentMessage]のメッセージ入力を辞書型に変換
        llm_messages = build_messages_for_llm(messages)

        # 推論パラメータの構築
        dict_params = build_parameters_for_llm(params, stream=True)

        # Chat Completionの実行
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=llm_messages,
            **dict_params,
        )

        # ChunkのStream返却
        for chunk in response:
            dict_chunk = chunk.to_dict()
            delta = dict_chunk["choices"][0]["delta"]
            if (
                not "content" in delta or delta["content"] is None
            ):  # contentを必ず含めるようにする
                delta["content"] = ""
            if "reasoning_content" in delta:
                dict_chunk["custom_outputs"] = {
                    "reasoning_content": delta["reasoning_content"]
                }

            yield ChatCompletionChunk.from_dict(dict_chunk)

    def _shutdown(self):
        """
        サーバープロセスをシャットダウンします。

        このメソッドは、サーバープロセスが存在し、かつauto_shutdownがTrueに設定されている場合に、
        サーバープロセスを終了させます。終了後、server_processをNoneに設定し、auto_shutdownをFalseに設定します。
        """
        if self.server_process and self.auto_shutdown:
            logger.info("shutdown sglang server.")
            terminate_process(self.server_process)
            self.server_process = None
            self.auto_shutdown = False

    def __del__(self):
        self._shutdown()


model = SGLangChatModel()
set_model(model)

Step2. モデルのロギング

Step1で定義したカスタムモデルを使ってMLflowにモデルをロギングします。
モデルの設定関連などでコード量が増えていますが、やっていることはlog_modelでモデルを保管することです。

import mlflow
import os

mlflow.set_registry_uri("databricks-uc")

# 依存関係の設定
extra_pip_requirements = [
    "torch==2.6.0 --index-url https://download.pytorch.org/whl/cu124",
    "https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl",
    "sglang[srt]==0.4.6.post5",
    "openai==1.82.1",
    "loguru",
    "psutil",
]
pip_requirements = mlflow.pyfunc.get_default_pip_requirements() + extra_pip_requirements

# 利用する重みファイルの場所
model_path = "/Volumes/training/llm/model_snapshots/models--deepseek-ai--DeepSeek-R1-0528-Qwen3-8B/"

# モデル設定。主にSGLangの起動設定
model_config = {
    "model": os.path.basename(os.path.normpath(model_path)),
    "server": {
        "args": [
            "--mem-fraction-static 0.85",
            "--grammar-backend llguidance",
            "--tool-call-parser qwen25",
            "--max-running-requests 1",
            "--decode-log-interval 1000",
        ],
    },
}

artifacts = {
    "llm-model": model_path,
}

# 入力サンプル
input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is a good recipe for baking scones that doesn't require a lot of skill?",
        }
    ],
    "max_tokens":100,
    "temperature":0.6,
    "top_p":0.8,
    "presence_penalty":1.5,
    "custom_inputs":{"extra_body":{"chat_template_kwargs": {"enable_thinking": False}}}
}

with mlflow.start_run() as run:
    logged_model = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model="sglang_chat_model.py",
        artifacts=artifacts,
        model_config=model_config,
        input_example=input_example,        
        pip_requirements=pip_requirements,
        await_registration_for=3000,
    )

このままUnity Catalogにも登録します。

mlflow.set_registry_uri("databricks-uc")

run_id = logged_model.run_id
model_uri = f"runs:/{run_id}/model"
registered_model_name = "training.llm.sglang_qwen3"

uc_registered_model_info = mlflow.register_model(
    model_uri=model_uri,
    name=registered_model_name,
    await_registration_for=3000,
)

Step3. Model Servingへデプロイする

Mosaic AI Agent Frameworkを使ってDatabricks Mosaic AI Model Servingにエージェントとしてデプロイします。

from databricks import agents
import mlflow
from mlflow import MlflowClient

# エンドポイントの名前
endpoint_name = "chatmodel_" + registered_model_name.replace(".", "__")
print(endpoint_name)

# Unity Catalog上モデルの最新バージョン取得
client = MlflowClient()
versions = [
    mv.version for mv in client.search_model_versions(f"name='{registered_model_name}'")
]
print(versions)

agents.deploy(
    registered_model_name,
    versions[0],
    scale_to_zero=False,
    endpoint_name=endpoint_name,
)

このままだとクラスタータイプがCPUとしてデプロイされるため、サービングメニューのUIからクラスタータイプを変更します。

image.png

問題なければ、数十分後にはサービングエンドポイントへのデプロイが完了します。

Step4. Playgroundで試す

では、実際に利用してみましょう。

思考モデルが得意であろう質問を投げかけてみます。

image.png

実際にはもっと回答が続くのですが、内容的にはなかなか興味深いものでした。
若干日本語がおかしいところはありますので、そこは課題かな。

ちなみに、思考内容はこのようになっています。

image.png

最終出力とは変わって、割とカジュアルというか実際的な文言で思考していますね。
また、そこまで思考過程が長くないのも特徴な気がします。


次にコーディングもさせてみましょう。

image.png

HTMLを保存して動かすと以下のようなモックアプリができていました。

image.png

メモリ制約であまり大した確認までできていませんが、それなりのものを作れる能力がありそうです。
モック作る程度なら必要十分そう。

おわりに

DeepSeek-R1-0528-Qwen3-8BをDatabricks上で試してみました。

8Bという大きくないパラメータサイズの割に、性能の高さを感じました。
簡単に試しただけですが、確かにQwen3 8Bより推論能力が高まっている感覚はあります。

ちょっと残念なのは、Qwen3 8B Baseを基にしたモデルのため、ベースモデルではないQwen3で使える/no_thinkが有効ではないこと。Qwen3のそこを再現するモデルではないので、しょうがないですが。

一時期蒸留が急速に話題になりましたが、また再燃して様々なオープンウェイトモデルが出てくると面白くなるなと思います。最近はClaudeやGeminiを使うことが多いのですが、ローカルLLMは多くの可能性があると考えており、特に小パラメータで協力なLLMがこれからも登場することを期待しています。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?