1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Databricks Model Servingで非transformersのカスタムモデルを動かす

Posted at

導入

DatabricksはMLFlowに登録したモデル(LLM含む)を本番利用可能な性能で公開できるモデルサービング機能を備えています。
また、LLMにおいては、一部のモデルで「プロビジョニングされたスループットの基盤モデルAPIs」を利用でき、対応モデルであればスループットを最適化した推論を実行できます。
(Tensor-RT LLMを基にしたカスタム推論エンジンを利用していると思われる)

LLMのモデルサービングにおいては、極力「プロビジョニングされたスループットの基盤モデルAPIs」を利用するべきです。(現状、AI Playgroundなどの機能はプロビジョニングされたスループットの基盤モデルのみ対応など、対応の優先度が異なる)

一方、

  • 非対応のモデルを利用したい(Phi-3、CommandR、Qwenなど)
  • 単純なLLMのサービングではなく、いくつかの処理を組み合わせたものをサービングしたい
  • 最新の量子化したモデルを利用したい
    など、ユースケースによっては「プロビジョニングされたスループットの基盤モデルAPIs」を利用できないケースも多々あると思います。

加えて、こういうときにtransformersではない推論エンジン/フレームワーク(vLLMやExLlamaV2、llama.cppなど)を利用したいときもあるでしょう。

というわけで、今回はそんなニッチな要望があるときのカスタムモデルサービングをやってみます。

題材として、「EXL2量子化フォーマットのモデルを高速に推論したい」という内容でやってみます。
おそらく、同種のやり方でvLLMなどの利用もできると思います。(未検証)

今回の処理に使った環境はDatabricks on AWS、DBRは15.1ML、クラスタのインスタンスタイプはg5.xlargeです。

Model Servingのコンテナイメージ作成における挙動をあまり理解していないため、今回の内容はCUDA周りの考慮など足りていないかもしれません。
挙動がおかしい・将来的に動かなくなるといった可能性があることを承知ください。

Step1. カスタムモデルを準備する

試験用のカスタムモデルを準備します。
上記の題材のように、EXL2で量子化されたモデルファイルを利用して(つまりはExLlamaV2を利用して)、チャットモデルを準備します。

必要なパッケージのインストール

まず、パッケージをインストール。
こちらの記事で作成したrequirements.txtを使ってCUDA12.1用pytorch 2.3.0およびこちらこちらから、同じくCUDA12.1に対応したFlash AttensionとExLlamaV2のWheelファイルを事前にダウンロード・Unity Catalog Volumesにアップしておき、そちらを使ってインストールします。
※ Wheelファイルをダウンロードして利用した理由は、後ほど再利用するため。

また、MLFlowを合わせて最新化。

# pytorch 2.3.0 for CUDA 12.1
%pip install -r /Volumes/training/llm/packages/torch/requirements.txt

%pip install /Volumes/training/llm/packages/exllamav2/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
%pip install /Volumes/training/llm/packages/exllamav2/exllamav2-0.0.20+cu121-cp311-cp311-linux_x86_64.whl

%pip install "mlflow-skinny[databricks]>=2.12.2"

dbutils.library.restartPython()

pyfuncカスタムモデルの定義

MLFlowのpyfuncカスタムモデルを作成。
コードは長いので折り畳み。ほとんどがExLlamaV2での推論用処理です。

pyfunc カスタムモデルクラスの定義
from typing import List
import uuid

import mlflow
from mlflow.types.llm import ChatResponse

# Define a custom PythonModel
class ExLlamaV2ChatModel(mlflow.pyfunc.ChatModel):
    def __init__(self, prompt_map, pre_prompt="", post_prompt=""):        
        # チャット用のプロンプトテンプレート設定用
        self.prompt_map = prompt_map
        self.pre_prompt = pre_prompt
        self.post_prompt = post_prompt

    # Apply prompt format
    def format_messages(self, messages):
        prompt = self.pre_prompt
        for mes in messages:
            template = self.prompt_map.get(mes.role)
            if template:
                prompt += template.format(mes.content)

        prompt += self.post_prompt

        return prompt

    def load_context(self, context):
        from exllamav2 import (
            ExLlamaV2,
            ExLlamaV2Config,
            ExLlamaV2Cache_Q4,
            ExLlamaV2Tokenizer,
        )
        from exllamav2.generator import ExLlamaV2BaseGenerator

        model_directory = context.artifacts["llm-model"]

        config = ExLlamaV2Config(model_directory)

        model = ExLlamaV2(config)
        print("Loading model: " + model_directory)

        # Q4 KV-cache
        cache = ExLlamaV2Cache_Q4(
            model,
            lazy=True,
            max_seq_len=8192,
        ) 
        model.load_autosplit(cache)

        tokenizer = ExLlamaV2Tokenizer(config)

        self._model = model
        self._cache = cache
        self._tokenizer = tokenizer

    def predict(
        self,
        context,
        messages: List[mlflow.types.llm.ChatMessage],
        params: mlflow.types.llm.ChatParams,
    ):
        from exllamav2.generator import ExLlamaV2BaseGenerator, ExLlamaV2Sampler

        # プロンプト構築
        prompt = self.format_messages(messages)

        # サンプリングの設定
        settings = ExLlamaV2Sampler.Settings()
        settings.temperature = params.temperature
        settings.top_k = 50
        settings.top_p = 0.8
        settings.token_repetition_penalty = 1.05
        max_tokens = params.max_tokens if params.max_tokens else 512

        generator = ExLlamaV2BaseGenerator(self._model, self._cache, self._tokenizer)
        output = generator.generate_simple(
            prompt,
            settings,
            max_tokens,
            seed=1234,
            add_bos=True,
            completion_only=True,
        )

        prompt_tokens = (
            self._tokenizer.encode(prompt).size()[1] + 1
        )  # add_bos=Trueのため、1を追加カウント
        completion_tokens = self._tokenizer.encode(output).size()[1]
        usage = {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens,
        }

        id = str(uuid.uuid4())
        finish_reason = "stop" if completion_tokens < max_tokens else "length"

        response = {
            "id": id,
            "model": "local model",
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": output},
                    "finish_reason": finish_reason,
                }
            ],
            "usage": usage,
        }

        return ChatResponse(**response)

    def __getstate__(self):
        # ExLlamaV2モデルはPickle化から除外
        state = self.__dict__.copy()
        del state["_model"]
        del state["_cache"]
        del state["_tokenizer"]
        return state

インスタンス作成とMLflowへの登録

上記で作成したクラスをインスタンス化します。
なお、今回利用するLLMは下記の記事で作成したカラクリ社のkarakuri-lm-8x7b-chat-v0.1(EXL2量子化版)を利用します。

model_path = "/Volumes/training/llm/model_snapshots/models--karakuri-ai--karakuri-lm-8x7b-chat-v0.1-exl2--3.5bpw/"
prompt_map = {
    "system": "[INST]<<SYS>>\\n{}\\n<</SYS>>\\n\\n[/INST]",
    "user": "[INST]{}[/INST]",
    "assistant": "{}</s>",
}

model = ExLlamaV2ChatModel(prompt_map)

次に、MLFlowへ登録する際のConda環境設定を準備。
何をしているかというと、ここで必要な外部依存のパッケージ類をリストアップし、MLflowへの登録時に利用するためです。
ここで必要な依存関係の設定ができないと、モデルサービングでのエンドポイント作成において、コンテナイメージ作成時もしくは推論実行時にエラーがでます。
今回はExLlamaV2を利用するのに必要なものをリストアップ。

import mlflow

extra_pip_requirements = [
    "astunparse==1.6.3",
    "cython==0.29.32",
    "dill==0.3.6",
    "opt-einsum==3.3.0",
    "pynvml==11.5.0",
    "tokenizers==0.15.0",
    "-r /Volumes/training/llm/packages/torch/requirements.txt",
    "/Volumes/training/llm/packages/exllamav2/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl",
    "/Volumes/training/llm/packages/exllamav2/exllamav2-0.0.20+cu121-cp311-cp311-linux_x86_64.whl",
]

# conda_env
conda_env = mlflow.pyfunc.get_default_conda_env()
conda_env["dependencies"][-1] = {
    "pip": ["mlflow==2.12.2"]
    + mlflow.pyfunc.get_default_pip_requirements()
    + extra_pip_requirements
}

作成したカスタムモデルクラスのインスタンスおよび依存関係情報を利用して、MLflowへカスタムモデルを登録。

import mlflow
import os

mlflow.set_registry_uri("databricks-uc")

registered_model_name = "training.llm.test_exllamav2_chatmodel"

with mlflow.start_run() as run:
    _ = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=model,
        artifacts={
            "llm-model": model_path,
        },
        conda_env=conda_env,
        example_no_conversion=True,
        await_registration_for=1200,  # モデルサイズが大きいので長めの待ち時間にします
        registered_model_name=registered_model_name,  # 登録モデル名 in Unity Catalog
    )

依存ライブラリの追加

今回のポイント。
通常であればここまでMLflowからモデル利用が可能になるのですが、モデルサービングで利用するためには事前にライブラリ関連ファイルをモデルのアーティファクトとして追加しておく必要があります。

from pprint import pprint
from mlflow import MlflowClient
import mlflow.models.utils

# 最新のバージョンを取得
client=MlflowClient()
model_name = registered_model_name
versions = [mv.version for mv in client.search_model_versions(f"name='{model_name}'")]

# add_libraries_to_modelを使って、必要なライブラリファイルをモデルのアーティファクトとして追加
mlflow.models.utils.add_libraries_to_model(f"models:/{registered_model_name}/{versions[0]}")

実行すると新たなバージョンのモデルが作成され、wheelsフォルダに必要なパッケージのwheelファイルが保管されます。

image.png

ここまでで準備完了です。

Step2. カスタムモデルをモデルサービングエンドポイントに登録する

Databricks Model ServingのREST APIを使って、Step1.で準備したモデルに対するエンドポイントを作成します。
(UI上からでももちろん作成可能です)

import requests
import json

# モデルレジストリの最新バージョンを取得しなおし
client=MlflowClient()
model_name = registered_model_name
versions = [mv.version for mv in client.search_model_versions(f"name='{model_name}'")]

# APIアクセス用のルートパスとAPIトークンを取得
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# REST API
data = {
    "name": "test-exllamav2-endpoint",
    "config":{
        "served_entities": [
        {
            "entity_name": registered_model_name,
            "entity_version": versions[0],
            "workload_type": "GPU_MEDIUM",
            "workload_size": "Small",
            "scale_to_zero_enabled": True
        }]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers)

print(json.dumps(response.json(), indent=4))

問題なければモデルサービングにエンドポイントが作成されます。
※ 私のDatabricks環境下だと、イメージの作成に20~30分、その後のサービングまで20分ぐらいかかりました。

エンドポイントが出来たら後、確認のためにクエリを投げてみます。


data = {
  "messages": [
    {
      "role": "system",
      "content": "あなたは日本語を話すAIアシスタントです。"
    },
    {
      "role": "user",
      "content": "Databricksの特長を箇条書きで述べてください。"
    }
  ],
  "temperature": 1.0,
  "max_tokens": 100,
  "stop": [
    "\n"
  ],
  "n": 1,
  "stream": False
}
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(
    url=f"{API_ROOT}/serving-endpoints/test-exllamav2-endpoint/invocations", json=data, headers=headers
)

print(json.dumps(response.json(), ensure_ascii=False))
出力
{"id": "66ff2b59-0c4d-4e87-9cb1-fefb816b3012", "model": "local model", "choices": [{"index": 0, "message": {"role": "assistant", "content": "* 大規模データセットの処理に適している\n* Apache Spark の利点を活用できる\n* UI が使いやすく、管理が容易\n* 高い拡張性と互換性\n* コラボレーション機能があり、チームでの作業が効率的"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 72, "completion_tokens": 95, "total_tokens": 167}, "object": "chat.completion", "created": 1715503289}

うまく動いていますね。

まとめ

カスタムモデルをDatabricks Model Servingでデプロイしてみました。
transformersだとパッケージの依存関係などをあまり気にしなくて良さそうなのですが、別のパッケージを利用する場合は上記のような流れになるかと思います。(もう少し洗練できそうですが)

繰り返しになりますが、特にチャット用のLLMをデプロイする分にはなるべく「プロビジョニングされたスループットの基盤モデル」が利用可能ではないか検討したほうがよいです。なるべくマネージドなサービスに合わせましょう。
とはいえ、カスタムモデルが必要なシチュエーションもあると思いますので、適切に使い分けて行ければなと思います。

参考リンク

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?