0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Databricks Mosaic AIを用いたLLMバッチ推論

Last updated at Posted at 2024-11-21

個人的にはLLMのバッチ推論にユースケースの可能性を感じています。なぜなら、以下のブログ記事で触れているように、多くの企業ではバッチ処理を適用できるようなテキストを格納しているテーブルがあるからです。

Databricksで提供しているこちらの機能をサンプルノートブックで試していきます。

プロビジョニングされたスループットエンドポイントを使用したバッチ推論

このノートブックは、大規模なバッチ推論ワークロードのためにプロビジョニングされたスループットエンドポイントを設定し、ai_queryを使用して行を処理します。

要件

続行する前に、以下が必要です:

  • 次のいずれかの構成を持つコンピュート:
    • オールパーパスコンピュート:
      • コンピュートサイズ i3.2xlarge 以上。
      • Databricks DBR 15.4 LTS。
      • 少なくとも2つのワーカーを持つマルチノード。
    • SQL ウェアハウス:
      • SQL ウェアハウス Medium 以上。
  • Databricks Unity Catalog に登録された大規模言語モデル (LLM)。
    以下のいずれかを行うことができます:
    • Unity Catalog または Marketplace で事前トレーニングされた LLM を使用する (AWS |Azure)
    • 独自の LLM モデルを手動でトレーニングして Unity Catalog に登録する。
    • Databricks はバッチ推論ワークロードに Meta Llama 3.1 8B または 70B を使用することを強く推奨します。

パラメータリファレンス

パラメータ名 値の例 説明
model_uc_path system.ai.meta_llama_v3_1_70b バッチ推論に使用するUnity CatalogのDatabricksモデルの名前。最新バージョンのモデルが使用されます。
prompt 提供されたZIPコードに対応する米国の州名を教えてください。 推論に使用するプロンプト。各入力行はこのプロンプトに連結されます。
request_params {"max_tokens": 1000, "temperature": 0} 推論に使用するリクエストパラメータ。
input_table_name samples.nyctaxi.trips Unity Catalogの入力テーブルの名前。
input_column_name pickup_zip 入力テーブルで使用される列の名前。
output_table_name main.default.batch_nyctaxi_trips Unity Catalogの出力テーブルの名前。出力テーブルの既存の行は上書きされます。
output_column_name state_name 生成されたデータを格納する出力テーブルの列の名前。
endpoint_name llama_batch_inference (オプション) 作成されるプロビジョニングされたスループットエンドポイントの名前。
model_uc_version 2 (オプション) 使用するUnity CatalogのDatabricksモデルのバージョン。指定しない場合、最新バージョンのモデルが使用されます。
input_num_rows 100000 (オプション) 処理するinput_table_nameの行数。空の場合、input_table_nameのすべての行が処理されます。
%pip install transformers
dbutils.library.restartPython()
import requests, json, math, os, time, uuid
import numpy as np
import pandas as pd
from pyspark.sql.functions import col, rand
# ウィジェットの設定
dbutils.widgets.text("model_uc_path", "system.ai.meta_llama_v3_1_70b_instruct", "Model UC Path")
dbutils.widgets.text("input_table_name", "samples.nyctaxi.trips", "Input Table Name")
dbutils.widgets.text("input_column_name", "pickup_zip", "Input Column Name")
dbutils.widgets.text("output_table_name", "main.default.batch_nyctaxi_trips", "Output Table Name")
dbutils.widgets.text("output_column_name", "state_name", "Output Column Name")
dbutils.widgets.text("prompt", "Could you tell me the state name of the zip code provided? zip code: ", "Prompt")
dbutils.widgets.text("request_params", '{"max_tokens": 1000, "temperature": 0}', "Request Parmaeters")
dbutils.widgets.text("endpoint_name", "", "(optional) Endpoint Name")
dbutils.widgets.text("model_uc_version", "", "(optional) Model Version")
dbutils.widgets.text("input_num_rows", "", "(optional) Input Number of Rows")

# パラメータの取得
model_uc_path = dbutils.widgets.get("model_uc_path")
input_table_name = dbutils.widgets.get("input_table_name")
input_column_name = dbutils.widgets.get("input_column_name")
output_table_name = dbutils.widgets.get("output_table_name")
output_column_name = dbutils.widgets.get("output_column_name")
prompt = dbutils.widgets.get("prompt")
request_params = dbutils.widgets.get("request_params")
endpoint_name = dbutils.widgets.get("endpoint_name")
model_uc_version = dbutils.widgets.get("model_uc_version")
input_num_rows = dbutils.widgets.get("input_num_rows")

# パラメータの検証
assert model_uc_path.count(".") == 2, "model_uc_pathはcatalog.schema.model_nameの形式である必要があります"
assert input_table_name.count(".") == 2, "input_table_nameはcatalog.schema.tableの形式である必要があります"
assert input_column_name, "input_column_nameは空にできません"
assert output_table_name.count(".") == 2, "output_table_nameはcatalog.schema.tableの形式である必要があります"
assert output_column_name, "output_column_nameは空にできません"
try:
    request_params = json.loads(dbutils.widgets.get("request_params"))
except:
    assert False, "request_paramsは有効なjsonである必要があります"
if model_uc_version:
    assert model_uc_version.isdigit(), "model_uc_versionは数字である必要があります"
    model_uc_version = int(model_uc_version)
    assert model_uc_version >= 1, "model_uc_versionは正の整数である必要があります"
if input_num_rows:
    assert input_num_rows.isdigit(), "input_num_rowsは数字である必要があります"
    input_num_rows = int(input_num_rows)
    assert input_num_rows >= 1, "input_num_rowsは正の整数である必要があります"

ノートブックの上部にウィジェットが表示されます。

Screenshot 2024-11-21 at 10.43.55.png

パート1: プロビジョニングされたスループットエンドポイントの作成

これはバッチ推論のための大規模なプロビジョニングされたスループットエンドポイントを作成します。エンドポイントのサイズはPTU(プロビジョニングされたスループットユニット)で指定されます。バッチ推論が完了した後、エンドポイントは停止されます。

from mlflow.utils.databricks_utils import get_databricks_host_creds, get_workspace_url

host = get_databricks_host_creds("databricks").host
assert host, "ホストURLを取得できません"

def getHeaders():
    credentials = get_databricks_host_creds("databricks")
    return {
        "Authorization": f"Bearer {credentials.token}",
        "Content-Type": "application/json",
    }

def create_endpoint(name, model_name, model_version, pt):
    data = {
        "name": name,
        "config": {
            "served_entities": [
                {
                    "entity_name": model_name,
                    "entity_version": model_version,
                    "min_provisioned_throughput": pt,
                    "max_provisioned_throughput": pt,
                }
            ]
        },
    }
    response = requests.post(
        url=f"{host}/api/2.0/serving-endpoints", json=data, headers=getHeaders()
    )
    print(json.dumps(response.json(), indent=4))

def monitor_endpoint(name):
    # エンドポイントが準備完了になるまで監視します。これには数分かかる場合があります。
    start_time = time.time()
    timeout = 2 * 60 * 60
    while True:
        response = requests.get(
            url=f"{host}/api/2.0/serving-endpoints/{name}", headers=getHeaders()
        )
        if response.status_code == 404:
            raise Exception(f"エンドポイント {name} が見つかりません")
        elif response.status_code != 200:
            print(f"リクエストがステータス {response.status_code} で失敗しました, {response.text}")
        else:
            serving_endpoint = response.json()
            state = serving_endpoint["state"]
            if state["ready"] == "READY" and state["config_update"] == "NOT_UPDATING":
                print(f"エンドポイント {name} は準備完了です")
                break
            else:
                print(f"エンドポイント {name} は準備ができていません, 状態: {state}")
        time.sleep(30)
        if time.time() - start_time > timeout:
            raise Exception(f"エンドポイント {name} の作成がタイムアウトしました")

バッチ推論エンドポイントのサイズをPTUで設定します。PTUを増やすと推論スループットが向上します。GPU容量の制限により、エンドポイントの作成時にエラーが発生した場合は、PTUの数を減らすことを検討してください。

私は上述のエラーに遭遇したので1にしています。

num_ptus = 1
# エンドポイントを構成する
if not model_uc_version:
    model_versions = requests.get(
        url=f"{host}/api/2.1/unity-catalog/models/{model_uc_path}/versions",
        headers=getHeaders(),
    ).json()
    model_uc_version = max(version['version'] for version in model_versions['model_versions'])

if not endpoint_name:
    endpoint_name = str(uuid.uuid4())

optimizable_info = requests.get(
    url=f"{host}/api/2.0/serving-endpoints/get-model-optimization-info/{model_uc_path}/{model_uc_version}",
    headers=getHeaders(),
).json()
if "optimizable" not in optimizable_info or not optimizable_info["optimizable"]:
    raise ValueError(
        f"モデル {model_uc_path} バージョン {model_uc_version} は存在しないか、プロビジョニングスループットの対象ではありません"
    )
chunk_size = optimizable_info["throughput_chunk_size"]

# エンドポイントを作成する
print(f"プロビジョニングスループットエンドポイントを作成中:")
print(f"エンドポイント: {endpoint_name}, モデル: {model_uc_path}, バージョン: {model_uc_version}")
print(f"同時実行数: {num_ptus*chunk_size} ({num_ptus} PTU)")

create_endpoint(endpoint_name, model_uc_path, model_uc_version, num_ptus*chunk_size)
monitor_endpoint(endpoint_name)
プロビジョニングスループットエンドポイントを作成中:
エンドポイント: taka_batch_endpoint, モデル: system.ai.meta_llama_v3_1_70b_instruct, バージョン: 3
同時実行数: 6000 (1 PTU)
{
    "name": "taka_batch_endpoint",
    "creator": "takaaki.yayoi@databricks.com",
    "creation_timestamp": 1732152126000,
    "last_updated_timestamp": 1732152126000,
    "state": {
        "ready": "NOT_READY",
        "config_update": "IN_PROGRESS",
        "suspend": "NOT_SUSPENDED"
    },
    "pending_config": {
        "served_entities": [
            {
                "name": "meta_llama_v3_1_70b_instruct-3",
                "entity_name": "system.ai.meta_llama_v3_1_70b_instruct",
                "entity_version": "3",
                "min_provisioned_throughput": 6000,
                "max_provisioned_throughput": 6000,
                "workload_size": "Small",
                "workload_type": "GPU_XLARGE_8",
                "min_dbus": 424.286,
                "max_dbus": 424.286,
                "state": {
                    "deployment": "DEPLOYMENT_CREATING",
                    "deployment_state_message": "Creating resources for served entity."
                },
                "creator": "takaaki.yayoi@databricks.com",
                "creation_timestamp": 1732152126000
            }
        ],
        "served_models": [
            {
                "name": "meta_llama_v3_1_70b_instruct-3",
                "min_provisioned_throughput": 6000,
                "max_provisioned_throughput": 6000,
                "workload_size": "Small",
                "workload_type": "GPU_XLARGE_8",
                "model_name": "system.ai.meta_llama_v3_1_70b_instruct",
                "model_version": "3",
                "min_dbus": 424.286,
                "max_dbus": 424.286,
                "state": {
                    "deployment": "DEPLOYMENT_CREATING",
                    "deployment_state_message": "Creating resources for served entity."
                },
                "creator": "takaaki.yayoi@databricks.com",
                "creation_timestamp": 1732152126000
            }
        ],
        "traffic_config": {
            "routes": [
                {
                    "served_model_name": "meta_llama_v3_1_70b_instruct-3",
                    "traffic_percentage": 100,
                    "served_entity_name": "meta_llama_v3_1_70b_instruct-3"
                }
            ]
        },
        "config_version": 1,
        "start_time": 1732152126000
    },
    "id": "9731288b510d4d32a01445b4b3aed4c6",
    "permission_level": "CAN_MANAGE",
    "route_optimized": false,
    "creator_display_name": "Takaaki Yayoi",
    "creator_kind": "User"
}
エンドポイント taka_batch_endpoint は準備ができていません, 状態: {'ready': 'NOT_READY', 'config_update': 'IN_PROGRESS', 'suspend': 'NOT_SUSPENDED'}
エンドポイント taka_batch_endpoint は準備ができていません, 状態: {'ready': 'NOT_READY', 'config_update': 'IN_PROGRESS', 'suspend': 'NOT_SUSPENDED'}
エンドポイント taka_batch_endpoint は準備ができていません, 状態: {'ready': 'NOT_READY', 'config_update': 'IN_PROGRESS', 'suspend': 'NOT_SUSPENDED'}
エンドポイント taka_batch_endpoint は準備ができていません, 状態: {'ready': 'NOT_READY', 'config_update': 'IN_PROGRESS', 'suspend': 'NOT_SUSPENDED'}
エンドポイント taka_batch_endpoint は準備ができていません, 状態: {'ready': 'NOT_READY', 'config_update': 'IN_PROGRESS', 'suspend': 'NOT_SUSPENDED'}
エンドポイント taka_batch_endpoint は準備完了です

少し待つと、モデルサービングエンドポイントが稼働します。

uploading...0

パート2: バッチ推論に ai_query を使用する

以下は、プロビジョニングスループットエンドポイントをクエリして、データでバッチ推論を実行するための ai_query の設定です。

# ソーステーブルから読み込む
df = spark.table(input_table_name)
if input_num_rows:
    df = df.limit(input_num_rows)
input_num_rows = df.count()
print(f"テーブル {input_table_name} には {input_num_rows} 行があります")

# (オプション) 入力データを前処理する
pass

最初は以下のパラメーターで実行します:

  • endpoint_name: taka_batch_endpoint
  • input_table_name: samples.nyctaxi.trips
  • input_column_name: pickup_zip
  • output_table_name: takaakiyayoi_catalog.batch.batch_result
  • output_column_name: state
  • prompt: 提供されたZIPコードに対応する米国の州名を教えてください。
# バッチ推論に ai_query を使用する
model_params_str = "named_struct(" + ", ".join([f"'{k}', {v}" for k, v in request_params.items()]) + ")"
ai_query_expr = f"""ai_query('{endpoint_name}', CONCAT('{prompt}', {input_column_name}), modelParameters => {model_params_str}) as {output_column_name}"""
df_out = df.selectExpr(ai_query_expr)

# 出力テーブルに書き込む
df_out.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(output_table_name)

郵便番号を判定していますが、元データがそもそもニューヨークのものしかないので、あまり面白くないです。

Screenshot 2024-11-21 at 11.14.00.png

ソーステーブルを変更して、パラメーターを変更します。

  • endpoint_name: taka_batch_endpoint
  • input_table_name: takaakiyayoi_catalog.qiita_2023.taka_qiita_2023
  • input_column_name: body
  • output_table_name: takaakiyayoi_catalog.batch.summary_result
  • output_column_name: summary
  • prompt: 提供されたテキストを100字程度の要約にしてください

ソーステーブルはこのようになっています。

Screenshot 2024-11-21 at 11.17.26.png

パラメーターの読み込みと、ai_queryの呼び出し部分のみを再度実行します。

# ウィジェットの設定
dbutils.widgets.text("model_uc_path", "system.ai.meta_llama_v3_1_70b_instruct", "Model UC Path")
dbutils.widgets.text("input_table_name", "samples.nyctaxi.trips", "Input Table Name")
dbutils.widgets.text("input_column_name", "pickup_zip", "Input Column Name")
dbutils.widgets.text("output_table_name", "main.default.batch_nyctaxi_trips", "Output Table Name")
dbutils.widgets.text("output_column_name", "state_name", "Output Column Name")
dbutils.widgets.text("prompt", "Could you tell me the state name of the zip code provided? zip code: ", "Prompt")
dbutils.widgets.text("request_params", '{"max_tokens": 1000, "temperature": 0}', "Request Parmaeters")
dbutils.widgets.text("endpoint_name", "", "(optional) Endpoint Name")
dbutils.widgets.text("model_uc_version", "", "(optional) Model Version")
dbutils.widgets.text("input_num_rows", "", "(optional) Input Number of Rows")

# パラメータの取得
model_uc_path = dbutils.widgets.get("model_uc_path")
input_table_name = dbutils.widgets.get("input_table_name")
input_column_name = dbutils.widgets.get("input_column_name")
output_table_name = dbutils.widgets.get("output_table_name")
output_column_name = dbutils.widgets.get("output_column_name")
prompt = dbutils.widgets.get("prompt")
request_params = dbutils.widgets.get("request_params")
endpoint_name = dbutils.widgets.get("endpoint_name")
model_uc_version = dbutils.widgets.get("model_uc_version")
input_num_rows = dbutils.widgets.get("input_num_rows")

# パラメータの検証
assert model_uc_path.count(".") == 2, "model_uc_pathはcatalog.schema.model_nameの形式である必要があります"
assert input_table_name.count(".") == 2, "input_table_nameはcatalog.schema.tableの形式である必要があります"
assert input_column_name, "input_column_nameは空にできません"
assert output_table_name.count(".") == 2, "output_table_nameはcatalog.schema.tableの形式である必要があります"
assert output_column_name, "output_column_nameは空にできません"
try:
    request_params = json.loads(dbutils.widgets.get("request_params"))
except:
    assert False, "request_paramsは有効なjsonである必要があります"
if model_uc_version:
    assert model_uc_version.isdigit(), "model_uc_versionは数字である必要があります"
    model_uc_version = int(model_uc_version)
    assert model_uc_version >= 1, "model_uc_versionは正の整数である必要があります"
if input_num_rows:
    assert input_num_rows.isdigit(), "input_num_rowsは数字である必要があります"
    input_num_rows = int(input_num_rows)
    assert input_num_rows >= 1, "input_num_rowsは正の整数である必要があります"
# ソーステーブルから読み込む
df = spark.table(input_table_name)
if input_num_rows:
    df = df.limit(input_num_rows)
input_num_rows = df.count()
print(f"テーブル {input_table_name} には {input_num_rows} 行があります")

# (オプション) 入力データを前処理する
pass
# バッチ推論に ai_query を使用する
model_params_str = "named_struct(" + ", ".join([f"'{k}', {v}" for k, v in request_params.items()]) + ")"
ai_query_expr = f"""ai_query('{endpoint_name}', CONCAT('{prompt}', {input_column_name}), modelParameters => {model_params_str}) as {output_column_name}"""
df_out = df.selectExpr(ai_query_expr)

# 出力テーブルに書き込む
df_out.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(output_table_name)

作成されたテーブルを確認します。期待した通りに要約が作成されています。

Screenshot 2024-11-21 at 11.18.17.png

ソーステーブルと成し遂げたいことから、モデルの選択、プロンプトの定義をすることで、大量のテキストデータに様々な処理(要約、感情分析、翻訳、分類、情報抽出など)を迅速かつ容易に行う事ができます!是非ご活用ください。

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?