LoginSignup
14
10

FastAPIを使ってストリーム対応のLLM API RESTサーバを作る on Databricks

Last updated at Posted at 2023-08-20

Databricksを使って動作確認しています。一部、Databricks固有の処理が含まれます。

導入

下記の前回記事では、langchain + CTranslate2を使ってストリーミングする処理を実装しました。

LLMの多くはGPU上の演算だったり多くのメモリを必要とするため、ローカル環境で動作させるよりバックエンドのAPIサーバとして構築し、そこにフロントエンドからつなぎに行く構成が多いと思います。

Huggingface TGIvLLMLlama.cppはAPIサーバ機能を備えてる、もしくは拡張機能が出ていたりするのですが、CTranslate2は標準でAPIサーバ機能を持っていません。(たぶん。Translation用のAPI サーバはあるぽいんですが)

無いものは作ればいいということで、前回のクラスを流用して、ストリーミングに対応した指示応答サーバをFastAPIを使って簡易実装してみます。

あくまで簡易実装です。
エラー処理など非常に脆弱なので参考にする場合は気を付けてください。

FastAPIとは

Databricksで開発用サーバを立てる際にはFlaskのサンプルをよく見ますが、個人的にはこっちの方が好きです。

FlaskにもStreaming Contentという機能があり、こっちでもいけそうなんですが、今回はFastAPIのStreamingResponseを使ってストリーミングに対応します。

今回の実装にあたっては、↓の内容を参考にしています。というかここを読んだらこの先の内容を読まなくてもいいのではないか。。。

準備

今回はAPI Serverを起動するノートブックと、動作確認用のクライアントを動作する二つのノートブックを使います。
モデルはline-corporation/japanese-large-lm-3.6b-instruction-sftをCTranslate2で変換したモデルを使ってみます。
DBRは13.1ML、CPUのみのクラスタで動作確認しています。

いつも通りに必要なモジュールをDatabricksのサーバを動作させる方のノートブック上でインストールします。

%pip install -U -qq sentencepiece ctranslate2 langchain transformers accelerate fastapi uvicorn nest_asyncio

dbutils.library.restartPython()

fastapi、uvicorn、そしてnest_asyncioが追加になったモジュールです。
最初の二つはfastapiを動作させるために必要であり、最後のモジュールは後ほど説明します。

Generator/Tokenizerのロード

いつものように読み込みます。

import ctranslate2
import transformers
import torch

model_path = "変換済みモデルのパス/line-corporation/japanese-large-lm-3.6b-instruction-sft"

# ジェネレーターとトークナイザーの準備
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, use_fast=False)

Langchain用のカスタムLLMクラスをインポート

前回作ったクラスを流用します。

class CTranslate2StreamLLM(LLM):

    generator: ctranslate2.Generator
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

### 省略

APIサーバを定義

必要なモジュールをインポート+FastAPIで利用するクラスInstructionRequestを定義します。
ついでに、config変数に各種設定値を保持させます。
実際には、temperatureなどはリクエストのパラメータにしたほうがいいですね。

from typing import Any, List, Union, Mapping, Optional, AsyncIterable, Awaitable

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

import asyncio
import os
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

# FastAPIのリクエスト内で使うクラス
class InstructionRequest(BaseModel):
    instruction: str

# 処理中に使う各種設定値
config = {}
config["template"] = "ユーザー: {instruction}\nシステム: "
config["max_tokens"] = 256
config["temperature"] = 0.3

ストリーム処理する関数generateを作成。
これをFastAPIのルーティング処理の中で実行します。

async def generate(
    instruction: str,
    verbose: bool = False,
) -> AsyncIterable[str]:
    
    # langchainのコールバックハンドラーを作成
    callback = AsyncIteratorCallbackHandler()

    prompt_tmp = PromptTemplate(
        input_variables=["instruction"], template=config["template"]
    )

    llm = CTranslate2StreamLLM(
        generator=generator,
        tokenizer=tokenizer,
        verbose=False,
        callbacks=[callback],
        max_length=config["max_tokens"],
        temperature=config["temperature"],
    )

    llm_chain = LLMChain(
        llm=llm,
        prompt=prompt_tmp,
        verbose=verbose,
    )

    async def wrap_done(fn: Awaitable, event: asyncio.Event):
        """渡された関数をawaitし、doneか例外発生で終了する"""
        try:
            await fn
        except Exception as e:
            # TODO: handle exception
            print(f"Caught exception: {e}")
        finally:
            event.set()

    # バックグラウンドでタスクを実行
    task = asyncio.create_task(
        wrap_done(llm_chain.apredict(instruction=instruction), callback.done),
    )

    async for token in callback.aiter():
        word = f"{token}"
        # </s>は表示しない
        word = word.replace("</s>", "")
        # server-sent-eventsを使ってresponseへストリーミング
        yield word

    # タスク完了まで待って終了
    await task

流れ的には、関数内部でLLMChainを作成し、callbackから得られるtokenをyieldで出力しているだけです。
llmに渡すコールバックとして、AsyncIteratorCallbackHandlerを指定しています。
これでトークンのenqueue/dequeueを実現しています。

最後にFastAPIのルーティングを定義。

app = FastAPI()

@app.post("/instruct")
async def instruct(body: InstructionRequest):

    return StreamingResponse(
        generate(
            instruction=body.instruction,
            verbose=False,
        ),
        media_type="text/event-stream",
    )

StreamingResponseとして結果を返却しているのがポイント。
これでストリーム対応のレスポンスを返します。
Flaskの記述とは若干違いがありますが、割と似てると思います。

FastAPIサーバの起動

import nest_asyncio

def start():
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=8000)

if __name__ == "__main__":
    start()

これを実行するとFastAPIのサーバが起動します。

補足ですが、ノートブック内で単純にuvicorn.run(app, host="0.0.0.0", port=8000)を実行するとエラーが発生します。
これは、ノートブック上では既に別のevent loopが実行されているためであり、同一スレッド上でuvicornのevent loopが実行できないため。

stackoverflowにもQ/Aが記載されており、今回はnest_asyncioを使う方法を参考にしました。

※ asyncio周りは、正直ちゃんと理解できている自信がない。。。

クライアントからの接続

無事サーバが起動したら、クライアントから繋いでみます。

別のノートブックを作成し、同じクラスタにつないだ上で以下を実行。

import requests
import json

def get_stream(data):

    headers = {"Accept": "text/event-stream"}
    output_text = []

    s = requests.Session()
    with s.post("http://127.0.0.1:8000/instruct", headers=headers, json=data, stream=True) as resp:
        for line in resp.iter_content(decode_unicode=True):
            if line:
                print(line, end="", flush=True)
                output_text.append(line)
        return "".join(output_text)

data = {"instruction": "織田信長は何をした人?"}
get_stream(data)
結果
織田信長は、1534年(天正元年)に尾張国(現在の愛知県西部)の戦国大名・織田信秀として生まれました。その後、信秀が早世すると、信長は跡を継いで織田家の家督となる。信長は、父・信秀から与えられた「清州城」を拠点に、尾張国の統一と拡大を進めました。また、美濃国(現在の岐阜県南部)や近江国(現在の滋賀県東部)への進出も進めた。信長は、強力な軍事力と経済力を持つようになり、その支配力は急速に拡大していきました。そして、1560年に桶狭間の戦いで今川義元を破り、ついに独立王国を築き上げます。しかし、この勝利の後、信長は徐々に勢力を拡大し、やがて天下人と呼ばれるようになる。信長は、革新的な政治手法と戦術によって、多くの敵を打ち破った。彼は、商業都市である堺や、貿易港であった堺など、経済的基盤を確立した。また、鉄砲などの新しい武器を積極的に取り入れ、戦闘における優位性を確保しました。このように、織田信長は、革新的な戦略と戦術を駆使して、日本の歴史を大きく変えた人物です。

実際には標準出力にストリーミングで表示されます。

requestsでpostをする際にstream=Trueを指定することでresponseをストリーミングで受け取ることが出来ます。

今回は同一クラスタで実行したため、URLを127.0.0.1固定で実行しましたが、driver-proxy-apiを経由して異なるホストからのアクセスもできます。(この場合、リクエストヘッダにBearerトークンの指定が必要)。

driver-proxyについては、以下を参照ください。

上記の記事でも触れられていますが、driver-proxyの利用はデモや開発用途に限定するのが良さそうです。

まとめ

ストリーミング対応のAPIサーバをFastAPIを使って実装しました。
今回は単純な指示に対して回答を返すだけで、対話を繰り返すような仕組みにはなっていません。
対話履歴をAPI側で受け取って推論に生かすような形にすると、ChatGPT相当の機能を持ったアプリを作れるかなと思います。

Streamlitで簡単なUIを作るとこんな感じです。
(対話できるような対応もしてみました)

チャット.gif

GPUクラスタ(g4dn.xlarge)で動作させていますが、かなり高速にストリーミング出力されます。

さて、DatabricksにはLLMも含めて本番環境にも利用できるModel Servingの機能があります。
今回のようなAPIサーバはデモ用途にはいいのですが、本番運用を考えるとやはりマネージドなサービス上で動作させるほうがよいです。
一方、DatabricksのModel Servingで動かすためにはモデルレジストリに登録する必要がある=MLFlow上でモデル管理できる必要があります。

次回は、CTranslate2で変換したモデルをMLFlow上でロギングしたいと思います。

14
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
10