0
0

ローカルLLMの生成文字をWebSocket経由で順次出力させる

Posted at

はじめに

最近ローカルLLMをAPI経由で飛ばすアプリケーションを開発しているのですが、APIを叩いた際に、出力文字が一括で出力されるからか、待ち時間が長くストレスになっていました。
ChatGPTはもちろんですが、FastChatなどローカルLLMでも同じように文字が順次出力されており、同じことができないかと調査したところ、transformersの場合streamerというストリームミングを実現するオプションの存在を知りました。
この記事はstreamerとWebSocketでLLMの出力を順次出力させる内容になります。

この記事のコードは以下リポジトリに置いてあります。
WSLLMComm

動作確認環境

OS:Windows 11(WSL+Docker)
GPU:RTX 2070 Super(VRAM:8GB)
バックエンド言語:Python
Webフレームワーク:FastAPI
フロントエンド:Vite+React+TypeScript

技術概要

WebSocketとは

Wikipediaより引用

WebSocket(ウェブソケット)は、単一のTCPコネクション上に双方向通信のチャンネルを提供する、コンピュータの通信プロトコルの1つである。

WebSocketはHTTPとは異なるプロトコルである。ともにOSI参照モデルのレイヤー7に位置し、レイヤー4のTCPに依存している。両者は異なるプロトコルであるが、RFC 6455では、WebSocketは「HTTPプロキシと仲介者をサポートするために、HTTPの443番および80番ポート上で動作するように設計されている」と述べられているように、HTTPプロトコルと互換性がある。互換性を実現するために、WebSocketのハンドシェイクはHTTP/1.1 Upgradeヘッダーを使用し、HTTPプロトコルをWebSocketプロトコルに変更するように実現されている。

つまり、Socket通信のようなことをHTTPポートで行えるプロトコルということです。

transformersのstreamerについて

transformersの出力をストリームミングでテキスト出力してくれるものになります。使用方法はtokenizerを引数として入れ、generateの引数に渡します。

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-128k-instruct",
            device_map="cuda",
            torch_dtype="auto",
            trust_remote_code=True,
        )
inputs = tokenizer("Can you provide ways to eat combinations of bananas and dragonfruits?", return_tensors="pt")
streamer = TextStreamer(tokenizer)

_ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)

for text in streamer:
    print(text, end="")

実装コード

バックエンド

FastAPIの実装になります。
まずmainとrouterから

main
from api.router import llm
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.include_router(llm.router)
router
from api.services.llm import LlmService
from fastapi import APIRouter, HTTPException, WebSocket

router = APIRouter()


@router.on_event("startup")
async def startup_event():
    global service
    service = LlmService()

@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    global service
    await websocket.accept()
    while True:
        data = await websocket.receive_text()
        print("WebSocket Data: ", data)
        await service.get_inference_result_as_ws(data, websocket)
        print("WebSocket Success!")

mainでrouterを呼んでいます。
routerの特徴は以下の通りです。

  • ロードしたモデルを保持するため、startupでサービスを初期化してグローバル変数にしています
  • @rouer.websocket("/ws")でWebSocketを定義しています

次はserviceです。

service
import asyncio
from threading import Thread

import torch
from fastapi import WebSocket
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


class LlmService:
    def __init__(self):
        self.tokenizer = None
        self.model = None

    def get_inference_result(self, str: str) -> TextIteratorStreamer:
        if self.model is None:
            self.set_llm()
        chat = [
            {"role": "user", "content": str},
        ]
        prompt = self.tokenizer.apply_chat_template(
            chat, tokenize=False, add_generation_prompt=True
        )
        token_ids = self.tokenizer.encode(
            prompt, add_special_tokens=False, return_tensors="pt"
        )
        streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True, skip_special_tokens=False
        )
        generation_kwargs = dict(
            input_ids=token_ids.to(self.model.device),
            do_sample=True,
            temperature=0.6,
            max_new_tokens=256,
            streamer=streamer,
        )
        with torch.no_grad():
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()
        return streamer

    def set_llm(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/Phi-3-mini-128k-instruct",
            device_map="cuda",
            torch_dtype="auto",
            trust_remote_code=True,
            load_in_4bit=True,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "microsoft/Phi-3-mini-128k-instruct"
        )

    async def get_inference_result_as_ws(self, input: str, ws: WebSocket) -> None:
        streamer = self.get_inference_result(input)
        for text in streamer:
            await ws.send_text(text)
            print(text, end="")
            await asyncio.sleep(0)
        await ws.send_text("<|ENDTEXT|>")

serviceではモデルのロードおよび文字の生成を行っています。
今回はmicrosoft/Phi-3-mini-128k-instructを使用させていただきました。

  • get_inference_resultで適切なパラメータをセットし、generateをしたstreamerを返しています
  • get_inference_result_as_wsでWebSocketへSendを行っています
    • asyncio.sleep(0)を入れないとSendパラメータ送信タイミングが処理後になるので注意
  • 最後に終わりを検知するテキスト<|ENDTEXT|>を送信します

フロントエンド

バックエンドで定義したWebSocketを受け取るフロントエンドの実装です。
今回はReact+TypeScriptの環境想定です。

container.tsx
import { SelectChangeEvent } from "@mui/material";
import { useEffect, useRef, useState } from "react";
import { Layout } from "../../components/layout/container";
import { useAPI } from "../../hooks/useAPI";
import { useWebSocket } from "../../hooks/useWebSocket";
import { LlmTextDto } from "../../types/LlmTextDto";
import { IndexPagePresenter } from "./presenter";

type LlmTextDto = {
  speaker: string;
  text: string;
};

export const IndexPage = () => {
  const [ws, setWs] = useState<WebSocket | null>(null);
  const [llmText, setLlmText] = useState<string>("");
  const [responseText, setResponseText] = useState<string>("");
  const responseTextRef = useRef<string>("");
  const [llmTextList, setLlmList] = useState<LlmTextDto[]>([]);
  const handleGetLlmText = async () => {
    setLlmText("");
    setIsSendButtonDisabled(true);
    setLlmList((llmTextList) => [
      ...llmTextList,
      { speaker: "user", text: llmText },
    ]);
    if (ws && ws.readyState === WebSocket.OPEN) {
      ws?.send(llmText);
    }
  };
  useEffect(() => {
    const ws = new WebSocket("ws://localhost:8000/ws");
    ws.onopen = () => {
      console.log("connected");
    };
    ws.onclose = () => {
      console.log("disconnected");
    };
    ws.onmessage = (event) => {
      const message = event.data;
      if (message === "<|ENDTEXT|>") {
        const text = responseTextRef.current;
        setLlmList((llmTextList) => [
          ...llmTextList,
          { speaker: "bot", text: text },
        ]);
        responseTextRef.current = "";
        setResponseText("");
        return;
      }
      responseTextRef.current += message;
      setResponseText(responseTextRef.current);
    };
    setWs(ws);
    return () => {
      ws.close();
    };
  }, [getWs]);
};

上記プログラムのコードは以下の通りです。

  • llmTextはユーザー入力のテキスト、responseTextは受信したテキスト、llmTextListは受信、送信両方のテキストをListとして持つ想定です
  • useEffectでWebSocketの定義を行っています
    • メッセージ受信時にResponseTextへ受信したテキストを追加しています
    • <|ENDTEXT|>が来た際にllmTextListへテキストを入れ、受信したテキストの初期化をおこなっております
  • handleGetLlmTextはボタンなどが押された際にテキストを送信するメソッドです

動作比較

通常の生成速度と比較をしてみます。

通常

rest20240602141717.gif
生成時間:25.6秒

WebSocket

websocket20240602142014.gif
生成開始時間:0.77秒
生成終了時間:24.7秒

どちらもほぼ同じ時間ですが、WebSocketは待ち時間が少なく、ストレスを感じにくいですね!

おわりに

今回の記事をまとめると以下の通りです。

  • WebSocketとstreamerを使うことでLLMの生成物を順次出力することができる
  • 順次出力のほうが、待ち時間が少なくユーザーにストレスを感じさせない

この記事が、ローカルLLM生成時間の長さで悩んでいる方の助けになれば幸いです。

参考・出展リンク

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0