導入
こちらのさらに続き(?)です。
SGLangについて、以下のBlogでローカルLLMを用いたJSONの高速デコードが紹介されていました。
こちらの内容を見るに、驚きの生成速度です。
詳細なロジックまで理解できているわけではないのですが、興味が出たので実際に試してみました。
JSON高速デコードに関するコードサンプルは以下にあります。
今回はこちらをベースに、LLMを使った感情分析のバッチ処理を行うサンプルを作成して実行速度を見てみます。
具体的には、テキスト文章をインプットとして、ネガティブ・ポジティブのラベルをJSON形式で返す処理を数十程度バッチで実行させてみます。
検証はDatabricks on AWS上で実施しました。
DBRは14.3ML LTS、クラスタタイプはg5.xlarge(A10G)のGPUクラスタです。
Step1. パッケージインストール
SGLangなど必要なパッケージをインストールします。
前回と同じです。
# torch, xformers
# pytorchのリポジトリから直接インストールする場合
# %pip install -U https://download.pytorch.org/whl/cu118/torch-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl
# %pip install -U https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
%pip install -U /Volumes/training/llm/tmp/xformers-0.0.23.post1+cu118-cp310-cp310-manylinux2014_x86_64.whl
# vLLM
# GithubからvLLMを直接インストールする場合
# %pip install https://github.com/vllm-project/vllm/releases/download/v0.3.0/vllm-0.2.7+cu118-cp310-cp310-manylinux1_x86_64.whl
%pip install /Volumes/training/llm/tmp/vllm-0.3.0+cu118-cp310-cp310-manylinux1_x86_64.whl
# Reposにクローンしておいたsglangのソースリポジトリから最新のSGLangインストール
!cd /Workspace/Repos/リポジトリのパス/sglang && pip install -U -e "python[srt]"
%pip install -U "triton>=2.2.0"
dbutils.library.restartPython()
Step2. torch設定
これも前回同様、torchのmultiprocessingの処理種別を変更。
import torch
torch.multiprocessing.set_start_method('spawn', force=True)
Step3. 感情分析用データセットの準備
今回はサンプルとして、公開されているchABSA-datasetに対して、ネガポジ判定をさせてみます。
そのためのデータセットを取得・加工します。
import requests
import zipfile
import io
import json
# chABSA-datasetを取得&Zipファイルを展開
url = "https://s3-ap-northeast-1.amazonaws.com/dev.tech-sketch.jp/chakki/public/chABSA-dataset.zip"
resp = requests.get(url)
content = io.BytesIO(resp.content)
with zipfile.ZipFile(content, "r") as f:
f.extractall("/tmp/chABSA")
# 展開したファイル群のうち、1ファイルのみロード
with open("/tmp/chABSA/chABSA-dataset/e05714_ann.json","br") as f:
data = json.load(f)
data
ファイルは以下のようなJSONデータになっています。
{'header': {'document_id': 'E05714',
'document_name': 'ソニーフィナンシャルホールディングス株式会社',
'doc_text': '有価証券報告書',
'edi_id': 'E05714',
'security_code': '87290',
'category33': '保険業',
'category17': '金融(除く銀行)',
'scale': '4'},
'sentences': [{'sentence_id': 0,
'sentence': '当連結会計年度(平成28年4月1日~平成29年3月31日)における日本経済は、企業収益の堅調な推移や雇用・所得環境の着実な改善などを背景に、緩やかな景気回復の動きが見られる一方で、英国や米国における経済政策の変化や中国をはじめとする新興国経済の下振れリスクを含む海外経済動向の影響などにより、先行きの不透明感は高まりました',
'opinions': [{'target': '日本経済',
'category': 'NULL#general',
'polarity': 'positive',
'from': 33,
'to': 37},
{'target': '企業収益',
'category': 'NULL#general',
'polarity': 'positive',
'from': 39,
'to': 43},
--- 以下省略 ---
これを適当に加工して、文章と感情のテーブルデータに変換します。
変換の過程で、ネガポジ判定ができない文は除去しました。
import pyspark.sql.functions as F
import pandas as pd
pdf = pd.DataFrame(data["sentences"])
df = spark.createDataFrame(pdf)
df = df.select("sentence_id", "sentence", "opinions.polarity")
df = df.withColumn("polarity_size", F.size("polarity"))
df = df.withColumn(
"positive_score", F.size(F.filter("polarity", lambda x: x == "positive"))
)
df = df.withColumn(
"negative_score", F.size(F.filter("polarity", lambda x: x == "negative"))
)
df = df.withColumn(
"original_sentiment",
F.when(F.col("positive_score") >= F.col("negative_score"), "positive").otherwise(
"negative"
),
)
df = df.drop("polarity")
# 判断付け出来ないデータはカット
df = df.filter("polarity_size > 0")
display(df)
69件のテキストデータセットが出来ました。
Step4. ChatTemplateの登録&モデルのロード
今回もモデルとしてOpenChat v1.5を利用します。
事前にダウンロードしておいた以下のモデルを利用します。
まずはプロンプトテンプレートを以下のように登録。
from sglang.lang.chat_template import (
get_chat_template,
register_chat_template,
ChatTemplate,
)
register_chat_template(
ChatTemplate(
name="openchat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", "\n"),
"user": ("GPT4 Correct User: ", "<|end_of_turn|>"),
"assistant": ("GPT4 Correct Assistant: ", "<|end_of_turn|>"),
},
)
)
モデルをロード。今回はFlashinferを使わないことにします。
from sglang import function, system, user, assistant, gen, set_default_backend, Runtime
model_path = "/Volumes/training/llm/model_snapshots/models--TheBloke--openchat-3.5-0106-GPTQ"
runtime = Runtime(model_path)
runtime.endpoint.chat_template = get_chat_template("openchat")
# OpenChat-3.5-0106のtokenizerファイルのバグ対応
runtime.get_tokenizer().eos_token_id = 32000
set_default_backend(runtime)
ここまで準備でした。
Step5. 推論
今回のポイントである、LLMによる推論とJSON出力を試します。
今回は、入力に埋め込んだsentence_id
とネガポジ分類、確度の3項目をJSON形式で出力するようにしてみました。
まずはそのための処理を定義します。
import time
import sglang as sgl
# JSONテンプレート
sentiment_regex = (
r"""\{\n"""
+ r""" "sentence_id": [-+]?[0-9]{1,9},\n"""
+ r""" "sentiment": "(positive|negative)",\n"""
+ r""" "sentiment_probability": [0-1]{1,1}\.[0-9]{1,1}\n"""
+ r"""\}"""
)
@sgl.function
def sentiment_gen(s, sentence_id, sentence):
s += "Please fill in the sentiment class(positive or negative), classification probability(between 0.0 and 1.0), and sentence ID with the following sentence and id.\n"
s += "SENTENCE:\n" + sentence + "\nSENTENCE ID:" + str(sentence_id) + "\n\n"
s += sgl.gen("json_output", max_tokens=128, regex=sentiment_regex)
def batch_sentiment_gen(df, num_threads:int=1):
pdf = df.select("sentence_id","sentence").toPandas()
sentences = pdf.to_dict("records")
tic = time.time()
states = sentiment_gen.run_batch(
sentences,
temperature=0.8,
num_threads=num_threads,
progress_bar=True,
)
latency = time.time() - tic
return states, latency
ウォームアップのための初期ロードを兼ねて、上記の処理を1件実行してみます。
# 1件のみ実行
states, latency = batch_sentiment_gen(df.limit(1))
for s in states:
print(s.text())
以下のような出力が得られます。
Please fill in the sentiment class(positive or negative), classification probability(between 0.0 and 1.0), and sentence ID with the following sentence and id.
SENTENCE:
当連結会計年度(平成28年4月1日~平成29年3月31日)における日本経済は、企業収益の堅調な推移や雇用・所得環境の着実な改善などを背景に、緩やかな景気回復の動きが見られる一方で、英国や米国における経済政策の変化や中国をはじめとする新興国経済の下振れリスクを含む海外経済動向の影響などにより、先行きの不透明感は高まりました
SENTENCE ID:0
{
"sentence_id": 0,
"sentiment": "positive",
"sentiment_probability": 0.5
}
では、本番実行です。
全件出力させ、かつレイテンシを計測します。
# 全件実行(並列なし)
states, latency = batch_sentiment_gen(df, num_threads=1)
for s in states:
print(s.text())
print(f"Latency: {latency:.3f} sec.")
Please fill in the sentiment class(positive or negative), classification probability(between 0.0 and 1.0), and sentence ID with the following sentence and id.
SENTENCE:
当連結会計年度(平成28年4月1日~平成29年3月31日)における日本経済は、企業収益の堅調な推移や雇用・所得環境の着実な改善などを背景に、緩やかな景気回復の動きが見られる一方で、英国や米国における経済政策の変化や中国をはじめとする新興国経済の下振れリスクを含む海外経済動向の影響などにより、先行きの不透明感は高まりました
SENTENCE ID:0
{
"sentence_id": 0,
"sentiment": "negative",
"sentiment_probability": 0.9
}
Please fill in the sentiment class(positive or negative), classification probability(between 0.0 and 1.0), and sentence ID with the following sentence and id.
SENTENCE:
債券市場では、低下が続いていた国内長期金利が平成28年7月以降上昇に転じ、同年11月の米国の大統領選挙の結果を受けて世界的に長期国債利回りが上昇した流れもあり、小幅ながらさらに上昇したものの、日銀の緩和的な金融政策により依然として低水準にとどまっています
SENTENCE ID:1
{
"sentence_id": 1,
"sentiment": "positive",
"sentiment_probability": 0.6
}
--- 中略 ---
Please fill in the sentiment class(positive or negative), classification probability(between 0.0 and 1.0), and sentence ID with the following sentence and id.
SENTENCE:
また、国際業務部門の資金運用収支は、41億4百万円、役務取引等収支は84百万円、その他業務収支は47億80百万円となりました
SENTENCE ID:189
{
"sentence_id": 189,
"sentiment": "positive",
"sentiment_probability": 0.9
}
Latency: 17.804 sec.
69件の推論出力が直列で18秒足らずで完了。1件0.3秒かからずJSON出力できています。
え、速い。。。
当然並列出力すると、さらにスループットがあがります。
# 全件実行&4スレッド並列
states, latency = batch_sentiment_gen(df, num_threads=4)
for s in states:
print(s.text())
print(f"Latency: {latency:.3f} sec.")
--- 前略 ---
Please fill in the sentiment class(positive or negative), classification probability(between 0.0 and 1.0), and sentence ID with the following sentence and id.
SENTENCE:
また、国際業務部門の資金運用収支は、41億4百万円、役務取引等収支は84百万円、その他業務収支は47億80百万円となりました
SENTENCE ID:189
{
"sentence_id": 189,
"sentiment": "positive",
"sentiment_probability": 0.8
}
Latency: 3.841 sec.
69件の推論出力が4並列で3.8秒。既にLLMで処理している実感がない。
最後に得られた結果をSparkデータフレームに変換し、元の表と結合して表示。
ret = []
for s in states:
ret.append(json.loads(s.get_var("json_output")))
sentiment_df = spark.createDataFrame(pd.DataFrame(ret))
result_df = df.join(sentiment_df, ["sentence_id"], "left")
result_df = result_df.withColumn(
"check", F.col("original_sentiment") == F.col("sentiment")
)
result_df = result_df.select(
"sentence_id",
"sentence",
"original_sentiment",
"sentiment",
"sentiment_probability",
"check",
)
display(result_df)
ちなみに8割程度がオリジナルの分類結果と一致していました。
LLMやプロンプトによってこのあたりの結果は大きく変わると思います。
Step6. 終了処理
最後に、ラインタイムの終了処理。
runtime.shutdown()
まとめ
SGLangの高速JSONデコードを試してみました。
JSONでLLMから結果を得たい際、特に大量データの分類処理において非常によさそうだと感じました。シンプルに凄い。
SGLang、まだまだ発展途上感は強いのですが、もう少し使い込んでみようと思います。