はじめに
最近ローカルLLMをAPI経由で飛ばすアプリケーションを開発しているのですが、APIを叩いた際に、出力文字が一括で出力されるからか、待ち時間が長くストレスになっていました。
ChatGPTはもちろんですが、FastChatなどローカルLLMでも同じように文字が順次出力されており、同じことができないかと調査したところ、transformersの場合streamerというストリームミングを実現するオプションの存在を知りました。
この記事はstreamerとWebSocketでLLMの出力を順次出力させる内容になります。
この記事のコードは以下リポジトリに置いてあります。
WSLLMComm
動作確認環境
OS:Windows 11(WSL+Docker)
GPU:RTX 2070 Super(VRAM:8GB)
バックエンド言語:Python
Webフレームワーク:FastAPI
フロントエンド:Vite+React+TypeScript
技術概要
WebSocketとは
WebSocket(ウェブソケット)は、単一のTCPコネクション上に双方向通信のチャンネルを提供する、コンピュータの通信プロトコルの1つである。
WebSocketはHTTPとは異なるプロトコルである。ともにOSI参照モデルのレイヤー7に位置し、レイヤー4のTCPに依存している。両者は異なるプロトコルであるが、RFC 6455では、WebSocketは「HTTPプロキシと仲介者をサポートするために、HTTPの443番および80番ポート上で動作するように設計されている」と述べられているように、HTTPプロトコルと互換性がある。互換性を実現するために、WebSocketのハンドシェイクはHTTP/1.1 Upgradeヘッダーを使用し、HTTPプロトコルをWebSocketプロトコルに変更するように実現されている。
つまり、Socket通信のようなことをHTTPポートで行えるプロトコルということです。
transformersのstreamerについて
transformersの出力をストリームミングでテキスト出力してくれるものになります。使用方法はtokenizerを引数として入れ、generateの引数に渡します。
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct",
device_map="cuda",
torch_dtype="auto",
trust_remote_code=True,
)
inputs = tokenizer("Can you provide ways to eat combinations of bananas and dragonfruits?", return_tensors="pt")
streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
for text in streamer:
print(text, end="")
実装コード
バックエンド
FastAPIの実装になります。
まずmainとrouterから
from api.router import llm
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(llm.router)
from api.services.llm import LlmService
from fastapi import APIRouter, HTTPException, WebSocket
router = APIRouter()
@router.on_event("startup")
async def startup_event():
global service
service = LlmService()
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
global service
await websocket.accept()
while True:
data = await websocket.receive_text()
print("WebSocket Data: ", data)
await service.get_inference_result_as_ws(data, websocket)
print("WebSocket Success!")
mainでrouterを呼んでいます。
routerの特徴は以下の通りです。
- ロードしたモデルを保持するため、startupでサービスを初期化してグローバル変数にしています
-
@rouer.websocket("/ws")
でWebSocketを定義しています
次はserviceです。
import asyncio
from threading import Thread
import torch
from fastapi import WebSocket
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
class LlmService:
def __init__(self):
self.tokenizer = None
self.model = None
def get_inference_result(self, str: str) -> TextIteratorStreamer:
if self.model is None:
self.set_llm()
chat = [
{"role": "user", "content": str},
]
prompt = self.tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
token_ids = self.tokenizer.encode(
prompt, add_special_tokens=False, return_tensors="pt"
)
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=False
)
generation_kwargs = dict(
input_ids=token_ids.to(self.model.device),
do_sample=True,
temperature=0.6,
max_new_tokens=256,
streamer=streamer,
)
with torch.no_grad():
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def set_llm(self):
self.model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct",
device_map="cuda",
torch_dtype="auto",
trust_remote_code=True,
load_in_4bit=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct"
)
async def get_inference_result_as_ws(self, input: str, ws: WebSocket) -> None:
streamer = self.get_inference_result(input)
for text in streamer:
await ws.send_text(text)
print(text, end="")
await asyncio.sleep(0)
await ws.send_text("<|ENDTEXT|>")
serviceではモデルのロードおよび文字の生成を行っています。
今回はmicrosoft/Phi-3-mini-128k-instruct
を使用させていただきました。
-
get_inference_result
で適切なパラメータをセットし、generateをしたstreamerを返しています -
get_inference_result_as_ws
でWebSocketへSendを行っています- asyncio.sleep(0)を入れないとSendパラメータ送信タイミングが処理後になるので注意
- 最後に終わりを検知するテキスト
<|ENDTEXT|>
を送信します
フロントエンド
バックエンドで定義したWebSocketを受け取るフロントエンドの実装です。
今回はReact+TypeScriptの環境想定です。
import { SelectChangeEvent } from "@mui/material";
import { useEffect, useRef, useState } from "react";
import { Layout } from "../../components/layout/container";
import { useAPI } from "../../hooks/useAPI";
import { useWebSocket } from "../../hooks/useWebSocket";
import { LlmTextDto } from "../../types/LlmTextDto";
import { IndexPagePresenter } from "./presenter";
type LlmTextDto = {
speaker: string;
text: string;
};
export const IndexPage = () => {
const [ws, setWs] = useState<WebSocket | null>(null);
const [llmText, setLlmText] = useState<string>("");
const [responseText, setResponseText] = useState<string>("");
const responseTextRef = useRef<string>("");
const [llmTextList, setLlmList] = useState<LlmTextDto[]>([]);
const handleGetLlmText = async () => {
setLlmText("");
setIsSendButtonDisabled(true);
setLlmList((llmTextList) => [
...llmTextList,
{ speaker: "user", text: llmText },
]);
if (ws && ws.readyState === WebSocket.OPEN) {
ws?.send(llmText);
}
};
useEffect(() => {
const ws = new WebSocket("ws://localhost:8000/ws");
ws.onopen = () => {
console.log("connected");
};
ws.onclose = () => {
console.log("disconnected");
};
ws.onmessage = (event) => {
const message = event.data;
if (message === "<|ENDTEXT|>") {
const text = responseTextRef.current;
setLlmList((llmTextList) => [
...llmTextList,
{ speaker: "bot", text: text },
]);
responseTextRef.current = "";
setResponseText("");
return;
}
responseTextRef.current += message;
setResponseText(responseTextRef.current);
};
setWs(ws);
return () => {
ws.close();
};
}, [getWs]);
};
上記プログラムのコードは以下の通りです。
-
llmText
はユーザー入力のテキスト、responseText
は受信したテキスト、llmTextList
は受信、送信両方のテキストをListとして持つ想定です - useEffectでWebSocketの定義を行っています
- メッセージ受信時にResponseTextへ受信したテキストを追加しています
-
<|ENDTEXT|>
が来た際にllmTextList
へテキストを入れ、受信したテキストの初期化をおこなっております
-
handleGetLlmText
はボタンなどが押された際にテキストを送信するメソッドです
動作比較
通常の生成速度と比較をしてみます。
通常
WebSocket
どちらもほぼ同じ時間ですが、WebSocketは待ち時間が少なく、ストレスを感じにくいですね!
おわりに
今回の記事をまとめると以下の通りです。
- WebSocketとstreamerを使うことでLLMの生成物を順次出力することができる
- 順次出力のほうが、待ち時間が少なくユーザーにストレスを感じさせない
この記事が、ローカルLLM生成時間の長さで悩んでいる方の助けになれば幸いです。