またまた凄いのが出てきた。。。
導入
LMSYS OrgがLLMの推論フレームワーク(と言っていいのか)であるSGLangを公開しました。
Githubの説明によると、SGLangは以下のコア機能を備えています。
柔軟なフロントエンド言語: これにより、複数の連鎖生成呼び出し、高度なプロンプト手法、制御フロー、複数のモダリティ、並列処理、および外部との相互作用を使用して、LLMアプリケーションを簡単にプログラミングできます。
RadixAttention を使用した高性能ランタイム: この機能は、複数の呼び出し間で KV キャッシュを自動的に再利用することにより、複雑な LLM プログラムの実行を大幅に高速化します。また、連続バッチ処理やテンソル並列処理など、他の一般的な手法もサポートしています。
SGLangは一般的なLLMのワークロードにおいて、vLLMのような既存システムより最大5倍程度のスループットを発揮すると発表されています。
(詳細は上記リンク先の比較を確認ください)
え、あのvLLMの5倍のスループット?
というわけで、実際にどんなものなのか試してみたいと思います。
検証はDatabricks on AWS上で実施しました。
DBRは14.2ML、クラスタタイプはGPU(g5.xlarge)を使用しました。
Step1. パッケージインストール
SGLangに必要なパッケージをインストール。
SGLangは依存パッケージにvLLMの比較的新しいバージョンを使用しています。
そしてvLLMはpytorchの2.1.0以降を要求するため、DatabricksのDBRに合わせたpytorchおよびxformersのパッケージをインストールしています。(CUDA 18を利用するように指定)
また、SGLangは現状OpenAIやAnthropicのAPIと連携できるように設計されており、それらを利用する場合はsglang[all]
をインストールしてください。
今回はそれらを利用せず、ローカルLLMで動かすため、より限定的なsglang[srt]
でインストールを行っています。
# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl
# 今回は上記ファイルをUnity Catalog Volumesにダウンロードして利用
%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl
# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.2.7/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
# 今回は上記ファイルをUnity Catalog Volumesにダウンロードして利用
%pip install /Volumes/training/llm/tmp/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install "sglang[srt]"
# 念のため、tritonを2.2.0以上にアップグレード
%pip install -U "triton>=2.2.0"
dbutils.library.restartPython()
Step2. torch設定
torchのmultiprocessingに対して、以下の設定を行います。
これをしないと、私の環境ではエラーが出ました。
import torch
torch.multiprocessing.set_start_method('spawn', force=True)
Step3. LLMの読み込み/SGLangランタイム(SRT)起動
モデルを指定してSGLangのバックエンドであるSGLangランタイム(SRT)を起動します。
GithubのREADMEではLlama2-7B-chatを使っていましたが、せっかくなので以下の日本語が使えるLLMであるELYZA-7B-instructモデルを使うことにします。
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
# 事前にダウンロードしておいたモデルを利用
model_path = "/Volumes/training/llm/model_snapshots/models--elyza--ELYZA-japanese-Llama-2-7b-fast-instruct"
runtime = Runtime(model_path)
set_default_backend(runtime)
ちらっとコードを見ただけですが、Runtime
はFastAPIで実装されたサーバを別プロセスで起動しているようです。
次に、起動したランタイムに対して、Llama2-chatのチャットテンプレートを適用します。
特に指定しないと、実行するチャットプロントがSYSTEM: や USER: から始まるものに変換されるようで、これを適切なプロンプトテンプレート(ELYZA-7Bだと[INST]などで始まったりする)ように設定します。
なお、Llama2-chat用のテンプレートは既にSGLangの中に組み込まれているため、それを取得するget_chat_template
を呼び出しています。(カスタムチャットテンプレートも作成できるようです)
# Elyza用に、chat_templateをLlama2-chat用のテンプレートに変更する
from sglang.lang.chat_template import get_chat_template
runtime.endpoint.chat_template = get_chat_template("llama-2-chat")
これで準備完了です。
Step4. 単純な推論実行
では、まずシンプルなQAを実行してみます。
@function
def simple_question(s, question):
s += system("あなたは誠実で優秀な日本人のアシスタントです。")
s += user(question)
s += assistant(gen("answer", max_tokens=512))
state = simple_question.run(
question="Databricksとは何?",
temperature=0,
)
# 文字列出力
print(state.text())
[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
Databricksとは何? [/INST] Databricksは、データの可視化、分析、機械学習を実行するためのクラウドベースのプラットフォームです。Apache Sparkを使用して、大規模なデータセットを処理し、洞察を得ることができます。Databricksは、データサイエンティスト、データエンジニア、ビジネスユーザーを対象としています。 </s><s>
SGLangは@function
デコレータを付与した関数内に実行するプロンプトを記述し、run
等のメソッドで処理を記述する形になっています。
かなり直感的な記述ができるなという感想を持ちました。
返り値であるstate
のメソッドをコールすると実際の生成処理が遅延実行されます。
text
メソッドは、入力プロンプト含めて生成結果を取得する処理のようです。
なお、このセルの実行時間は2.59秒でした。
text
メソッド以外に、各ロールごとの入出力を得るメソッドmessages
も用意されています。
state = simple_question.run(
question="Databricksとは何?",
temperature=0,
)
# Message
for m in state.messages():
print(m["role"], ":", m["content"])
system : あなたは誠実で優秀な日本人のアシスタントです。
user : Databricksとは何?
assistant : Databricksは、データの可視化、分析、機械学習を実行するためのクラウドベースのプラットフォームです。Apache Sparkを使用して、大規模なデータセットを処理し、洞察を得ることができます。Databricksは、データサイエンティスト、データエンジニア、ビジネスユーザーを対象としています。
Step5. ストリーミング出力
ストリーミング出力用のインターフェースも提供されています。
state = simple_question.run(
question="Databricksとは何?400文字以上で詳細に説明して。",
temperature=0,
stream=True,
)
# asyncの場合。asyncではない場合はtext_iterでストリーム出力できる。
async for out in state.text_async_iter():
print(out, end="", flush=True)
print()
[INST] <<SYS>>
あなたは誠実で優秀な日本人のアシスタントです。
<</SYS>>
Databricksとは何?400文字以上で詳細に説明して。 [/INST] Databricksは、Apache Sparkを使用してデータを処理するクラウドベースのプラットフォームです。Sparkの機能を簡単に利用できるようにするAPIや、Sparkのパフォーマンスを最適化する機能などを提供しています。
Databricksは、2013年にカリフォルニア州で創業された会社です。2016年には、GoogleやNECなどから総額1億ドルの資金調達を行っています。2018年には、NASDAQに上場しました。
Databricksのプラットフォームでは、SparkのAPIや、Sparkのパフォーマンスを最適化する機能などを提供しています。これにより、開発者はSparkの機能を簡単に利用でき、IT管理者はSparkのパフォーマンスを最適化し、セキュリティを確保し、コストを削減することができます。
Databricksのプラットフォームでは、SparkのAPIや、Sparkのパフォーマンスを最適化する機能などを提供しています。以下に、その例を示します。
* SparkのAPI: Databricksのプラットフォームでは、SparkのAPIを簡単に利用できるようにするRESTful APIを提供しています。これにより、開発者はSparkの機能を簡単に利用できます。
* Sparkのパフォーマンスを最適化する機能: Databricksのプラットフォームでは、Sparkのパフォーマンスを最適化する機能を提供しています。これにより、IT管理者はSparkのパフォーマンスを最適化し、コストを削減できます。
* セキュリティ: Databricksのプラットフォームでは、Sparkのセキュリティを確保する機能を提供しています。これにより、IT管理者はデータの漏洩や改ざんを防ぐことができます。
* コスト削減: Databricksのプラットフォームでは、Sparkのコストを削減する機能を提供しています。これにより、IT管理者はSparkのコストを削減できます。
Databricksのプラットフォームは、Sparkの機能を簡単に利用できるようにするAPIや、Sparkのパフォーマンスを最適化する機能などを提供しています。これにより、開発者はSparkの機能を簡単に利用でき、IT管理者はSpark </s><s>
run実行時にstream=True
を指定することでストリーミングで結果を得ることができます。
なお、このセルの実行時間は18.01秒でした。(512トークン分出力しているはず)
Step6. バックエンドの終了
runtimeのshutdown
メソッドを呼ぶことでバックエンドサービスを終了できます。
runtime.shutdown()
まとめ
SGLangをとりあえず触ってみました。
ごくごく基本的なところしか使えてないので、もう少しいろいろやってみたいと思っています。
ちなみに上記ぐらいの使い方だと、パフォーマンスはvLLMと同等ぐらいという感覚です。
マルチターンの推論やバッチ処理などをするとまたスループットは大きく違ってくるかもしれません。
また、記述方式は非常にシンプルで好感が持てます。
今回はやっていないのですが、FewShotPromptやマルチターンクエスチョンなども書きやすそうですし、複雑なChainを作る必要性がないユースケースはこれで十分かなと思いました。
他にも、出力を正規表現で処理するなどいろいろと機能があるようなので、このあたりも試していきたいと思います。
かなり期待の持てる仕組かなと思うので、どんどん発展していって欲しい!