以前と似たような記事を書いてる気もしますが、自分の学びとして書いてみます。
導入
前回、MLflowでカスタムモデルを作ってLLMを管理しました。
それを発展させて、MLflowで登録したモデルを実際のアプリ(API+UI)まで繋げてみます。
自分の勉強のためにも、以前の記事から、以下を変更します。
- モデル管理:MLflow上でモデル管理に変更
- APIサーバ:FastAPIからFlaskに変更
- UI:今回は簡易QA画面作成まで記事内に記載
検証はいつも通り全てDatabricks上で行っています。
注意
この記事で説明している方法は開発・実験用途での利用を推奨します。
本番用途は、Databricks標準のサーバレスモデルサービングエンドポイントの利用を検討ください。
早く日本にサービスこないかなあ。
Step1. モデルの登録
前回記事の内容を拡張したMLflowカスタムモデルを作って登録します。
apredict_tokens
メソッドおよびその内部からコールしている処理などが拡張部分。
生成したトークンをストリーミングで出力する処理です。詳しくはFastAPIでストリーミングしたときと同様なのでコード説明は割愛します。
コード(長いので折り畳み)
import mlflow
import pandas as pd
import torch
import ctranslate2
import transformers
from typing import Any, List, Dict, Union, Mapping, Optional, AsyncIterable, Awaitable
from ctranslate2llm import CTranslate2StreamLLM
from langchain.llms.base import LLM
from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks import AsyncIteratorCallbackHandler
import asyncio
import os
# Define a custom PythonModel
class QAModel(mlflow.pyfunc.PythonModel):
def load_context(self, context):
"""モデルを初期化する。モデルのパスはartifactsから取得"""
model_path_local = context.artifacts["ct2-model"]
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator(model_path_local, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path_local, use_fast=False
)
self.generator = generator
self.tokenizer = tokenizer
self.verbose = False
async def _agenerate_tokens(
self,
instruction: str,
llm: CTranslate2StreamLLM,
prompt_template: PromptTemplate,
system_prompt: str = "",
verbose: bool = False,
):
"""生成したトークンを非同期で返す"""
callback = AsyncIteratorCallbackHandler()
llm.callbacks = [callback]
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
print(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()
# Begin a task that runs in the background.
task = asyncio.create_task(
wrap_done(llm_chain.apredict(instruction=instruction), callback.done),
)
async for token in callback.aiter():
if token not in llm.tokenizer.all_special_tokens:
yield f"{token}"
await task
async def _agenerate_batch(self, instructions, llm, prompt):
"""
トークンの全生成後に結果を返す。
複数指示をシリアルに実行するため、バッチ効率はよくない。
"""
results = []
for inst in instructions:
sentence_buffer = []
async for t in self._agenerate_tokens(
inst,
llm=llm,
prompt_template=prompt,
):
sentence_buffer.append(t)
results.append("".join(sentence_buffer))
return results
def predict(
self,
context,
model_input: pd.DataFrame,
params: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
template = params["template"]
temperature = params["temperature"]
max_length = params["max_length"]
verbose = params["verbose"]
instructions = model_input[["instruction"]].to_dict(orient="records")
prompt = PromptTemplate(input_variables=["instruction"], template=template)
llm = CTranslate2StreamLLM(
generator=self.generator,
tokenizer=self.tokenizer,
temperature=temperature,
max_length=max_length,
verbose=verbose,
)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(
self._agenerate_batch(instructions, llm, prompt)
)
return pd.DataFrame(
{
"instruction": [q["instruction"] for q in instructions],
"answer": [r for r in results],
}
)
async def apredict_tokens(
self,
instruction: str,
template:str,
temperature:float=0.5,
max_length:int=256,
verbose:bool=False,
):
""" 非同期に指示の結果を推論し、トークン単位で結果を出力する """
prompt = PromptTemplate(input_variables=["instruction"], template=template)
llm = CTranslate2StreamLLM(
generator=self.generator,
tokenizer=self.tokenizer,
temperature=temperature,
max_length=max_length,
verbose=verbose,
)
async for t in self._agenerate_tokens(
instruction,
llm=llm,
prompt_template=prompt,
):
yield t
このカスタムクラスをMLflowに登録します。
今回はtraining.llm.ct2api
という名前で登録しました。また、エイリアスをstaging
で設定します。
MLflowの登録処理についても前回記事とほとんど同じなので割愛。
Step2. モデルの読込
MLflowに登録したモデルを読み込みます。エイリアス管理は便利。
# 最新モデルバージョンの特定
model_name = "training.llm.ct2api"
model_uri = f"models:/{model_name}@staging"
# mlflowからモデルを取得
model = mlflow.pyfunc.load_model(model_uri)
Step3. FlaskでAPIサーバを起動
こちらを参考にさせていただきました。いつもありがとうございます。
MLflowカスタムモデルをasync defで作っていたため、一部同期処理へ変換するなど冗長になっています。
大事なのはreturn Response(stream_with_context(iter), mimetype="text/event-stream")
の部分ですね。これでストリーミングのレスポンスを返すことが出来ます。
今回はAPIのルートを/chat
に設定しています。
from flask import Flask, jsonify, request, Response, stream_with_context
import time
app = Flask("llm-streaming")
app.config["JSON_AS_ASCII"] = False
api_port = "8000"
def iter_over_async(ait, loop):
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
@app.route("/chat", methods=["POST"])
def stream_chat():
"""
レスポンスをストリーミングで返す
"""
prompt = request.get_json(force=True).get("prompt", "")
# MLflowの登録Signatureから、parameterのデフォルト値を取得
params = dict(
[
(d["name"], d["default"])
for d in model.metadata.get_params_schema().to_dict()
]
)
# asyncのイテレーションを同期処理に変換
loop = asyncio.get_event_loop()
iter = iter_over_async(
model.unwrap_python_model().apredict_tokens(prompt, **params), loop
)
response = Response(stream_with_context(iter), mimetype="text/event-stream")
# Driver Proxy(nginx?)でSSE時にバッファしない設定
# https://qiita.com/willow-micro/items/5b245076101460d9dfd6
response.headers["Cache-Control"] = "no-cache"
response.headers["X-Accel-Buffering"] = "no"
print(response.headers)
return response
app.run(host="0.0.0.0", port=api_port)
Step4. StreamlitでQAのUIを作る
Databricks上でStreamlitを動かす方法は下記を参照ください。
Streamlitの処理を記載するui.py
ファイルを作成し、以下のコードを入力します。
(コメント少なくてごめんなさい)
"""
単純なQAアプリUI
"""
import os
import sys
import streamlit as st
import argparse
from typing import Any, Callable, Dict, List, Optional
import json
import requests
st.set_page_config(
page_title="Simple QA",
page_icon="📝",
layout="centered",
initial_sidebar_state="auto",
)
def parse_args(args):
"""今回はアクセストークンとAPIのURLをStreamlitの起動パラメータで渡す"""
parser = argparse.ArgumentParser("LLM")
parser.add_argument("--token", help="Access Token", required=True)
parser.add_argument("--api", help="API Server URL", required=True)
return parser.parse_args(args)
def call_api(url, api_token, prompt, placeholder):
"""APIサーバにリクエストを発行、ストリーミングで結果を取得しplaceholderに出力"""
data = {"prompt": prompt}
headers = {"Authorization": f"Bearer {api_token}", "Accept": "text/event-stream"}
full_response = ""
s = requests.Session()
with s.post(url, headers=headers, json=data, stream=True) as resp:
for token in resp.iter_content(decode_unicode=True):
if token:
full_response += token
placeholder.markdown(full_response + "▌")
if not resp.ok:
return None, resp.status_code
return full_response, None
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "message": "質問してください!"}]
st.title("Simple QA")
args = parse_args(sys.argv[1:])
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["message"])
if prompt := st.chat_input("Send a message."):
st.chat_message("user").write(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
with st.spinner("typing..."):
history = st.session_state.messages
full_response, err = call_api(
args.api + "chat", args.token, prompt, message_placeholder
)
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "user", "message": prompt})
st.session_state.messages.append({"role": "assistant", "message": full_response})
別でノートブックを作成し、以下のようにStreamlitを起動します。
※ ui.py
ファイルは事前に/tmp/dev/streamlit/
にファイルコピーし、そのパスを指定しています。
※ 起動時に、Streamlitの起動パラメータとしてAPIトークンやAPIサーバのURLを渡しています。
# APIサーバのURLをDriver ProxyのURLで指定する場合
# from dbruntime.databricks_repl_context import get_context
# ctx = get_context()
# api_url = f"https://{ctx.browserHostName}/driver-proxy-api/o/{ctx.workspaceId}/{ctx.clusterId}/8000/"
api_url = "http://127.0.0.1:8000/"
!streamlit run /tmp/dev/streamlit/ui.py --server.port {port} -- --token {api_token} --api {api_url}
注意
APIサーバのURLは、同一クラスタを前提とした127.0.0.1を直接指定しています。
これをdriver-proxy-apiのURL指定にすることもできるのですが、この場合、レスポンスが遅い(すべて終わってからストリームでデータが配信される)。
(FastAPIを使う方式でも確認してみましたが、同じ現象が起きました)
詳しいことはわからないのですが、おそらくdriver-proxy-api内でバッファリングしているせいなのではないかと。。。(driver proxy自体はnginxぽいのですが、バッファを外す設定してもダメでした)
driver-proxy-apiのURLでもストリーミングできたと考えていたのですが、いろいろ混ざって間違って認識していたようです。。。
実行結果はこんな感じ。
いろいろ制限付きではありますが、UIまで含めたストリーム処理の画面まで作成できました。
その他
- チャット履歴を推論に含めてないので、現時点でチャットボットというには不完全。
- RAG対応含めてそのあたりを入れていくとChatGPTぽいものになると思います。
終わりに
以前書いた記事の焼き直しっぽい内容ではありましたが、MLflowでLLMとAPI処理の大部分を管理できるのは個人的にいい学びでした。
ただ、プロダクション運用を考えると今回のようなやり方だといろいろ不十分なので、精進していきます。