速さが足りない!
導入
最近知ったのですが、litestarという比較的新しめのAPIサーバ(Web)フレームワークがあり、FastAPIと同等もしくはそれ以上に高速に動作し、かつ機能もいろいろあるようです。
LLMと組み合わせるAPIサーバ関連もいろいろ試してみたいと思っていたので、今回はExLlamaV2とlitestarを使って、なるべく速い(と思われる)チャットAPIサーバを、Databricks上で構築してみます。
この記事で説明している方法は開発・実験用途での利用を推奨します。
本番用途は、Databricks標準のサーバレスモデルサービングエンドポイントの利用を検討ください。
これが最速、というわけではないと思います。
が、体感結構速い構成ではないかと。。。
DBRは14.1ML、GPUクラスタを利用しました。
今回のステップ
以下の3ステップで実装します。
- MLflowモデルの作成
- MLflowモデルのロギング
- APIサーバの構築
- テスト用の簡易チャットアプリ構築
Step1. MLflowモデルの作成
チャット用のモデルをMLFlowのカスタムモデルとして作成します。
exllamav2_pyfunc_chat_model.py
というファイルを作成し、MLflow用のpyfuncカスタムクラスを定義。
フルのソースコードは以下。
exllamav2_pyfunc_chat_model.py
import mlflow
import pandas as pd
import torch
import transformers
import os
import json
from typing import Any, List, Dict, Union, Mapping, Optional, AsyncIterable, Awaitable
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.schema.runnable.base import RunnableSequence
from langchain.prompts import ChatPromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory
from operator import itemgetter
from exllamav2_chat import ChatExllamaV2Model
# Define a custom PythonModel
class ExllamaV2ChatChainModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""モデルを初期化する。モデルのパスはartifactsから取得"""
model_path_local = context.artifacts["llm-model"]
# LLM/Tokenizerのロード
_model = ChatExllamaV2Model.from_model_dir(model_path_local)
self.exllama_model = _model.exllama_model
self.exllama_config = _model.exllama_config
self.exllama_tokenizer = _model.exllama_tokenizer
self.exllama_cache = _model.exllama_cache
def chat_chain(
self,
history: List[Mapping[str, str]] = None,
memory_window_size: int = 4,
system_prompt: str = "",
params: Optional[Dict[str, Any]] = None,
) -> RunnableSequence:
"""
Generates a chain of responses for a given history of conversation.
Args:
history (List[Mapping[str, str]]): A list of conversation messages represented as a dictionary with "user" and "assistant" keys.
memory_window_size (int): An integer representing the number of previous conversation messages used to generate the current response.
system_prompt (str): An optional string representing a prompt that can be used by the system to start off the conversation.
params (Optional[Dict[str, Any]]): An optional dictionary containing additional parameters to be passed to the model.
Returns:
RunnableSequence: An object representing the entire conversation with chain responses.
Examples:
>>> params = {
"human_message_template": "GPT4 Correct User: {}<|end_of_turn|>",
"ai_message_template": "GPT4 Correct Assistant: {}",
"repetition_penalty": 1.2,
"temperature": 0.1,
"max_new_tokens": 512,
}
>>> history = [
{"user": "Hi there!", "assistant": "Hello!"},
{"user": "How are you doing today?", "assistant": "I'm doing well, thanks for asking. How about you?"},
]
>>> system_prompt = "You are a helpful chatbot."
>>> chain = chatbot.chat_chain(history=history, system_prompt=system_prompt, params=params)
>>> chain.invoke(""I'm doing pretty well too. Do you have any plans for the weekend?")
"""
_history = history or []
_params = params or {}
_doc_line_sep = _params.get("prompt_line_separator", "\n")
# 会話履歴から、Memoryを構成
memory = ConversationBufferWindowMemory(
k=memory_window_size,
memory_key="history",
input_key="question",
return_messages=True,
)
for h in _history:
memory.save_context(
{"question": h["user"]},
{"output": h["assistant"]},
)
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_prompt),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{question}"),
AIMessagePromptTemplate.from_template(""),
]
)
chat_model = ChatExllamaV2Model(
exllama_config=self.exllama_config,
exllama_model=self.exllama_model,
exllama_tokenizer=self.exllama_tokenizer,
exllama_cache=self.exllama_cache,
**_params,
)
chain = (
{
"history": RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
"question": RunnablePassthrough(),
}
| response_prompt
| chat_model
| StrOutputParser()
)
return chain
def predict(
self,
context,
model_input: List[str],
params: Optional[Dict[str, Any]] = None,
) -> List[str]:
_params = params or {}
system_prompt = _params.get("system_prompt", "")
chain = self.chat_chain(system_prompt=system_prompt, params=_params)
return chain.batch(model_input)
ポイントは下記のchat_chain
メソッドで、この中でlangchainのChainを作成しています。
また、カスタムシステムプロンプトと会話履歴を与えることができるようにしています。
ChatExllamaV2Model
はこちらの記事で作成したExLlama V2をlangchainのチャットモデルとして利用できるようにするためのクラスです。
こちらを利用して、LCELを使ってChainを定義します。
def chat_chain(
self,
history: List[Mapping[str, str]] = None,
memory_window_size: int = 4,
system_prompt: str = "",
params: Optional[Dict[str, Any]] = None,
) -> RunnableSequence:
_history = history or []
_params = params or {}
_doc_line_sep = _params.get("prompt_line_separator", "\n")
# 会話履歴から、Memoryを構成
memory = ConversationBufferWindowMemory(
k=memory_window_size,
memory_key="history",
input_key="question",
return_messages=True,
)
for h in _history:
memory.save_context(
{"question": h["user"]},
{"output": h["assistant"]},
)
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_prompt),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{question}"),
AIMessagePromptTemplate.from_template(""),
]
)
chat_model = ChatExllamaV2Model(
exllama_config=self.exllama_config,
exllama_model=self.exllama_model,
exllama_tokenizer=self.exllama_tokenizer,
exllama_cache=self.exllama_cache,
**_params,
)
chain = (
{
"history": RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
"question": RunnablePassthrough(),
}
| response_prompt
| chat_model
| StrOutputParser()
)
return chain
Step2. MLflowモデルのロギング
新たにノートブックを作成し、LLMをMLflow上に記録します。
今回もOpenChat 3.5のGPTQ量子化モデルを事前にダウンロードしたものを使います。
まずは必要なモジュールをインストール。
%pip install -U -qq "mlflow-skinny[databricks]>=2.8.1" "langchain==0.0.340" "transformers==4.35.2" "accelerate==0.24.1" "exllamav2==0.0.9"
dbutils.library.restartPython()
MLflowへモデルを登録するための準備。保管するモデルの名前や、必要なメタ情報を作成します。
ExllamaV2ChatChainModel
はStep1.で作成したMLFlowカスタムモデルです。
import mlflow
import os
from mlflow.models.signature import infer_signature
from exllamav2_pyfunc_chat_model import ExllamaV2ChatChainModel
# Unity Catalogのモデル機能を利用
mlflow.set_registry_uri("databricks-uc")
# 変換済みモデルのパス
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat_3.5-GPTQ"
# 登録モデル情報
model_name = "training.llm.simple_chat_model"
model_alias = "dev"
# モデルが推論で保持するデフォルトのパラメータ
default_params = {
"prompt_line_separator": "\n",
"human_message_template": "GPT4 User: {}<|end_of_turn|>",
"ai_message_template": "GPT4 Assistant: {}",
"system_message_template": "{}",
"temperature": 0.5,
"top_k":50,
"top_p":0.8,
"repetition_penalty": 1.1,
"max_new_tokens": 512,
"verbose": False,
}
sample_input = ["LLMとは何ですか?"]
sample_output = ["LLMとは大規模言語モデルのことです。"]
# mlflow保存用のsignature作成
signature = infer_signature(sample_input, sample_output, default_params)
code_path = [f"{os.getcwd()}/exllamav2_chat.py"]
準備が終わりましたので、MLFlowにモデルをロギングします。
with mlflow.start_run() as run:
_ = mlflow.pyfunc.log_model(
artifact_path="model",
python_model=ExllamaV2ChatChainModel(),
extra_pip_requirements=[
"langchain>=0.0.340",
"exllamav2==0.0.9",
"transformers>=4.35.2",
"accelerate>=0.24.1",
], # 依存ライブラリ
signature=signature,
artifacts={
"llm-model": model_path,
},
code_path=code_path,
await_registration_for=1200, # モデルサイズが大きいので長めの待ち時間にします
input_example=sample_input,
registered_model_name=model_name, # 登録モデル名 in Unity Catalog
)
最後に、登録したモデルのAliasを設定。今回はdevに変更します。
from exllamav2_pyfunc_chat_model import ExllamaV2ChatChainModel
import mlflow
mlflow.set_registry_uri("databricks-uc")
# 最新モデルバージョンの特定
model_name = "training.llm.simple_chat_model"
model_alias = "dev"
model_uri = f"models:/{model_name}@{model_alias}"
# mlflowからモデルを取得
model = mlflow.pyfunc.load_model(model_uri)
こんな感じでカタログエクスプローラ上でも登録されたモデルを確認できます。
Step3. APIサーバの構築
今回の本題であるlitestarを使ってAPIサーバを構築します。
新たにノートブックを作成し、必要なモジュールをインストール。
typing-extensions
の最新化が必要でした。
%pip install -U -qq "mlflow-skinny[databricks]>=2.8.1" "langchain==0.0.340" "transformers==4.35.2" "accelerate==0.24.1" "exllamav2==0.0.9"
%pip install -U typing-extensions
%pip install litestar uvicorn nest_asyncio
dbutils.library.restartPython()
MLflowのモデルレジストリからモデルを読み込み。
from exllamav2_pyfunc_chat_model import ExllamaV2ChatChainModel
import mlflow
mlflow.set_registry_uri("databricks-uc")
# モデルの指定
model_name = "training.llm.simple_chat_model"
model_alias = "dev"
model_uri = f"models:/{model_name}@{model_alias}"
# mlflowからモデルを取得
model = mlflow.pyfunc.load_model(model_uri)
litestarの方式でAPIを定義します。
感覚的にはFastAPIとよく似ています。
また、ストリーム処理も対応しており、ジェネレータをStream
クラスでラップしたレスポンスを返せば実現できます。
また、Server Sent Event(SSE)専用のレスポンスクラスServerSentEvent
もありました。
用途によってはこちらを使う方がいいかもしれません。
最後にuvicornを使ってサーバを起動しています。
from litestar import Litestar, get, post
from litestar.response import Stream, ServerSentEvent
from dataclasses import dataclass, field
from collections.abc import AsyncGenerator
import nest_asyncio
import uvicorn
import logging
# 不要なログ情報が表示されてしまうのを抑制
logging.getLogger("py4j").setLevel(logging.ERROR)
# データクラス群
@dataclass
class ChatHistory:
user: str
assistant: str
@dataclass
class ChatRequest:
user_input: str
history: list[ChatHistory] = field(default_factory=list)
system_prompt: str = ""
@post("/")
async def chat(data: ChatRequest) -> Stream:
# デフォルトパラメータの取得&上書き
params = dict(
[
(d["name"], d["default"])
for d in model.metadata.get_params_schema().to_dict()
]
)
params.update({"temperature": 0.1, "max_new_tokens": 512})
# システムプロンプト
system_prompt = data.system_prompt
# チャット履歴
history = [{"user": h.user, "assistant": h.assistant} for h in data.history]
# Chainの取得
chain = model.unwrap_python_model().chat_chain(
system_prompt=system_prompt,
history=history,
params=params,
)
# ストリーム処理
async def generator() -> AsyncGenerator[bytes, None]:
async for text in chain.astream(data.user_input):
yield text
return Stream(generator)
app = Litestar([chat])
if __name__ == "__main__":
nest_asyncio.apply()
uvicorn.run(app, port=5000, log_level="info")
Step4. テスト用の簡易チャットアプリ構築
API動作を確認するために、Streamlitで簡単なチャットアプリを作ります。
手前みそですが、StreamlitをDatabricksで動作させる方法は以下を参照ください。
アプリ側のコードは以下です。
"""
単純なQAアプリUI
"""
import os
import sys
import streamlit as st
import argparse
from typing import Any, Callable, Dict, List, Optional
import json
import requests
st.set_page_config(
page_title="Simple QA",
page_icon="📝",
layout="centered",
initial_sidebar_state="auto",
)
def parse_args(args):
"""今回はアクセストークンとAPIのURLをStreamlitの起動パラメータで渡す"""
parser = argparse.ArgumentParser("LLM")
parser.add_argument("--token", help="Access Token", required=True)
parser.add_argument("--api", help="API Server URL", required=True)
return parser.parse_args(args)
def call_api(url, api_token, prompt, history, system_prompt, placeholder):
"""APIサーバにリクエストを発行、ストリーミングで結果を取得しplaceholderに出力"""
data = {
"user_input": prompt,
"history": history,
"system_prompt": system_prompt,
}
headers = {"Authorization": f"Bearer {api_token}", "Accept": "text/event-stream"}
full_response = ""
s = requests.Session()
with s.post(url, headers=headers, json=data, stream=True) as resp:
for token in resp.iter_content(decode_unicode=True):
if token:
full_response += token
placeholder.markdown(full_response + "▌")
if not resp.ok:
return None, resp.status_code
return full_response, None
if "messages" not in st.session_state:
st.session_state["messages"] = []
st.title("Simple QA")
reset_button = st.button("New Chat")
if reset_button:
st.session_state["messages"].clear()
args = parse_args(sys.argv[1:])
system_prompt = st.text_input("System Prompt", "You are helpful assistant.")
st.chat_message("assistant").write("何についてチャットしますか?")
for msg in st.session_state.messages:
with st.chat_message("user"):
st.write(msg["user"])
with st.chat_message("assistant"):
st.write(msg["assistant"])
if prompt := st.chat_input("Send a message."):
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
with st.spinner("typing..."):
history = st.session_state.messages
full_response, err = call_api(
args.api,
args.token,
prompt,
history,
system_prompt,
message_placeholder,
)
message_placeholder.markdown(full_response)
st.session_state.messages.append({"user": prompt, "assistant": full_response})
というわけで、動かしてみました。
GIF画像なので環境によって再生速度が違うかもしれませんが、体感それなりに高速に出力されました。
まとめ
litestarというかExLlama V2の速さを体感するためにやった感じになりましたが、悪くない構成ではないでしょうか。
MLflowを使うとモデル管理も楽になりますので、これからもいろいろ試してみたいと思います。