2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

DatabricksのLLM最適化モデルサービングにLlama2モデルをデプロイしてみる

Posted at

こちらのサンプルノートブックをウォークスルーします。

最適化Llama2サービングのサンプル

最適化LLMサービングによって、最先端のOSSのLLMを活用し、GPUによって改善されたレーテンシーとスループットを提供する自動最適化機能を持つDatabricksモデルサービングにデプロイすることができます。現時点では、Llama2とMosaic MPTクラスのモデルの最適化をサポートしています。

このサンプルでは以下をウォークスルーします:

  1. Hugging Face transformersからモデルをダウンロード
  2. DatabricksのUnity Catalogあるいはワークスペースレジストリに、最適化されたサービングサポートフォーマットでモデルを記録
  3. モデルの最適化サービングの有効化

前提条件

  • ノートブックを十分なメモリーを持つクラスターにアタッチ
  • MLflowバージョン2.7.0以降がインストールされていること
  • 7B以上のサイズのモデルを取り扱う際には特にUCでのモデル管理を行うこと

ステップ1: 最適化LLMサービングのためにモデルを記録

# 必要な依存関係を更新/インストール
!pip install -U mlflow
!pip install -U transformers
!pip install -U accelerate
!pip install -U flash_attn # DBR 13.3 MLでエラーになったので追加
dbutils.library.restartPython()

Hugging faceのトークンを入力します。詳細はこちら

import huggingface_hub
# 既にhugging faceにログインしている際にはスキップ
huggingface_hub.login()
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", 
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf"
)
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, Schema
import numpy as np

# モデルの入出力スキーマの定義
input_schema = Schema([
    ColSpec("string", "prompt"),
    ColSpec("double", "temperature", optional=True),
    ColSpec("integer", "max_tokens", optional=True),
    ColSpec("string", "stop", optional=True),
    ColSpec("integer", "candidate_count", optional=True)
])

output_schema = Schema([
    ColSpec('string', 'predictions')
])

signature = ModelSignature(inputs=input_schema, outputs=output_schema)

# 入力サンプルの定義
input_example = {
    "prompt": np.array([
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n"
        "What is Apache Spark?\n\n"
        "### Response:\n"
    ])
}

最適化サービングを有効化するには、モデルを記録する際に以下のようにmlflow.transformers.log_modelを呼び出して追加のメタデータのディクショナリーを含めます:

metadata = {"task": "llm/v1/completions"}

ここで、モデルサービングエンドポイントで使用されるAPIシグネチャを指定します。

import mlflow

# UCでのモデル管理を行わない場合には以下の行をコメントアウトします
# 3レベルの名前空間ではなく、シンプルにモデル名を指定します
mlflow.set_registry_uri('databricks-uc')
CATALOG = "quickstart_catalog_taka"
SCHEMA = "llm-catalog"
registered_model_name = f"{CATALOG}.{SCHEMA}.llama2-7b"

# 新規MLflowランをスタート
with mlflow.start_run():
    components = {
        "model": model,
        "tokenizer": tokenizer,
    }
    mlflow.transformers.log_model(
        transformers_model=components,
        task = "text-generation",
        artifact_path="model",
        registered_model_name=registered_model_name,
        signature=signature,
        input_example=input_example,
        metadata={"task": "llm/v1/completions"}
    )

モデルが記録されます。4分弱かかります。
Screenshot 2023-10-08 at 17.12.08.png

カタログエクスプローラからも確認できます。
Screenshot 2023-10-08 at 17.11.42.png

ステップ2: モデルサービングGPUエンドポイントの設定および作成

エンドポイント名を変更するには以下のセルを修正します。エンドポイントAPIを呼び出した後は、記録されたLlama2モデルは自動的に最適化LLMサービングにデプロイされます。workload_typeは環境で利用できるスペックのGPUを指定する必要があります。

# MLflowエンドポイント名の設定
endpoint_name = "taka-llama2-7b"

# 登録されたMLflowモデル名
model_name = registered_model_name

# MLflowモデルの最新バージョンの取得
model_version = 1

# 計算資源タイプの指定 (CPU, GPU_SMALL, GPU_MEDIUM, etc.)
workload_type = "GPU_MEDIUM"

# 計算資源のスケールアウトサイズの指定 (Small, Medium, Large, etc.)
workload_size = "Small" 

# ゼロへのスケールの指定 (CPUエンドポイントでのみサポート)
scale_to_zero = False 

# 現在のノートブックコンテキストからAPIエンドポイントとトークンを取得
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
import requests
import json

data = {
    "name": endpoint_name,
    "config": {
        "served_models": [
            {
                "model_name": model_name,
                "model_version": model_version,
                "workload_size": workload_size,
                "scale_to_zero_enabled": scale_to_zero,
                "workload_type": workload_type,
            }
        ]
    },
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

response = requests.post(url=f"{API_ROOT}/api/2.0/serving-endpoints", json=data, headers=headers)

print(json.dumps(response.json(), indent=4))

エンドポイントの参照

エンドポイントのより詳細な情報を参照するには、左のナビゲーションバーのServingにアクセスし、エンドポイント名で検索します。
Screenshot 2023-10-09 at 15.00.59.png

ステップ3: エンドポイントへのクエリー

エンドポイントの準備ができたら、APIリクエストを行うことでクエリーできるようになります。モデルのサイズや複雑性に依存しますが、準備ができるまでに30分以上を要します。

GPU Mediumの場合、30分程度でReadyになりました。

data = {
    "inputs": {
        "prompt": [
            "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is Apache Spark?\n\n### Response:\n"
        ]
    },
    "params": {
        "max_tokens": 100, 
        "temperature": 0.0
    }
}

headers = {
    "Context-Type": "text/json",
    "Authorization": f"Bearer {API_TOKEN}"
}

response = requests.post(
    url=f"{API_ROOT}/serving-endpoints/{endpoint_name}/invocations",
    json=data,
    headers=headers
)

print(json.dumps(response.json(), indent=4))

クエリーできました!

JSON
{
    "predictions": [
        {
            "candidates": [
                {
                    "text": "Apache Spark is an open-source data processing engine that can handle large-scale data processing tasks. It was developed at the University of California, Berkeley and is now maintained by Apache Software Foundation. Spark provides high-level APIs in Java, Python, Scala, and R, and supports a wide range of data sources, including Hadoop Distributed File System (HDFS), Cassandra, HBase, and Hive. Spark's core features include in-memory computing, fault",
                    "metadata": {
                        "finish_reason": "length"
                    }
                }
            ],
            "metadata": {
                "input_tokens": 41,
                "output_tokens": 100,
                "total_tokens": 141
            }
        }
    ]
}

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?