こちらの続きです。
導入
前回の記事でMLFlowのChatModelを活用したダミーモデルを作って見ましたが、今度は実際に推論できるモデルを登録して、ついでにFastAPIを使ってOpenAI API互換のAPIサーバを立ててみます。
MLFlow ChatModelについては↑の記事を確認ください。
構築はDatabricks on AWS上で実施しました。DBRは14.3ML、クラスタタイプはg4dn.xlargeです。
Step1. パッケージインストール
ノートブックを作成し、必要なパッケージをインストール。
%pip install -U -qq "mlflow-skinny[databricks]>=2.11.1" "langchain==0.1.11" "transformers==4.38.2" "accelerate==0.27.2" "exllamav2>=0.0.15" "lm-format-enforcer==0.9.2" "pydantic==2.6.3"
dbutils.library.restartPython()
Step2. カスタムChatModelの作成
MLFlowのカスタムChatModelクラスを定義。
今回はpyfunc_custom_model2.pyというファイルを作成し、そこに以下のクラスを定義しました。
LLMによる推論はExLlamaV2を使っており、以下の記事で作成したLangChain用カスタムクラスChatExllamaV2Model
を流用しています。
from typing import List
import uuid
import mlflow
from mlflow.types.llm import ChatResponse
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from exllamav2_chat import ChatExllamaV2Model
# Define a custom PythonModel
class MyChatModel2(mlflow.pyfunc.ChatModel):
def __init__(self, model_path: str):
self._model_path = model_path
def _convert_chat_messages_to_langchain_messages(
self, messages
) -> list[BaseMessage]:
conv_dict = {
"system": SystemMessage,
"user": HumanMessage,
"assistant": AIMessage,
}
new_messages = []
for mes in messages:
cl = conv_dict.get(mes.role)
if cl:
new_messages.append(cl(content=mes.content, name=mes.name))
else:
new_messages.append(BaseMessage(content=mes.content, name=mes.name))
return new_messages
def load_context(self, context):
"""モデルを初期化 """
self._model = ChatExllamaV2Model.from_model_dir(
self._model_path,
cache_max_seq_len=2048,
system_message_template="[INST] <<SYS>>\n{}\n<</SYS>>\n",
human_message_template="{}[/INST]",
ai_message_template="{}",
temperature=0,
top_p=0.0001,
max_new_tokens=512,
repetition_penalty=1.15,
)
def predict(
self,
context,
messages: List[mlflow.types.llm.ChatMessage],
params: mlflow.types.llm.ChatParams,
):
# System Promptが含まれない場合、デフォルトのシステムプロンプトを設定
include_system_role = len([m for m in messages if m.role == "system"]) > 0
if not include_system_role:
messages = [
mlflow.types.llm.ChatMessage(
role="system",
content="You are a helpful AI assistant.",
name="default system prompt",
)
] + messages
# 推論
lc_messages = self._convert_chat_messages_to_langchain_messages(messages)
tmp_model = ChatExllamaV2Model.from_model(self._model)
tmp_model.temperature = params.temperature if params.temperature else 1.0
tmp_model.max_new_tokens = params.max_tokens if params.max_tokens else 512
result = tmp_model.invoke(lc_messages)
id = str(uuid.uuid4())
usage = { # 今回は割愛
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
response = {
"id": id,
"model": "MyChatModel",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": result.content},
"finish_reason": "stop",
}
],
"usage": usage,
}
return ChatResponse(**response)
def __getstate__(self):
# ExLlamaV2モデルはPickle化から除外
state = self.__dict__.copy()
del state['_model']
return state
Step3. モデルの登録
MLFlowのモデルレジストリにモデルを登録します。
LLMは以下の記事でEXL2量子化したELYZA-japanese-Llama-2-13B
を利用しました。
import mlflow
import os
mlflow.set_registry_uri("databricks-uc")
# 変換済みモデルのパス
model_path = "/Volumes/training/llm/model_snapshots/models--local--ELYZA-japanese-Llama-2-13b-fast-instruct-4.0bpw-h6-exl2/"
code_path = [f"{os.getcwd()}/exllamav2_chat.py"]
extra_pip_requirements = [
"langchain==0.1.11",
"transformers==4.38.2",
"accelerate==0.27.2",
"exllamav2>=0.0.15",
"lm-format-enforcer==0.9.2",
"pydantic==2.6.3",
] # 依存ライブラリ
with mlflow.start_run() as run:
_ = mlflow.pyfunc.log_model(
artifact_path="model",
python_model=MyChatModel2(model_path),
extra_pip_requirements=extra_pip_requirements,
code_path=code_path,
registered_model_name="training.llm.test_chatmodel", # 登録モデル名 in Unity Catalog
)
Step4. モデルのテスト
登録したモデルをロードしてテスト実行してみます。
# セッション初期化
dbutils.library.restartPython()
import mlflow
import os
from pprint import pprint
mlflow.set_registry_uri("databricks-uc")
model_uri = f"models:/training.llm.test_chatmodel/1" #1はバージョン番号
model = mlflow.pyfunc.load_model(model_uri)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a hello world program in python"},
]
result = model.predict({"messages": messages, "temperature": 0.1})
pprint(result)
{'choices': [{'finish_reason': 'stop',
'index': 0,
'message': {'content': ' Databricksは、Amazon Web Services (AWS) '
'やMicrosoft '
'Azureなどのパブリッククラウドやプライベートクラウドにデータウェアハウスを構築・管理するサービスを提供しています。',
'role': 'assistant'}}],
'created': 1710150392,
'id': '4899a3a3-017f-4617-8215-8ca79a40e6de',
'model': 'MyChatModel',
'object': 'chat.completion',
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}}
回答が得られていますね。
Step5. OpenAI API互換のチャットサーバ立ち上げ
せっかくなのでFastAPIを使ってOpenAI API互換のAPIサーバも立ててみます。
Databricksで本番運用する場合は、Model Servingの利用を検討ください。
別にノートブックを作り、APIサーバ起動の処理を作成します。
まずは必要なパッケージをインストール。
%pip install -U -qq "mlflow-skinny[databricks]>=2.11.1" "langchain==0.1.11" "transformers==4.38.2" "accelerate==0.27.2" "exllamav2>=0.0.15" "lm-format-enforcer==0.9.2" "pydantic==2.6.3"
%pip install fastapi uvicorn nest_asyncio
dbutils.library.restartPython()
先ほど登録したモデルをロード。
import mlflow
import os
mlflow.set_registry_uri("databricks-uc")
model_uri = f"models:/training.llm.test_chatmodel/1"
model = mlflow.pyfunc.load_model(model_uri)
/v1/chat/completions
にルーティングするようにFastAPIを定義。
モデルにリクエストボディを渡しているだけのシンプルな内容です。
これらを実行してAPIサーバを起動しておきます。
from fastapi import FastAPI
import nest_asyncio
import uvicorn
app = FastAPI()
@app.post("/v1/chat/completions")
async def chat(body: dict):
return model.predict(body)
def start():
nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
start()
Step6. OpenAIクライアントで通信
OpenAIのpythonクライアントからリクエストを実行してみます。
別にノートブックを作成&Step5.のノートブックと同じクラスタをアタッチし、以下を実行。
%pip install -U openai
dbutils.library.restartPython()
import os
# OpenAI Client用の接続先設定
os.environ["OPENAI_BASE_URL"] = "http://localhost:8000/v1/"
os.environ["OPENAI_API_KEY"] = "EMPTY"
from openai import OpenAI
from pprint import pprint
client = OpenAI()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "What is Apache Spark?",
}
],
model="gpt-3.5-turbo",
)
pprint(chat_completion.dict())
{'choices': [{'finish_reason': 'stop',
'index': 0,
'logprobs': None,
'message': {'content': ' Apache '
'Sparkは、オープンソースのデータ処理フレームワークです。分散型コンピューティングアーキテクチャを採用しており、大規模データのリアルタイムな分析や機械学習を高速に実行することができます。また、RDD '
'(Resilient Distributed Datasets) '
'という概念を用いており、並列計算エンジンとしてHadoopのMapReduceに加え、Streaming、MLlibなどの機能を提供しています。',
'function_call': None,
'role': 'assistant',
'tool_calls': None}}],
'created': 1710151277,
'id': '542f60de-401e-4934-9e26-0e90454228ed',
'model': 'MyChatModel',
'object': 'chat.completion',
'system_fingerprint': None,
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}}
OpenAIのクライアントを使って推論回答を取得できました!
まとめ
MLFlow ChatModelを使ってOpenAI APIサーバと一部互換のAPIサーバを手軽に建ててみました。
MLFlowのモデルとして登録さえできれば、かなりシンプルなコードで互換APIサーバをたてることができます。
(エラーハンドリングなどの処理は必要ですが)
OpenAI APIと完全互換というわけではないですが、LangChain OpenAIのサーバとしても利用できるため、使い勝手は良いのではないでしょうか。
Databricks Model Servingとの組み合わせでより本番を見据えた運用もできると思います。
MLFlow、もっと日本でも普及していいと思う。
(私が知らないだけでみんな使ってるのかしら。。。)