導入
こちらのさらに続き(?)です。
SGLangの動向を日々見守っているのですが、以下のDocを見るとflashinferというカーネルライブラリに対応しているようです。
さらに速い推論が試せるかと思い、SGLang上でflashinferを使った推論を試してみます。
検証はDatabricks on AWS上で実施しました。
DBRは14.3ML LTS、クラスタタイプはg5.xlarge(A10G)のGPUクラスタです。
Flashinferとは?
READMEを一部邦訳。
FlashInfer は、FlashAttention、PageAttention、LoRA などの LLM GPU カーネルの高性能な実装を提供する言語言語モデル用のライブラリです。FlashInferは、LLMのサービスと推論に重点を置き、さまざまなシナリオで最先端のパフォーマンスを提供します。
FlashInferのユニークな機能は次のとおりです。
包括的なアテンションカーネル:KV-Cacheのさまざまな形式(パディングテンソル、ラグドテンソル、ページテーブル)での、プリフィル、デコード、アペンドカーネルのシングルリクエストおよびバッチバージョンを含む、LLMサービングのすべての一般的なユースケースをカバーするアテンションカーネル。
最適化された共有プレフィックス バッチ デコード: FlashInfer は、カスケードによって共有プレフィックス バッチ デコードのパフォーマンスを向上させ、ベースラインの vLLM PageAttention 実装 (31 トークンの長いプロンプトと 32768 の大きなバッチ サイズ) と比較して、最大 256 倍の大幅な高速化を実現します。
圧縮/量子化KVキャッシュへの注目の加速:最新のLLMは、メモリトラフィックを削減するために、量子化/圧縮KVキャッシュを使用して展開されることがよくあります。FlashInfer は、Grouped-Query Attention、Fused-RoPE Attention、Quantized Attention のパフォーマンスを最適化することで、これらのシナリオを高速化します。FlashInfer は PyTorch、TVM、C++ (ヘッダーのみ) API をサポートしており、既存のプロジェクトに簡単に統合できます。
技術体系はBlogでも解説されているようなのですが、まだ未読。
アテンション周りの処理の高速化とKVキャッシュの圧縮・最適化によって高速化している、ということなのかな。
今回はどれくらいパフォーマンスが変わるかを見たかったので、ひとまず動作を見てみます。
Step1. パッケージインストール
SGLangやFlashinferなど必要なパッケージをインストールします。
まずは前提となるpytorchやvllmをインストール。
容量がそれなりに大きいため、事前にダウンロードしたWheelファイルからインストールしています。
# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl
# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.3.0/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install /Volumes/training/llm/tmp/vllm-0.3.0+cu118-cp310-cp310-manylinux1_x86_64.whl
次にFlashinferをインストール。CUDA11.8用のパッケージをインストールします。
Wheelファイルはこちらにあります。
# Github上のパッケージをインストールする場合。
# %pip install flashinfer -i https://flashinfer.ai/whl/cu118/
%pip install /Volumes/training/llm/tmp/flashinfer-0.0.1+cu118-cp310-cp310-linux_x86_64.whl
最後にSGLangとtritonをインストール。SGLangはソースコードからインストールします。
こちらの記事と同様に、Databricks Repos機能を使い、リポジトリのクローンを作った上でインストールしました。
git clone
を使ってリポジトリを取ってきても問題ないと思います。
# !git clone https://github.com/flashinfer-ai/flashinfer.git
!cd /Workspace/Repos/リポジトリのパス/sglang && pip install -U -e "python[srt]"
%pip install -U "triton>=2.2.0"
dbutils.library.restartPython()
Step2. torch設定
前回同様、torchのmultiprocessingの処理種別を変更。
import torch
torch.multiprocessing.set_start_method('spawn', force=True)
Step3. ChatTemplateの登録
今回もモデルとしてOpenChat v1.5を利用します。
また、SGLangがGPTQの量子化フォーマットに対応したため、以下のモデルを事前にダウンロードして利用することとします。
OpenChat v1.5を適切に利用するためのプロンプトテンプレートを以下のように登録しておきます。
from sglang.lang.chat_template import (
get_chat_template,
register_chat_template,
ChatTemplate,
)
register_chat_template(
ChatTemplate(
name="openchat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", "\n"),
"user": ("GPT4 Correct User: ", "<|end_of_turn|>"),
"assistant": ("GPT4 Correct Assistant: ", "<|end_of_turn|>"),
},
)
)
Step4. ランタイムの起動
LLMをロードして推論準備をします。
RuntimeとしてFlashinferを使う場合、model_mode='flashinfer'
をパラメータとして指定する必要があります。
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-0106-GPTQ"
# flashinferを使わない場合
# runtime = Runtime(model_path)
# flashinferを使う場合
runtime = Runtime(model_path, model_mode="flashinfer")
# ChatTemplateの適用
runtime.endpoint.chat_template = get_chat_template("openchat")
# OpenChat-3.5-0106のtokenizerファイルのバグ対応
runtime.get_tokenizer().eos_token_id = 32000
# バックエンド設定
set_default_backend(runtime)
Step5. 推論
推論を実行し、レイテンシを測ってみます。
今回は、「バッチ」「単一」「ストリーミング」の3種類を実行して、レイテンシと秒あたりの生成トークン数(TPS)をflashinferを使うケースと使わないケースでざっくり測ってみました。計測は実行2回目以降の数値を取っています。
バッチ推論
import time
max_tokens = 200
@function
def simple_question(s, question):
s += user(question)
s += assistant(gen("answer", max_tokens=max_tokens))
tic = time.time()
states = simple_question.run_batch(
[
{"question": "Databricksとは何?詳細に説明して。"},
{"question": "LLMとは何?詳細に説明して。"},
{"question": "生成AIとは何?詳細に説明して。"},
],
temperature=0,
progress_bar=True,
)
for s in states:
print(s.text())
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
print(f"Token per sec: {max_tokens*len(states) / latency:.2f}")
実行例:
単一推論
tic = time.time()
state = simple_question.run(
question="Databricksとは何?詳細に説明して。",
temperature=0,
)
print(state.text())
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
print(f"Token per sec: {max_tokens/ latency:.2f}")
実行例:
ストリーミング
tic = time.time()
state = simple_question.run(
question="Databricksとは何?詳細に説明して。",
temperature=0,
stream=True,
)
async for c in state.text_async_iter("answer"):
print(c, end="", flush=True)
print()
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
print(f"Token per sec: {max_tokens/ latency:.2f}")
実行例:
計測結果は以下の通り。(バッチ推論のTPSは総生成トークン数で計算)
※ 複数試行して平均とっているわけではないので、雑な傾向として見てください。
推論種別 | Flashinferあり | Flashinferなし |
---|---|---|
バッチ | Latency:2.60秒、TPS:230.9 | Latency:2.76秒、TPS:217.2 |
単一 | Latency:2.42秒、TPS:82.6 | Latency:2.64秒、TPS:75.8 |
ストリーミング | Latency:3.21秒、TPS:62.3 | Latency:3.43秒、TPS:58.4 |
結果として、FlashInferを使う方がTPSベースで10%弱ほど高速化されました。
というか、使わなくても十分速いですね。。。
ストリーミングはこちらの処理の問題なのかバッチに比べると少し遅くなりました。
まとめ
SGLangでFlashinferを使った推論の試行と雑な速度計測をしてみました。
SGLangをNVidia系のGPUクラスタを使う場合、積極的にFlashinferを使ってもいいんじゃないかと思いました。
まだ試行不足ですが、現時点でデメリットはあまり感じなかったので。
とはいえ、SGLang自体がまだ発展途上のパッケージなので、今後の進展が楽しみです。
※ SGLangの高速JSONデコードとかも試しているのですが、なかなかびっくりします。。。こちらはまた別で記事にする予定。