こちらの続きです。
導入
SGLangはまだサポートしているモデルが少ないのですが、AWQフォーマットのモデルをサポートしているようなので、量子化したモデルでの推論を試してみます。
合わせて、前回の記事で行わなかったChatTemplateの設定やバッチ実行を試してみます。
検証はDatabricks on AWS上で実施しました。
DBRは14.2ML、クラスタタイプはGPU(g4dn.xlarge)を使用しました。
Step1. パッケージのインストール
前回と同じです。
# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl
# 今回は上記ファイルをUnity Catalog Volumesにダウンロードして利用
%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl
# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.2.7/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
# 今回は上記ファイルをUnity Catalog Volumesにダウンロードして利用
%pip install /Volumes/training/llm/tmp/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install "sglang[srt]"
# tritonを2.2.0以上にアップグレード
%pip install -U "triton>=2.2.0"
dbutils.library.restartPython()
Step2. torch設定
これも前回同様。
import torch
torch.multiprocessing.set_start_method('spawn', force=True)
Step3. LLMの読み込み/SGLangランタイム(SRT)起動
モデルを指定してSGLangのバックエンドであるSGLangランタイム(SRT)を起動します。
今回はAWQ量子化フォーマットのモデルを利用します。
事前にダウンロードしておいた、お馴染み(?)の以下のモデルを使いました。
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
# 事前にダウンロードしておいたモデルを利用
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-1210-AWQ"
runtime = Runtime(model_path)
set_default_backend(runtime)
Step4. ChatTemplateの設定
OpenChat用のChatTemplateは標準で用意されていないため、カスタム設定します。
ChatTemplate
クラスに、以下のようなロールに対応するprefix, postfixを設定すれば適切なテンプレートを作成できます。
このインスタンスを直接runtimeのchat_templateに設定してもよいのですが、register_chat_template
関数を使って登録を行ってからそれを取得する形にしてみました。
from sglang.lang.chat_template import (
get_chat_template,
register_chat_template,
ChatTemplate,
)
register_chat_template(
ChatTemplate(
name="openchat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", "\n"),
"user": ("GPT4 Correct User: ", "<|end_of_turn|>"),
"assistant": ("GPT4 Correct Assistant: ", "<|end_of_turn|>"),
},
)
)
runtime.endpoint.chat_template = get_chat_template("openchat")
Step5. 単純推論の実行と出力
では単純なQAを実行してみましょう。
@function
def simple_question(s, question):
s += user(question)
s += assistant(gen("answer_1", max_tokens=256))
state = simple_question.run(
question="Databricksとは何?",
temperature=0,
)
# 文字列出力
print(state.text())
GPT4 Correct User: Databricksとは何?<|end_of_turn|>GPT4 Correct Assistant: データブリックス(Databricks)は、アメリカのシリコンバレーに本拠地を持つ、データ処理や分析を行うクラウドベースのプラットフォームです。主に、大規模なデータセットを処理するために、データエンジニアやデータサイエンティストが使用するSparkエコシステムを提供しています。
データブリックスは、Apache Sparkをベースとしたアナリティクスエンジンを提供し、データを処理し、分析するための統合された環境を提供します。これにより、データエンジニアやデータサイエンティストは、データの処理、分析、機械学習モデルの構築<|end_of_turn|>
出力結果を見ると、ChatTemplateで設定した内容でロールごとのプロンプトが組み立てられているのがわかります。
Step6. バッチ実行
複数の推論をバッチ実行させてみます。
SGLangはvLLMなどと同様、continuous batchingをサポートしており、高効率にバッチ処理を実行できるようです。
Continuous batchingは以下の記事で解説されております。
SGLangでバッチ実行させる場合、run_batch
メソッドを使います。
states = simple_question.run_batch(
[
{"question": "What is Databricks?"},
{"question": "What is LLM?"},
{"question": "人工知能とは何?"},
],
progress_bar=True,
temperature=0,
)
for state in states:
print(state.text())
100%|██████████| 3/3 [00:07<00:00, 2.66s/it]
GPT4 Correct User: What is Databricks?<|end_of_turn|>GPT4 Correct Assistant:
Databricks is a cloud-based data analytics platform that provides a unified environment for data engineering, data science, and machine learning. It is built on top of Apache Spark, an open-source distributed computing framework, and is designed to handle large-scale data processing and analytics tasks.
Databricks offers a collaborative workspace where data engineers, data scientists, and machine learning engineers can work together in real-time. It provides a range of tools and libraries for data processing, machine learning, and artificial intelligence, including SQL, R, Python, and Scala.
The platform is available as a cloud service on Amazon Web Services (AWS), Microsoft Azure, and Google Cloud Platform (GCP), as well as on-premises for organizations that prefer to host their data and analytics infrastructure on their own servers.
Databricks is used by many large organizations for various use cases, including data warehousing, real-time analytics, machine learning, and artificial intelligence.<|end_of_turn|>
GPT4 Correct User: What is LLM?<|end_of_turn|>GPT4 Correct Assistant: 1. LLM stands for "Learning and Lifestyle Management" which is a field of study that focuses on the development of technologies and systems to help individuals manage their daily activities, learn new skills, and improve their overall well-being. This field combines elements of artificial intelligence, human-computer interaction, and psychology to create tools and applications that can assist people in achieving their goals and maintaining a healthy lifestyle.
2. LLM also stands for "Licensed Legal Manager," which is a professional designation for individuals who have completed a specialized program in legal management and have passed a certification exam. These professionals work in law firms, corporate legal departments, and other legal organizations, providing administrative and management support to attorneys and legal teams.
3. LLM can also refer to "Learning Management System," which is a software platform used by educational institutions, businesses, and other organizations to manage, track, and deliver educational courses and training programs. These systems allow administrators to create and organize course content, track student progress, and provide feedback and support.
4. LLM can also stand for "Learning Materials," which are resources, such as textbooks, workbooks, online courses, and other<|end_of_turn|>
GPT4 Correct User: 人工知能とは何?<|end_of_turn|>GPT4 Correct Assistant: 人工知能(Artificial Intelligence, AI)は、人間の思考や行動をコンピュータやロボットなどの機械に模倣させる科学と技術の分野です。人工知能の目的は、機械が人間の知性や理解能力を持って、複雑な問題を解決し、人間の労力を軽減することです。
人工知能は、さまざまな分野で活用されており、例えば医療、農業、ロボット技術、自動運転車、言語処理、画像認識、ゲーム、情報処理などです。人工知能の主要な分類は、単一のタスクを解決する単一の知能(narrow AI)と、複数のタスクを解決する通常の人間の<|end_of_turn|>
progress_barパラメータをTrueにすると、プログレスバーを表示してくれるので、実行状況を表示させる場合は非常に便利。
なお、このセルの実行時間は8秒でした。
Step7. デコード制約
もう少しSGLangの推論時機能を触ってみます。
gen
関数にregexパラメータを指定すると、生成結果を正規表現にマッチする形に制約をかけることができます。
以下、例。
@function
def regular_expression_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
state = regular_expression_gen.run()
print(state.text())
Q: What is the IP address of the Google DNS servers?
A: 8.8.8.88
IPアドレスの形で制約をかけて結果を得ることができました。
ちなみに、制約をかけない場合の回答は以下でした。(正規表現指定の追加考慮が必要ですね)
8.8.8.8 and 8.8.4.4
Step8. 並列化
fork
メソッドを使うことで、処理内の推論を並列で処理することができます。
以下、例。
@function
def tip_suggestion(s):
s += (
"Here are two tips for staying healthy: "
"1. Balanced Diet. 2. Regular Exercise.\n\n"
)
forks = s.fork(2)
for i, f in enumerate(forks):
f += f"Now, expand tip {i+1} into a paragraph:\n"
f += gen(f"detailed_tip", max_tokens=256, stop="\n\n")
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
s += "In summary" + gen("summary")
state = tip_suggestion.run()
print(state.text())
tip_suggestion
内で二つのTipを生成していますが、1個目と2個目の生成が並列生成されます。そのため、直列実行するよりは高速に推論ができる、ということのようです。
外部APIを利用する環境下など、推論の並列実行が有用な環境の場合はスループット向上が期待できると思います。
まとめ
SGLangでAWQフォーマットを使ったモデルの利用と、様々な推論を試してみました。
かなりシンプルに記述できるのがいいですね。
また、バッチ処理はかなり効率がよさそうだなという感覚を持ちました。
vLLMがGPTQなど他のフォーマットへも対応しているので、SGLangも今後他のフォーマットへの対応が進んでいったらと期待しています。
SGLangはもう少し機能を追っていく予定です。