1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

MLFlow ChatModelでOpenAI API互換のChat APIサーバを立ててみる on Databricks

Posted at

こちらの続きです。

導入

前回の記事でMLFlowのChatModelを活用したダミーモデルを作って見ましたが、今度は実際に推論できるモデルを登録して、ついでにFastAPIを使ってOpenAI API互換のAPIサーバを立ててみます。

MLFlow ChatModelについては↑の記事を確認ください。

構築はDatabricks on AWS上で実施しました。DBRは14.3ML、クラスタタイプはg4dn.xlargeです。

Step1. パッケージインストール

ノートブックを作成し、必要なパッケージをインストール。

%pip install -U -qq "mlflow-skinny[databricks]>=2.11.1" "langchain==0.1.11" "transformers==4.38.2" "accelerate==0.27.2" "exllamav2>=0.0.15" "lm-format-enforcer==0.9.2" "pydantic==2.6.3"

dbutils.library.restartPython()

Step2. カスタムChatModelの作成

MLFlowのカスタムChatModelクラスを定義。
今回はpyfunc_custom_model2.pyというファイルを作成し、そこに以下のクラスを定義しました。

LLMによる推論はExLlamaV2を使っており、以下の記事で作成したLangChain用カスタムクラスChatExllamaV2Modelを流用しています。

from typing import List
import uuid

import mlflow
from mlflow.types.llm import ChatResponse

from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from exllamav2_chat import ChatExllamaV2Model

# Define a custom PythonModel
class MyChatModel2(mlflow.pyfunc.ChatModel):
    def __init__(self, model_path: str):
        self._model_path = model_path

    def _convert_chat_messages_to_langchain_messages(
        self, messages
    ) -> list[BaseMessage]:

        conv_dict = {
            "system": SystemMessage,
            "user": HumanMessage,
            "assistant": AIMessage,
        }

        new_messages = []
        for mes in messages:
            cl = conv_dict.get(mes.role)
            if cl:
                new_messages.append(cl(content=mes.content, name=mes.name))
            else:
                new_messages.append(BaseMessage(content=mes.content, name=mes.name))

        return new_messages

    def load_context(self, context):
        """モデルを初期化 """

        self._model = ChatExllamaV2Model.from_model_dir(
            self._model_path,
            cache_max_seq_len=2048,
            system_message_template="[INST] <<SYS>>\n{}\n<</SYS>>\n",
            human_message_template="{}[/INST]",
            ai_message_template="{}",
            temperature=0,
            top_p=0.0001,
            max_new_tokens=512,
            repetition_penalty=1.15,
        )

    def predict(
        self,
        context,
        messages: List[mlflow.types.llm.ChatMessage],
        params: mlflow.types.llm.ChatParams,
    ):

        # System Promptが含まれない場合、デフォルトのシステムプロンプトを設定
        include_system_role = len([m for m in messages if m.role == "system"]) > 0
        if not include_system_role:
            messages = [
                mlflow.types.llm.ChatMessage(
                    role="system",
                    content="You are a helpful AI assistant.",
                    name="default system prompt",
                )
            ] + messages

        # 推論
        lc_messages = self._convert_chat_messages_to_langchain_messages(messages)
        tmp_model = ChatExllamaV2Model.from_model(self._model)
        tmp_model.temperature = params.temperature if params.temperature else 1.0
        tmp_model.max_new_tokens = params.max_tokens if params.max_tokens else 512

        result = tmp_model.invoke(lc_messages)

        id = str(uuid.uuid4())
        usage = {  # 今回は割愛
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
        }

        response = {
            "id": id,
            "model": "MyChatModel",
            "choices": [
                {
                    "index": 0,
                    "message": {"role": "assistant", "content": result.content},
                    "finish_reason": "stop",
                }
            ],
            "usage": usage,
        }

        return ChatResponse(**response)
    
    def __getstate__(self):
        # ExLlamaV2モデルはPickle化から除外
        state = self.__dict__.copy()
        del state['_model']
        return state

Step3. モデルの登録

MLFlowのモデルレジストリにモデルを登録します。
LLMは以下の記事でEXL2量子化したELYZA-japanese-Llama-2-13Bを利用しました。

import mlflow
import os

mlflow.set_registry_uri("databricks-uc")

# 変換済みモデルのパス

model_path = "/Volumes/training/llm/model_snapshots/models--local--ELYZA-japanese-Llama-2-13b-fast-instruct-4.0bpw-h6-exl2/"

code_path = [f"{os.getcwd()}/exllamav2_chat.py"]

extra_pip_requirements = [
    "langchain==0.1.11",
    "transformers==4.38.2",
    "accelerate==0.27.2",
    "exllamav2>=0.0.15",
    "lm-format-enforcer==0.9.2",
    "pydantic==2.6.3",
]  # 依存ライブラリ
with mlflow.start_run() as run:
    _ = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=MyChatModel2(model_path),
        extra_pip_requirements=extra_pip_requirements,
        code_path=code_path,
        registered_model_name="training.llm.test_chatmodel",  # 登録モデル名 in Unity Catalog
    )

Step4. モデルのテスト

登録したモデルをロードしてテスト実行してみます。

# セッション初期化
dbutils.library.restartPython()
import mlflow
import os
from pprint import pprint

mlflow.set_registry_uri("databricks-uc")
model_uri = f"models:/training.llm.test_chatmodel/1" #1はバージョン番号
model = mlflow.pyfunc.load_model(model_uri)

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Write me a hello world program in python"},
]

result = model.predict({"messages": messages, "temperature": 0.1})
pprint(result)
出力
{'choices': [{'finish_reason': 'stop',
              'index': 0,
              'message': {'content': ' Databricksは、Amazon Web Services (AWS) '
                                     'やMicrosoft '
                                     'Azureなどのパブリッククラウドやプライベートクラウドにデータウェアハウスを構築・管理するサービスを提供しています。',
                          'role': 'assistant'}}],
 'created': 1710150392,
 'id': '4899a3a3-017f-4617-8215-8ca79a40e6de',
 'model': 'MyChatModel',
 'object': 'chat.completion',
 'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}}

回答が得られていますね。

Step5. OpenAI API互換のチャットサーバ立ち上げ

せっかくなのでFastAPIを使ってOpenAI API互換のAPIサーバも立ててみます。

Databricksで本番運用する場合は、Model Servingの利用を検討ください。

別にノートブックを作り、APIサーバ起動の処理を作成します。

まずは必要なパッケージをインストール。

%pip install -U -qq "mlflow-skinny[databricks]>=2.11.1" "langchain==0.1.11" "transformers==4.38.2" "accelerate==0.27.2" "exllamav2>=0.0.15" "lm-format-enforcer==0.9.2" "pydantic==2.6.3"
%pip install fastapi uvicorn nest_asyncio

dbutils.library.restartPython()

先ほど登録したモデルをロード。

import mlflow
import os

mlflow.set_registry_uri("databricks-uc")

model_uri = f"models:/training.llm.test_chatmodel/1"
model = mlflow.pyfunc.load_model(model_uri)

/v1/chat/completionsにルーティングするようにFastAPIを定義。
モデルにリクエストボディを渡しているだけのシンプルな内容です。
これらを実行してAPIサーバを起動しておきます。

from fastapi import FastAPI
import nest_asyncio
import uvicorn

app = FastAPI()

@app.post("/v1/chat/completions")
async def chat(body: dict):
    return model.predict(body)

def start():
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=8000)

if __name__ == "__main__":
    start()

Step6. OpenAIクライアントで通信

OpenAIのpythonクライアントからリクエストを実行してみます。
別にノートブックを作成&Step5.のノートブックと同じクラスタをアタッチし、以下を実行。

%pip install -U openai
dbutils.library.restartPython()
import os

# OpenAI Client用の接続先設定
os.environ["OPENAI_BASE_URL"] = "http://localhost:8000/v1/"
os.environ["OPENAI_API_KEY"] = "EMPTY"
from openai import OpenAI
from pprint import pprint

client = OpenAI()

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": "What is Apache Spark?",
        }
    ],
    model="gpt-3.5-turbo",
)

pprint(chat_completion.dict())
出力
{'choices': [{'finish_reason': 'stop',
              'index': 0,
              'logprobs': None,
              'message': {'content': ' Apache '
                                     'Sparkは、オープンソースのデータ処理フレームワークです。分散型コンピューティングアーキテクチャを採用しており、大規模データのリアルタイムな分析や機械学習を高速に実行することができます。また、RDD '
                                     '(Resilient Distributed Datasets) '
                                     'という概念を用いており、並列計算エンジンとしてHadoopのMapReduceに加え、Streaming、MLlibなどの機能を提供しています。',
                          'function_call': None,
                          'role': 'assistant',
                          'tool_calls': None}}],
 'created': 1710151277,
 'id': '542f60de-401e-4934-9e26-0e90454228ed',
 'model': 'MyChatModel',
 'object': 'chat.completion',
 'system_fingerprint': None,
 'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}}

OpenAIのクライアントを使って推論回答を取得できました!

まとめ

MLFlow ChatModelを使ってOpenAI APIサーバと一部互換のAPIサーバを手軽に建ててみました。
MLFlowのモデルとして登録さえできれば、かなりシンプルなコードで互換APIサーバをたてることができます。
(エラーハンドリングなどの処理は必要ですが)

OpenAI APIと完全互換というわけではないですが、LangChain OpenAIのサーバとしても利用できるため、使い勝手は良いのではないでしょうか。
Databricks Model Servingとの組み合わせでより本番を見据えた運用もできると思います。

MLFlow、もっと日本でも普及していいと思う。
(私が知らないだけでみんな使ってるのかしら。。。)

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?