5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

FlaskでLLMのストリームAPIサーバを作ってStreamlitで表示する on Databricks

Last updated at Posted at 2023-09-16

以前と似たような記事を書いてる気もしますが、自分の学びとして書いてみます。

導入

前回、MLflowでカスタムモデルを作ってLLMを管理しました。

それを発展させて、MLflowで登録したモデルを実際のアプリ(API+UI)まで繋げてみます。

自分の勉強のためにも、以前の記事から、以下を変更します。

  • モデル管理:MLflow上でモデル管理に変更
  • APIサーバ:FastAPIからFlaskに変更
  • UI:今回は簡易QA画面作成まで記事内に記載

検証はいつも通り全てDatabricks上で行っています。

注意
この記事で説明している方法は開発・実験用途での利用を推奨します。
本番用途は、Databricks標準のサーバレスモデルサービングエンドポイントの利用を検討ください。

早く日本にサービスこないかなあ。

Step1. モデルの登録

前回記事の内容を拡張したMLflowカスタムモデルを作って登録します。

apredict_tokensメソッドおよびその内部からコールしている処理などが拡張部分。
生成したトークンをストリーミングで出力する処理です。詳しくはFastAPIでストリーミングしたときと同様なのでコード説明は割愛します。

コード(長いので折り畳み)
import mlflow
import pandas as pd
import torch
import ctranslate2
import transformers
from typing import Any, List, Dict, Union, Mapping, Optional, AsyncIterable, Awaitable

from ctranslate2llm import CTranslate2StreamLLM

from langchain.llms.base import LLM
from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks import AsyncIteratorCallbackHandler

import asyncio
import os

# Define a custom PythonModel
class QAModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        """モデルを初期化する。モデルのパスはartifactsから取得"""

        model_path_local = context.artifacts["ct2-model"]

        device = "cuda" if torch.cuda.is_available() else "cpu"
        generator = ctranslate2.Generator(model_path_local, device=device)
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_path_local, use_fast=False
        )

        self.generator = generator
        self.tokenizer = tokenizer
        self.verbose = False

    async def _agenerate_tokens(
        self,
        instruction: str,
        llm: CTranslate2StreamLLM,
        prompt_template: PromptTemplate,
        system_prompt: str = "",
        verbose: bool = False,
    ):
        """生成したトークンを非同期で返す"""

        callback = AsyncIteratorCallbackHandler()
        llm.callbacks = [callback]
        llm_chain = LLMChain(llm=llm, prompt=prompt_template)

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
            try:
                await fn
            except Exception as e:
                # TODO: handle exception
                print(f"Caught exception: {e}")
            finally:
                # Signal the aiter to stop.
                event.set()

        # Begin a task that runs in the background.
        task = asyncio.create_task(
            wrap_done(llm_chain.apredict(instruction=instruction), callback.done),
        )

        async for token in callback.aiter():
            if token not in llm.tokenizer.all_special_tokens:
                yield f"{token}"

        await task

    async def _agenerate_batch(self, instructions, llm, prompt):
        """
        トークンの全生成後に結果を返す。
        複数指示をシリアルに実行するため、バッチ効率はよくない。
        """

        results = []
        for inst in instructions:
            sentence_buffer = []
            async for t in self._agenerate_tokens(
                inst,
                llm=llm,
                prompt_template=prompt,
            ):
                sentence_buffer.append(t)
            results.append("".join(sentence_buffer))

        return results

    def predict(
        self,
        context,
        model_input: pd.DataFrame,
        params: Optional[Dict[str, Any]] = None,
    ) -> pd.DataFrame:

        template = params["template"]
        temperature = params["temperature"]
        max_length = params["max_length"]
        verbose = params["verbose"]

        instructions = model_input[["instruction"]].to_dict(orient="records")
        prompt = PromptTemplate(input_variables=["instruction"], template=template)

        llm = CTranslate2StreamLLM(
            generator=self.generator,
            tokenizer=self.tokenizer,
            temperature=temperature,
            max_length=max_length,
            verbose=verbose,
        )

        loop = asyncio.get_event_loop()
        results = loop.run_until_complete(
            self._agenerate_batch(instructions, llm, prompt)
        )

        return pd.DataFrame(
            {
                "instruction": [q["instruction"] for q in instructions],
                "answer": [r for r in results],
            }
        )

    async def apredict_tokens(
        self,
        instruction: str,
        template:str,
        temperature:float=0.5,
        max_length:int=256,
        verbose:bool=False,

    ):
        """ 非同期に指示の結果を推論し、トークン単位で結果を出力する """

        prompt = PromptTemplate(input_variables=["instruction"], template=template)

        llm = CTranslate2StreamLLM(
            generator=self.generator,
            tokenizer=self.tokenizer,
            temperature=temperature,
            max_length=max_length,
            verbose=verbose,
        )

        async for t in self._agenerate_tokens(
            instruction,
            llm=llm,
            prompt_template=prompt,
        ):
            yield t

このカスタムクラスをMLflowに登録します。
今回はtraining.llm.ct2apiという名前で登録しました。また、エイリアスをstagingで設定します。

MLflowの登録処理についても前回記事とほとんど同じなので割愛。

Step2. モデルの読込

MLflowに登録したモデルを読み込みます。エイリアス管理は便利。

# 最新モデルバージョンの特定
model_name = "training.llm.ct2api"
model_uri = f"models:/{model_name}@staging"

# mlflowからモデルを取得
model = mlflow.pyfunc.load_model(model_uri)

Step3. FlaskでAPIサーバを起動

こちらを参考にさせていただきました。いつもありがとうございます。

MLflowカスタムモデルをasync defで作っていたため、一部同期処理へ変換するなど冗長になっています。
大事なのはreturn Response(stream_with_context(iter), mimetype="text/event-stream")の部分ですね。これでストリーミングのレスポンスを返すことが出来ます。

今回はAPIのルートを/chatに設定しています。

from flask import Flask, jsonify, request, Response, stream_with_context
import time

app = Flask("llm-streaming")
app.config["JSON_AS_ASCII"] = False
api_port = "8000"


def iter_over_async(ait, loop):
    ait = ait.__aiter__()

    async def get_next():
        try:
            obj = await ait.__anext__()
            return False, obj
        except StopAsyncIteration:
            return True, None

    while True:
        done, obj = loop.run_until_complete(get_next())
        if done:
            break
        yield obj


@app.route("/chat", methods=["POST"])
def stream_chat():
    """
    レスポンスをストリーミングで返す
    """
    prompt = request.get_json(force=True).get("prompt", "")

    # MLflowの登録Signatureから、parameterのデフォルト値を取得
    params = dict(
        [
            (d["name"], d["default"])
            for d in model.metadata.get_params_schema().to_dict()
        ]
    )

    # asyncのイテレーションを同期処理に変換
    loop = asyncio.get_event_loop()
    iter = iter_over_async(
        model.unwrap_python_model().apredict_tokens(prompt, **params), loop
    )

    response = Response(stream_with_context(iter), mimetype="text/event-stream")

    # Driver Proxy(nginx?)でSSE時にバッファしない設定
    # https://qiita.com/willow-micro/items/5b245076101460d9dfd6
    response.headers["Cache-Control"] = "no-cache"
    response.headers["X-Accel-Buffering"] = "no"

    print(response.headers)

    return response


app.run(host="0.0.0.0", port=api_port)

Step4. StreamlitでQAのUIを作る

Databricks上でStreamlitを動かす方法は下記を参照ください。

Streamlitの処理を記載するui.pyファイルを作成し、以下のコードを入力します。
(コメント少なくてごめんなさい)

ui.py
"""
単純なQAアプリUI
"""

import os
import sys
import streamlit as st
import argparse
from typing import Any, Callable, Dict, List, Optional
import json
import requests

st.set_page_config(
    page_title="Simple QA",
    page_icon="📝",
    layout="centered",
    initial_sidebar_state="auto",
)


def parse_args(args):
    """今回はアクセストークンとAPIのURLをStreamlitの起動パラメータで渡す"""
    parser = argparse.ArgumentParser("LLM")
    parser.add_argument("--token", help="Access Token", required=True)
    parser.add_argument("--api", help="API Server URL", required=True)
    return parser.parse_args(args)


def call_api(url, api_token, prompt, placeholder):
    """APIサーバにリクエストを発行、ストリーミングで結果を取得しplaceholderに出力"""

    data = {"prompt": prompt}
    headers = {"Authorization": f"Bearer {api_token}", "Accept": "text/event-stream"}

    full_response = ""

    s = requests.Session()

    with s.post(url, headers=headers, json=data, stream=True) as resp:
        for token in resp.iter_content(decode_unicode=True):
            if token:
                full_response += token
                placeholder.markdown(full_response + "")
    if not resp.ok:
        return None, resp.status_code
    return full_response, None


if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "message": "質問してください!"}]
st.title("Simple QA")

args = parse_args(sys.argv[1:])

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["message"])

if prompt := st.chat_input("Send a message."):

    st.chat_message("user").write(prompt)

    with st.chat_message("assistant"):
        message_placeholder = st.empty()

        with st.spinner("typing..."):

            history = st.session_state.messages

            full_response, err = call_api(
                args.api + "chat", args.token, prompt, message_placeholder
            )
        message_placeholder.markdown(full_response)
    st.session_state.messages.append({"role": "user", "message": prompt})
    st.session_state.messages.append({"role": "assistant", "message": full_response})

別でノートブックを作成し、以下のようにStreamlitを起動します。

※ ui.pyファイルは事前に/tmp/dev/streamlit/にファイルコピーし、そのパスを指定しています。
※ 起動時に、Streamlitの起動パラメータとしてAPIトークンやAPIサーバのURLを渡しています。

# APIサーバのURLをDriver ProxyのURLで指定する場合
# from dbruntime.databricks_repl_context import get_context
# ctx = get_context()
# api_url = f"https://{ctx.browserHostName}/driver-proxy-api/o/{ctx.workspaceId}/{ctx.clusterId}/8000/"

api_url = "http://127.0.0.1:8000/"

!streamlit run /tmp/dev/streamlit/ui.py --server.port {port} -- --token {api_token} --api {api_url}

注意
APIサーバのURLは、同一クラスタを前提とした127.0.0.1を直接指定しています。

これをdriver-proxy-apiのURL指定にすることもできるのですが、この場合、レスポンスが遅い(すべて終わってからストリームでデータが配信される)。
(FastAPIを使う方式でも確認してみましたが、同じ現象が起きました)

詳しいことはわからないのですが、おそらくdriver-proxy-api内でバッファリングしているせいなのではないかと。。。(driver proxy自体はnginxぽいのですが、バッファを外す設定してもダメでした)

driver-proxy-apiのURLでもストリーミングできたと考えていたのですが、いろいろ混ざって間違って認識していたようです。。。

実行結果はこんな感じ。

チャット.gif

いろいろ制限付きではありますが、UIまで含めたストリーム処理の画面まで作成できました。

その他

  • チャット履歴を推論に含めてないので、現時点でチャットボットというには不完全。
  • RAG対応含めてそのあたりを入れていくとChatGPTぽいものになると思います。

終わりに

以前書いた記事の焼き直しっぽい内容ではありましたが、MLflowでLLMとAPI処理の大部分を管理できるのは個人的にいい学びでした。
ただ、プロダクション運用を考えると今回のようなやり方だといろいろ不十分なので、精進していきます。

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?