3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Databricksにおける要約生成モデルのファインチューニング

Last updated at Posted at 2024-10-10

ようやくファインチューニングがわかってきたような。

ウォークスルーしたノートブックはこちらです。

Databricksにおける要約生成モデルのファインチューニング

要件

  • Mosaic AI Model Training APIを使用するため、サポートしているリージョンである必要があります。
  • ワークスペースでUnity Catalogが有効化されている必要があります。
  • 15.4 LTS MLシングルノードクラスターで動作確認しています。GPUクラスターを使用する必要はありません。

ディレクトリ構成

以下のNotebooksを番号順に実行していきます。

  • Includes
    • Config: カタログ、データベース名などを設定します。実行する環境に応じて変更してください。
    • Demo-Create-Tables: デモ用のデータを格納するテーブルを作成します。
  • Notebooks
    • 01_Create_Instruction_Dataset: 1: 指示データセットの作成
    • 02_Instruction_Fine_Tuning: 2: インストラクションのファインチューニング
    • 03_Create_a_Provisioned_Throughput_Serving_Endpoint: 3: プロビジョニングされたスループット サービング エンドポイントの作成
    • 04_Query_Endpoint_and_Batch_Inference: 4: エンドポイントのクエリーとバッチ推論
    • 05_Offline_Evaluation: 5: オフライン評価
    • blog.jsonl: デモデータを格納するjsonl

環境設定

必要に応じてIncludes/Configを修正します。

# 事前にカタログとスキーマを作成し、以下で設定してください
CATALOG = "takaakiyayoi_catalog"
USER_SCHEMA = "finetune_summarizer"

# モデルサービングエンドポイント名。必要に応じて変更ください。
MODEL_ENDPOINT_NAME = "ft-summarize-endpoint"

# ファインチューニングするモデル
BASE_MODEL =  "meta-llama/Meta-Llama-3-8B-Instruct" # ファインチューニングする基盤モデル

# ファインチューニングしたモデルの名前
UC_MODEL_NAME = "blog_summarization_llm"

# 推論テーブル名のプレフィクス
INFERENCE_TABLE_PREFIX = "ift_request_response"
print(f"このデモでは {CATALOG}.{USER_SCHEMA} を使用します。データを書き込むために必要なカタログやスキーマを設定するには、Configノートブックで `CATALOG` や `SHARED_SCHEMA` 変数を指定してください。")
# UC指示データセット
INPUT_TABLE = "blogs_bronze"
OUTPUT_VOLUME = "blog_ift_data"
OUTPUT_TRAIN_TABLE = "blog_title_generation_train_ift_data"
OUTPUT_EVAL_TABLE = "blog_title_generation_eval_ift_data"

# デモで使用するブログデータのjsonlファイル名
RAW_JSONL_NAME = "blog.jsonl"
print(f"このデモでは、ボリューム {OUTPUT_VOLUME} に配置される {RAW_JSONL_NAME} から記事データを読み込み、テーブル {INPUT_TABLE} を作成します。このテーブルからトレーニングデータ {OUTPUT_TRAIN_TABLE} と評価用データ {OUTPUT_EVAL_TABLE} を作成します。")
print(f"このデモでは、基盤モデル {BASE_MODEL} をファインチューニングし、モデル {UC_MODEL_NAME} を作成します。ファインチューンしたモデルは、モデルサービングエンドポイント {MODEL_ENDPOINT_NAME} にデプロイされます。")

1: 指示データセットの作成

Notebooks/01_Create_Instruction_Datasetを実行します。

このノートブックでは、指示のファインチューニングに使用する指示データセットを作成します。解決しようとするユースケースは、過去のDatabricksブログ投稿のスタイルでブログ投稿のタイトルを生成することです。このため、過去のブログ投稿とそのタイトルを指示データセットに準備します。

手順:

  1. 生のDatabricks記事データを読み込む
  2. Unity Catalog Volumeにデータを書き込み、そこからテーブルを作成する。
  3. 空の行をフィルタリングし、データの重複を排除する
  4. ブログテキストをプロンプトに構造化する
  5. promptresponseの列を持つテーブルを作成する。ここで、response列はブログ投稿のタイトルです

注意
ここではデモデータに要約データがないためタイトルを用いていますが、実際には要約を準備するようにしてください。

インポート

環境を設定して、必要な変数とデータセットを読み込みます。

%run ../Includes/Config
このデモでは takaakiyayoi_catalog.finetune_summarizer を使用します。データを書き込むために必要なカタログやスキーマを設定するには、Configノートブックで `CATALOG` や `SHARED_SCHEMA` 変数を指定してください。
このデモでは、モデルサービングエンドポイント ft-summarize-endpoint を作成して使用します。

このデモでは、ボリューム blog_ift_data に配置される blog.jsonl から記事データを読み込み、テーブル blogs_bronze を作成します。このテーブルからトレーニングデータ blog_title_generation_train_ift_data と評価用データ blog_title_generation_eval_ift_data を作成します。
%run ../Includes/Demo-Create-Tables
スキーマの初期化...
スキーマ takaakiyayoi_catalog.finetune_summarizer を作成しました。

ボリューム blog_ift_data を作成しました

テーブル takaakiyayoi_catalog.finetune_summarizer.blogs_bronze を作成しました

データの初期化を完了しました。
from typing import Iterator, List
import pandas as pd
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
from pyspark.sql.types import StringType

データの読み込みとフィルタリング

指定されたUnity Catalogテーブルから生のブログデータを読み込み、'text'または'title'列にnullまたは空の値がある行をフィルタリングします。

def load_and_filter(table_name: str, response_col: str = "title") -> DataFrame:
    """
    テーブルをロードし、'text'または`response_col`でnullまたは空の文字列をフィルタリングします。

    引数:
        table_name: ロードするテーブルの名前。
        response_col: nullまたは空の文字列をフィルタリングする列。

    戻り値:
        フィルタリングされたDataFrame。
    """
    print(f"テーブルをロード中: {table_name}")
    df = spark.table(table_name)
    original_count = df.count()
    print(f"行数: {original_count}")

    print(f"\n'text'または'{response_col}'でnullまたは空の文字列をフィルタリング")
    filtered_df = filter_null_or_empty(df, ["text", response_col])
    filtered_count = filtered_df.count()
    print(f"削除された行数: {original_count - filtered_count}")
    print(f"フィルタリングされた数: {filtered_count}")

    return filtered_df
  

def filter_null_or_empty(df: DataFrame, columns: List[str]) -> DataFrame:
    """
    指定された列のいずれかがnullまたは空である行をフィルタリングします。

    引数:
        df: フィルタリングするDataFrame。
        columns: nullまたは空の値をチェックする列のリスト。

    戻り値:
        フィルタリングされたDataFrame。
    """
    print("指定された列のいずれかがnullまたは空である行をフィルタリング中...")
    for col in columns:
        print(f"\t列: {col}")
        df = df.filter((F.col(col).isNotNull()) & (F.col(col) != ""))
    return df
filtered_df = load_and_filter(table_name=f"{CATALOG}.{USER_SCHEMA}.{INPUT_TABLE}")
テーブルをロード中: takaakiyayoi_catalog.finetune_summarizer.blogs_bronze
行数: 100

'text'または'title'でnullまたは空の文字列をフィルタリング
指定された列のいずれかがnullまたは空である行をフィルタリング中...
	列: text
	列: title
削除された行数: 0
フィルタリングされた数: 100

重複排除

textおよびtitle列に基づいてフィルタリングされたデータセットの重複を排除し、ユニークなブログ投稿を保証します。

filtered_deduped_df = filtered_df.drop_duplicates(subset=["text", "title"])
filtered_deduped_count = filtered_deduped_df.count()
print(f"重複排除後の件数: {filtered_deduped_count}")
重複排除後の件数: 100

プロンプト列の追加

PromptTemplate クラスが以下を返すように空欄を埋めます:

  • 指示
  • ブログキー
  • ブログテキスト
  • 応答
class PromptTemplate:
    """クラスを使用して、インストラクションデータセットの生成のためのプロンプトテンプレートを表すクラス。"""

    def __init__(self, instruction: str, blog_key: str, response_key: str) -> None:
        self.instruction = instruction
        self.blog_key = blog_key
        self.response_key = response_key

    def generate_prompt(self, blog_text: str) -> str:
        """
        テンプレートと与えられたブログテキストを使用してプロンプトを生成します。

        Args:
            blog_text: ブログのテキスト。

        Returns:
            プロンプトテンプレート。
        """
        return f"""{self.instruction}
{self.blog_key}
{blog_text}
{self.response_key}
"""

プロンプトテンプレートの作成

  • instruction には、提供されたブログ記事に基づいてタイトルを生成するためのLLMへのポインタを含める必要があります
blog_title_generation_template = PromptTemplate(
    instruction="以下はDatabricksに関するブログ記事のテキストです。指定されたブログ記事の要約を作成してください。",
    blog_key="### ブログ記事:",
    response_key="### 要約:"
)  
def add_instruction_prompt_column(df: DataFrame, prompt_template: PromptTemplate) -> DataFrame:
    """
    DataFrameに指定されたテンプレートを使用して 'prompt' 列を追加します。

    Args:
        df: 入力のDataFrame。
        prompt_template: プロンプトを生成するために使用するテンプレート。

    Returns:
        'prompt' 列を持つDataFrame。
    """
    @F.pandas_udf(StringType())
    def generate_prompt(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
        for texts in batch_iter:
            prompts = texts.apply(prompt_template.generate_prompt)
            yield prompts

    return df.withColumn("prompt", generate_prompt(df["text"]))
# プロンプト列を追加
instruction_df = add_instruction_prompt_column(filtered_deduped_df, blog_title_generation_template)

# プロンプト列とタイトル列を選択し、タイトル列をレスポンス列に名前変更
instruction_df = instruction_df.selectExpr("prompt", "title as response")
display(instruction_df)

プロンプトとレスポンスを保持するデータが準備できました。
Screenshot 2024-10-10 at 17.30.38.png

プロンプトの例

print(instruction_df.select("prompt").limit(1).collect()[0]["prompt"])
以下はDatabricksに関するブログ記事のテキストです。指定されたブログ記事の要約を作成してください。
### ブログ記事:
前見た時には、Delta Sharingで共有できるのはDeltaテーブルだけだったのですが、今日になってファイルやノートブックも共有できることに気づきました(遅い)。

https://docs.databricks.com/ja/data-sharing/index.html

いずれもDatabricks間での共有ですが、クラウドやアカウントを越えて共有できるのでユースケースは結構あると思います。

# 共有の設定

共有側(プロバイダー)をワークスペースA、利用側(受信者)をワークスペースBとします。

## 共有の作成:ワークスペースAでの作業

カタログエクスプローラで共有を作成します。Delta Sharingの**自分が共有**にアクセスし、**データを共有**をクリックして新規共有を作成します。
![Screenshot 2023-12-27 at 18.32.27.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/1c51d2d9-5dca-0912-5b86-c3aecb8473b7.png)

**アセットを追加**をクリックします。左のペインからテーブルやボリュームを選択して保存します。
![Screenshot 2023-12-27 at 18.34.15.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/c21a416d-adcf-cd65-ceb7-e21791254abb.png)

これでテーブルとボリュームが共有に追加されました。
![Screenshot 2023-12-27 at 18.35.00.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/bcef94d3-6280-4381-6ed3-4c4a318679dc.png)

さらに右上の**アセットを管理**をクリックし、**ノートブックファイルを追加します**を選択します。
![Screenshot 2023-12-27 at 18.35.24.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/9b032c48-4776-f84e-b18c-0caabd3dfdd5.png)

ノートブックを選択して保存します。
![Screenshot 2023-12-27 at 18.36.32.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/2641f003-1174-d6f8-450f-b8520e0e2022.png)
![Screenshot 2023-12-27 at 18.36.46.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/227ee433-a4ec-9b2a-07a5-1f36c3ee772a.png)

## 受信者の作成:ワークスペースBでの作業

別のクラウドやアカウントにあるDatabricksと共有する際には、受信側の[共有識別子](https://docs.databricks.com/ja/data-sharing/recipient.html#get-access-in-the-databricks-to-databricks-model)が必要となります。共有識別子は`<cloud>:<region>:<uuid>`の形式となっており、以前は特定するのが面倒でしたが今では受信側のカタログエクスプローラで容易にコピーできます。

ワークスペースBのカタログエクスプローラにアクセスし、**Delta Sharing > 自分と共有**にアクセスします。画面上部に**共有識別子**のコピーボタンがあるのでこれをクリックしてコピーします。
![Screenshot 2023-12-27 at 18.40.32.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/e2419d62-f2ca-d95e-89d9-8bb9c291e389.png)

## 受信者の作成:ワークスペースAでの作業

提供側の**Delta Sharing > 自分が共有**にアクセスし、上で作成した共有オブジェクトにアクセスし、**受信者**タブをクリックします。**受信者を追加**をクリックします。
![Screenshot 2023-12-27 at 18.43.30.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/60c98e76-af9e-74ae-3eb2-d1d36910931c.png)

**+ 新規受信者を作成** をクリックします。
![Screenshot 2023-12-27 at 18.43.58.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/e9bd93e8-98a7-8d25-1f75-4c4f9b004b3d.png)

受信者名を入力し、上でコピーした共有識別子を貼り付けて、**受信者を作成および追加**をクリックします。
![Screenshot 2023-12-27 at 18.44.56.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/a2ea7e0c-d943-b9b9-aefc-6d2752b2bd4c.png)

**追加**をクリックします。
![Screenshot 2023-12-27 at 18.46.19.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/0caa8dee-4800-bc1b-7d1b-6a45ddbee653.png)

これで提供者側の作業は完了です。
![Screenshot 2023-12-27 at 18.46.41.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/e1859d29-6ea4-7767-4298-027dbe673843.png)

# 受信者側の作業

以降はワークスペースBでの作業となります。

## ファイル(ボリューム)へのアクセス

ワークスペースBのカタログエクスプローラにアクセスし、**Delta Sharing > 自分と共有**にアクセスします。プロバイダーに提供者一覧が表示されます。提供者名は提供側に確認してください。
![Screenshot 2023-12-27 at 18.50.08.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/7a90f612-e3ba-0579-3f13-69e98733717d.png)

提供側の共有が表示されます。カタログを作成をクリックします。
![Screenshot 2023-12-27 at 18.51.56.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/c01f3663-54f1-9d74-d2a0-3efeee6dcd30.png)

受信者側におけるカタログ名を入力し、作成をクリックします。
![Screenshot 2023-12-27 at 18.52.19.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/27f68922-e6d1-c9d1-3306-3c2bb2b538f4.png)

これで共有されたテーブルやボリュームにアクセスできるようになります。
![Screenshot 2023-12-27 at 18.53.18.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/41f68da6-769c-85dd-b89e-b344605007d2.png)
![Screenshot 2023-12-27 at 18.53.09.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/3ec94dfb-35d7-f002-d8bd-3742a33a9b36.png)

## ノートブックへのアクセス

上で作成したカタログにアクセスし、**その他のアセット**をクリックすると共有されているノートブックが表示されます。
![Screenshot 2023-12-27 at 18.55.17.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/c4d0d6a2-9b41-deb3-4144-eacb60322531.png)

クリックすることで内容を確認することができ、右上の**クローン作成**で任意の場所にコピーすることができます。
![Screenshot 2023-12-27 at 18.56.06.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/1168882/aef7adfb-02ec-d097-9ca4-87910cb8bc44.png)

テーブルやファイルを共有する際には、サンプルとなるノートブックがあった方がコラボレーションが円滑になるかと思います。是非ご活用ください!

### Databricksクイックスタートガイド

[Databricksクイックスタートガイド](https://www.amazon.co.jp/dp/B09V1YXFVQ/)


### Databricks無料トライアル

[Databricks無料トライアル](https://databricks.com/jp/try-databricks)

### 要約:

データをランダムにトレーニングデータとテストデータに分割します

train_df, eval_df = instruction_df.randomSplit([0.9,0.1], seed=42)
print(train_df.count(), eval_df.count())
95 5

トレーニングデータと評価データを別々のテーブルに書き込む

train_data_path = f"{CATALOG}.{USER_SCHEMA}.{OUTPUT_TRAIN_TABLE}"
eval_data_path = f"{CATALOG}.{USER_SCHEMA}.{OUTPUT_EVAL_TABLE}"

train_df.write.mode("overwrite").saveAsTable(train_data_path)
eval_df.write.mode("overwrite").saveAsTable(eval_data_path)

以下のコマンドを実行して表示されるカタログエクスプローラでテーブルを確認してみてください。依存関係タブでテーブル間のリネージを確認することもできます。

displayHTML(f"<a href='/explore/data/{CATALOG}/{USER_SCHEMA}/'>カタログエクスプローラ</a>")

Screenshot 2024-10-10 at 17.32.46.png
Screenshot 2024-10-10 at 17.33.01.png

2: インストラクションのファインチューニング

Notebooks/02_Instruction_Fine_Tuningを実行します。

このノートブックでは、事前訓練された言語モデルに対してインストラクションのファインチューニング(IFT)を実行する方法を示します。

目標:

  1. 指定されたハイパーパラメータで単一のIFT実行をトリガーする
%pip install databricks-genai==1.0.1
dbutils.library.restartPython()

環境をセットアップして、必要な変数とデータセットをロードします。

%run ../Includes/Config

ファインチューニングのランの作成

from databricks.model_training import foundation_model as fm

register_to = f"{CATALOG}.{USER_SCHEMA}.{UC_MODEL_NAME}" # ファインチューニングしたモデルの登録先
training_duration = "3ep" # エポック数
learning_rate = "3e-06" # 学習率
data_prep_cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId") # データ準備に使用するクラスターID

run = fm.create(
  model=BASE_MODEL,
  train_data_path=f"{CATALOG}.{USER_SCHEMA}.{OUTPUT_TRAIN_TABLE}",
  eval_data_path=f"{CATALOG}.{USER_SCHEMA}.{OUTPUT_EVAL_TABLE}",
  data_prep_cluster_id=data_prep_cluster_id,
  register_to=register_to,
  training_duration=training_duration,
  learning_rate=learning_rate,
)
run

Screenshot 2024-10-10 at 17.34.41.png

ファインチューニングの処理は非同期で行われます。以下のコマンドを実行して進捗を確認します。

# 進捗の確認
run.get_events()

Screenshot 2024-10-10 at 17.35.26.png

ファインチューニングされたモデルは、以下を実行して表示されるMLflowエクスペリメントで管理されます。エクスペリメントでトレーニングの進捗を確認することもできます。

displayHTML(f"<a href='/ml/experiments/{run.experiment_id}'>エクスペリメント</a>")

StatusCompletedになったことを確認して次のノートブックに進んでください。

3: プロビジョニングされたスループット サービング エンドポイントの作成

このノートブックでは、プロビジョニングされたスループット ファウンデーション モデル API を作成します。プロビジョニングされたスループットは、本番ワークロードのパフォーマンス保証を備えたファウンデーション モデルの最適化された推論を提供します。

サポートされているモデルアーキテクチャのリストについては、プロビジョニングされたスループット ファウンデーション モデル API を参照してください。

このノートブックでは以下を行います:

  1. デプロイするモデルを定義します。これは、Unity Catalogに登録されたファインチューニングされたモデルになります
  2. 登録されたモデルの最適化情報を取得します
  3. エンドポイントを設定して作成します
  4. エンドポイントをクエリします
%pip install databricks-sdk==0.31.1
dbutils.library.restartPython()

環境を設定して、必要な変数とデータセットを読み込みます。

%run ../Includes/Config
import json
import requests

import mlflow
import mlflow.deployments
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    ServedEntityInput,
    EndpointCoreConfigInput,
    AutoCaptureConfigInput,
)

設定

Unity Catalogの設定変数をセットアップします。登録されたモデル名とデプロイするモデルバージョンを定義します。また、エンドポイントの名前と、推論テーブルの名前も定義します。この推論テーブルは、エンドポイントのリクエストとレスポンスのペイロードを記録するために作成されます。

UCモデルのバージョン

# 前のノートブックでUCに登録したモデルのバージョンを指定します
MODEL_VERSION = 1

モデルの最適化情報を取得する

モデル名とモデルバージョンを指定することで、モデルの最適化情報を取得することができます。これは、特定のモデルに対して1つのスループットユニットに対応するトークン/秒の数です。

API_ROOT = mlflow.utils.databricks_utils.get_databricks_host_creds().host
API_TOKEN = mlflow.utils.databricks_utils.get_databricks_host_creds().token

def get_model_optimization_info(full_model_name: str, model_version: int):
    """指定された登録済みモデルとバージョンのモデル最適化情報を取得します。"""
    headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}
    url = f"{API_ROOT}/api/2.0/serving-endpoints/get-model-optimization-info/{full_model_name}/{model_version}"
    response = requests.get(url=url, headers=headers)
    return response.json()


# 登録済みモデルを指定して最適化情報を取得する
model_optimization_info = get_model_optimization_info(
    full_model_name=f"{CATALOG}.{USER_SCHEMA}.{UC_MODEL_NAME}", model_version=MODEL_VERSION
)
print("model_optimization_info: ", model_optimization_info)
min_provisioned_throughput = model_optimization_info["throughput_chunk_size"]
# コスト削減のために最小値と同じ値を最大値に設定します。予想されるリクエストの負荷に基づいてより高い数値を選択することもできます。
max_provisioned_throughput = model_optimization_info["throughput_chunk_size"]
model_optimization_info:  {'optimizable': True, 'model_type': 'llama', 'throughput_chunk_size': 3600, 'dbus': 106}

GPUモデルサービングエンドポイントの設定と作成

エンドポイントの作成APIを呼び出した後、ログされたモデルは自動的に最適化されたLLMサービングで展開されます。

w = WorkspaceClient()

_ = spark.sql(f"DROP TABLE IF EXISTS {CATALOG}.{USER_SCHEMA}.{INFERENCE_TABLE_PREFIX}_payload")

print("エンドポイントの作成中..")
w.serving_endpoints.create_and_wait(
    name=MODEL_ENDPOINT_NAME,
    config=EndpointCoreConfigInput(
        name=MODEL_ENDPOINT_NAME,
        served_entities=[
            ServedEntityInput(
                entity_name=f"{CATALOG}.{USER_SCHEMA}.{UC_MODEL_NAME}", # サービングエンドポイントにデプロイされるUC登録モデル
                entity_version=MODEL_VERSION, # モデルのバージョン
                max_provisioned_throughput=max_provisioned_throughput, # スループットの最大値
                min_provisioned_throughput=0, # スループットの最小値
                scale_to_zero_enabled=True # ゼロノードへのスケーリングを有効にする
            )
        ],
        # エンドポイントの入出力をキャプチャする推論テーブル
        auto_capture_config=AutoCaptureConfigInput(
            catalog_name=CATALOG,
            schema_name=USER_SCHEMA,
            enabled=True,
            table_name_prefix=INFERENCE_TABLE_PREFIX,
        ),
    ),
)

上を実行した際、TimeoutError: timed out after 0:20:00: current status: EndpointStateConfigUpdate.IN_PROGRESSというエラーが出たとしても、以下のセルを実行して表示されるリンク先で処理が進んでいれば問題ありません。

displayHTML(f"<a href='/ml/endpoints/{MODEL_ENDPOINT_NAME}' target='_blank'>エンドポイント{MODEL_ENDPOINT_NAME}の詳細</a>")

エンドポイントを表示

エンドポイントに関する詳細情報を表示するには、左側のナビゲーションバーでServingを選択し、エンドポイント名を検索してください。

モデルのサイズや複雑さによっては、エンドポイントが準備完了するまでに30分以上かかることがあります。

エンドポイントが稼働していることを確認して、次のノートブックに進みましょう。

4: エンドポイントのクエリーとバッチ推論

Notebooks/04_Query_Endpoint_and_Batch_Inferenceを実行します。

このノートブックでは、以下の操作を行います:

  1. 単一のリクエストでモデルサービングエンドポイントをクエリします

環境をセットアップして、必要な変数とデータセットをロードします。

%run ../Includes/Config
from bs4 import BeautifulSoup
import json
import requests

import mlflow.deployments
from mlflow.deployments import get_deploy_client
from mlflow.exceptions import MlflowException

ブログ記事のURLを提供してください

まず、ブログのURLを指定して、単一のブログ記事のテキストを取得しましょう。https://qiita.com/taka_yayoi にアクセスし、いずれかのブログ投稿のURLを選択してください。

以下の例では、こちらの記事を指定しています。

def get_single_blog_post(url: str) -> str:
    """
    ブログのURLを指定して、単一のブログ記事のテキストを取得します。

    Args:
        url (str): ブログ記事のURL。

    Returns:
        str: クリーンされたブログ記事のテキスト。
    """
    response = requests.get(url)
    soup = BeautifulSoup(response.text, "html.parser")
    
    # ブログ記事のテキストコンテナを見つける(サイトによって構成が異なります)
    blog_text_container = soup.find("div", class_="p-items_main")
    
    if blog_text_container:
        # HTMLタグを削除してテキストを抽出する
        blog_text = " ".join(blog_text_container.stripped_strings)
        
        # テキストをクリーンアップする
        blog_text = blog_text.replace("\\'", "'")
        blog_text = blog_text.replace(" ,", ",")
        blog_text = blog_text.replace(" .", ".")
        
        return blog_text
    else:
        print(f"URL {url} のブログ記事が見つかりませんでした。")
        return ""

url = "https://qiita.com/taka_yayoi/items/9ee9ba97ebdb82704244"
blog_post_text = get_single_blog_post(url)

blog_post_text
'@ taka_yayoi ( Takaaki Yayoi ) in データブリックス・ジャパン株式会社 Databricks Apps(アプリ)がやってきました! Flask Dash Databricks Streamlit gradio Last updated at 2024-10-10 Posted at 2024-10-08 とうとう、DatabricksでWebアプリが組めるように…。 Databricks Apps を使用すると、開発者は Databricks プラットフォーム上で安全なデータアプリケーションと AI アプリケーションを作成し、それらのアプリをユーザーと共有することができます。 アプリでは、ガバナンスのための Unity Catalog、データのクエリを実行するための Databricks SQL、モデルサービングなどの AI 機能、ETL 用の Databricks ジョブ、ワークスペース内の既に設定されているセキュリティルール (アプリで使用されるデータへのアクセスを制御するルールなど) など、Databricks プラットフォームのリソースと機能を使用できます。 認証と承認では、Databricks OAuthやサービス プリンシパルなどの既存の 機能を使用します。 プレビュー 本機能は パブリックプレビュー です。利用できるリージョンに関しては こちら をご確認ください。本記事の執筆時点では日本リージョンでは利用できません。 ということで、利用できるリージョンで試してみます。 Hello Worldアプリ サイドメニューの クラスター にアクセスすると、 アプリ タブが追加されています! Create app をクリックします。 テンプレートかカスタムか、使用するフレームワークを選択します。テンプレートでは、チャットbot、データアプリなどを選択できます。フレームワークには Streamlit を使用し、一番下の Hellow world を選択して、 Next をクリックします。 アプリ名、説明文を入力して、 Create and deploy app をクリックします。 デプロイメントがスタートします。 数分後にデプロイが完了しました。 実行中 に表示されているURLをクリックします。 アプリにアクセスできました! ソースコードの変更 上の デプロイメント に表示されているのがアプリのソースコードです。こちらをクリックして app.py を開きます。 import streamlit as st import pandas as pd st. set_page_config ( layout = " wide " ) st. header ( " Hello world!!! " ) apps = st. slider ( " Number of apps ", max_value = 60, value = 10 ) chart_data = pd. DataFrame ({ \' y \' :[ 2 ** x for x in range ( apps )]}) st. bar_chart ( chart_data, height = 500, width = min ( 100 + 50 * apps, 1000 ), use_container_width = False, x_label = " Apps ", y_label = " Fun with data " ) 以下のように修正します。 import streamlit as st import pandas as pd st. set_page_config ( layout = " wide " ) st. header ( " Hello world!!!はじめてのDatabricks App!!! " ) apps = st. slider ( " Number of apps ", max_value = 60, value = 10 ) chart_data = pd. DataFrame ({ \' y \' :[ 2 ** x for x in range ( apps )]}) st. bar_chart ( chart_data, height = 500, width = min ( 100 + 50 * apps, 1000 ), use_container_width = False, x_label = " Apps ", y_label = " Fun with data " ) 右上の デプロイ をクリックすると、修正されたコードを用いて再度デプロイが実行されます。 変更が反映されています。 データアプリ 今度はテンプレートで、 Data app を選択します。 データアプリではSQLウェアハウスが必要になるので、 アプリのリソース でSQLウェアハウスを選択します。 こちらは、Unity Catalogで管理されているサンプルのタクシー乗降記録データにアクセスするアプリです。 右上のボックスにZipコードを入力すると、移動料金を予測することもできます。 早く日本に来て欲しいです。 はじめてのDatabricks はじめてのDatabricks Databricks無料トライアル Databricks無料トライアル 8 Go to list of users who liked 5 comment 2 Go to list of comments Register as a new user and use Qiita more conveniently You get articles that match your needs You can efficiently read back useful information You can use dark theme What you can do with signing up Sign up Login'

1つ目のノートブックで取り組んだテンプレートを用いてプロンプトを構成

class PromptTemplate:
    """Promptテンプレートを表すクラス。データセット生成用。"""

    def __init__(self, instruction: str, blog_key: str, response_key: str) -> None:
        self.instruction = instruction
        self.blog_key = blog_key
        self.response_key = response_key

    def generate_prompt(self, blog_text: str) -> str:
        """
        テンプレートと与えられたブログテキストを使用してプロンプトを生成します。

        Args:
            blog_text: ブログのテキスト。

        Returns:
            プロンプトテンプレート。
        """
        return f"""{self.instruction}
{self.blog_key}
{blog_text}
{self.response_key}
"""

blog_title_generation_template = PromptTemplate(
    instruction="以下はDatabricksのブログ記事のテキストです。提供されたブログ記事にタイトルを作成してください。",
    blog_key="### ブログ:",
    response_key="### タイトル:"
)

prompt = blog_title_generation_template.generate_prompt(blog_post_text)
print(prompt)
以下はDatabricksのブログ記事のテキストです。提供されたブログ記事にタイトルを作成してください。
### ブログ:
@ taka_yayoi ( Takaaki Yayoi ) in データブリックス・ジャパン株式会社 Databricks Apps(アプリ)がやってきました! Flask Dash Databricks Streamlit gradio Last updated at 2024-10-10 Posted at 2024-10-08 とうとう、DatabricksでWebアプリが組めるように…。 Databricks Apps を使用すると、開発者は Databricks プラットフォーム上で安全なデータアプリケーションと AI アプリケーションを作成し、それらのアプリをユーザーと共有することができます。 アプリでは、ガバナンスのための Unity Catalog、データのクエリを実行するための Databricks SQL、モデルサービングなどの AI 機能、ETL 用の Databricks ジョブ、ワークスペース内の既に設定されているセキュリティルール (アプリで使用されるデータへのアクセスを制御するルールなど) など、Databricks プラットフォームのリソースと機能を使用できます。 認証と承認では、Databricks OAuthやサービス プリンシパルなどの既存の 機能を使用します。 プレビュー 本機能は パブリックプレビュー です。利用できるリージョンに関しては こちら をご確認ください。本記事の執筆時点では日本リージョンでは利用できません。 ということで、利用できるリージョンで試してみます。 Hello Worldアプリ サイドメニューの クラスター にアクセスすると、 アプリ タブが追加されています! Create app をクリックします。 テンプレートかカスタムか、使用するフレームワークを選択します。テンプレートでは、チャットbot、データアプリなどを選択できます。フレームワークには Streamlit を使用し、一番下の Hellow world を選択して、 Next をクリックします。 アプリ名、説明文を入力して、 Create and deploy app をクリックします。 デプロイメントがスタートします。 数分後にデプロイが完了しました。 実行中 に表示されているURLをクリックします。 アプリにアクセスできました! ソースコードの変更 上の デプロイメント に表示されているのがアプリのソースコードです。こちらをクリックして app.py を開きます。 import streamlit as st import pandas as pd st. set_page_config ( layout = " wide " ) st. header ( " Hello world!!! " ) apps = st. slider ( " Number of apps ", max_value = 60, value = 10 ) chart_data = pd. DataFrame ({ ' y ' :[ 2 ** x for x in range ( apps )]}) st. bar_chart ( chart_data, height = 500, width = min ( 100 + 50 * apps, 1000 ), use_container_width = False, x_label = " Apps ", y_label = " Fun with data " ) 以下のように修正します。 import streamlit as st import pandas as pd st. set_page_config ( layout = " wide " ) st. header ( " Hello world!!!はじめてのDatabricks App!!! " ) apps = st. slider ( " Number of apps ", max_value = 60, value = 10 ) chart_data = pd. DataFrame ({ ' y ' :[ 2 ** x for x in range ( apps )]}) st. bar_chart ( chart_data, height = 500, width = min ( 100 + 50 * apps, 1000 ), use_container_width = False, x_label = " Apps ", y_label = " Fun with data " ) 右上の デプロイ をクリックすると、修正されたコードを用いて再度デプロイが実行されます。 変更が反映されています。 データアプリ 今度はテンプレートで、 Data app を選択します。 データアプリではSQLウェアハウスが必要になるので、 アプリのリソース でSQLウェアハウスを選択します。 こちらは、Unity Catalogで管理されているサンプルのタクシー乗降記録データにアクセスするアプリです。 右上のボックスにZipコードを入力すると、移動料金を予測することもできます。 早く日本に来て欲しいです。 はじめてのDatabricks はじめてのDatabricks Databricks無料トライアル Databricks無料トライアル 8 Go to list of users who liked 5 comment 2 Go to list of comments Register as a new user and use Qiita more conveniently You get articles that match your needs You can efficiently read back useful information You can use dark theme What you can do with signing up Sign up Login
### タイトル:

エンドポイントへのクエリー

from mlflow.utils.databricks_utils import get_databricks_env_vars

mlflow_db_creds = get_databricks_env_vars("databricks")
API_TOKEN = mlflow_db_creds["DATABRICKS_TOKEN"]
WORKSPACE_URL = mlflow_db_creds["_DATABRICKS_WORKSPACE_HOST"]

# モデルのパラメーター
max_tokens = 128
temperature = 0.9

payload = {
    "prompt": [prompt], "max_tokens": max_tokens, "temperature": temperature
}

headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {API_TOKEN}"
}

response = requests.post(
    url=f"{WORKSPACE_URL}/serving-endpoints/{MODEL_ENDPOINT_NAME}/invocations",
    json=payload,
    headers=headers
)

predictions = response.json().get("choices")
print(predictions[0]["text"])

きちんとした結果が返ってきています。

Databricks AppsによるWebアプリケーションの作成

モデルからレスポンスが得られたことを確認して、最後のノートブックに進みましょう。

5: オフライン評価

Notebooks/05_Offline_Evaluationを実行します。

このノートブックでは:

  1. 単一のリクエストでモデル提供エンドポイントをクエリします
  2. 一連のリクエストでエンドポイントをクエリします
    • デモでは、pandas UDFを使用して一連のリクエストを送信しました。このラボでは、リクエスト構造を単純に変更して、モデル提供エンドポイントに対して送信できるプロンプトのリストを受け入れるようにします
%pip install databricks-sdk==0.31.1 textstat==0.7.3
dbutils.library.restartPython()

環境をセットアップして、必要な変数とデータセットをロードします。

%run ../Includes/Config
from databricks.sdk import WorkspaceClient
import json
import mlflow
from mlflow import MlflowClient
from mlflow.metrics.genai.metric_definitions import answer_similarity
import pandas as pd
from typing import Iterator
import pyspark.sql.functions as F
import requests
# 評価データ
table_name = f"{CATALOG}.{USER_SCHEMA}.{OUTPUT_EVAL_TABLE}"

API_TOKEN = mlflow.utils.databricks_utils.get_databricks_host_creds().host
WORKSPACE_URL = mlflow.utils.databricks_utils.get_databricks_host_creds().token

# モデルのパラメーター
max_tokens = 128
temperature = 0.9

データフレームを適切な形式に整形する

mlflow.evaluate では、評価用のデータフレームを "inputs" と "ground_truth" のカラムを持つ pandas のデータフレームとして整形する必要があります。

eval_df = (spark.table(table_name)
           .select("prompt", "response")
           .withColumnRenamed("prompt", "inputs")
           .withColumnRenamed("response", "ground_truth") 
           )
eval_pdf = eval_df.toPandas()
    
print(eval_pdf.count())
display(eval_pdf)
inputs          5
ground_truth    5
dtype: int64

Screenshot 2024-10-10 at 17.48.30.png

エンドポイントに対してプロンプトのバッチを送信する

ここでは、Databricks Python SDKやpandas UDFの代わりに、requestsライブラリを使用してエンドポイントに対してプロンプトのバッチを送信します。最初に、単一のプロンプトでテストを行います。

def get_predictions(prompts, max_tokens, temperature, model_serving_endpoint):
    from mlflow.utils.databricks_utils import get_databricks_env_vars
    import requests

    # Databricksの環境変数から認証情報を取得
    mlflow_db_creds = get_databricks_env_vars("databricks")
    API_TOKEN = mlflow_db_creds["DATABRICKS_TOKEN"]
    WORKSPACE_URL = mlflow_db_creds["_DATABRICKS_WORKSPACE_HOST"]

    # モデルに送るデータの準備
    payload = {"prompt": prompts, "max_tokens": max_tokens, "temperature": temperature}
    
    # 認証情報をヘッダーに設定
    headers = {"Content-Type": "application/json",
               "Authorization": f"Bearer {API_TOKEN}"
               }
    
    # モデルサービングエンドポイントにリクエストを送信
    response = requests.post(url=f"{WORKSPACE_URL}/serving-endpoints/{model_serving_endpoint}/invocations",
                             json=payload,
                             headers=headers
                             )
    # レスポンスから予測結果を取得
    predictions = response.json().get("choices")
    return predictions
def make_prediction_udf(model_serving_endpoint):
    @F.pandas_udf("string")
    def get_prediction_udf(batch_prompt: Iterator[pd.Series]) -> Iterator[pd.Series]:

        import mlflow

        max_tokens = 100  # 最大トークン数を設定
        temperature = 1.0  # 温度パラメータを設定
        api_root = mlflow.utils.databricks_utils.get_databricks_host_creds().host  # APIのルートURLを取得
        api_token = mlflow.utils.databricks_utils.get_databricks_host_creds().token  # APIトークンを取得

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {api_token}"  # 認証ヘッダーを設定
        }
        
        for batch in batch_prompt:
            
            result = []  # 結果を格納するリスト
            for prompt, max_tokens, temperature in batch[["prompt", "max_tokens", "temperature"]].itertuples(index=False):  
                data = {"prompt": prompt, "max_tokens": max_tokens, "temperature": temperature}  # リクエストデータを準備
                response = requests.post(
                    url=f"{api_root}/serving-endpoints/{model_serving_endpoint}/invocations",
                    json=data,
                    headers=headers  # リクエストを送信
                )
                if response.status_code == 200:  # レスポンスが成功した場合
                    endpoint_output = json.dumps(response.json())  # レスポンスデータをJSON文字列に変換
                    data = json.loads(endpoint_output)  # JSON文字列を辞書に変換
                    prediction = data.get("choices")  # 予測結果を取得
                    try:
                        predicted_docs = prediction[0]["text"]  # 予測テキストを取得
                        result.append(predicted_docs)  # 結果リストに追加
                    except IndexError as e:  # 予測テキストが存在しない場合
                        result.append("null")  # nullを追加
                else:  # レスポンスが失敗した場合
                    result.append(str(response.raise_for_status()))  # エラーメッセージを追加

        yield pd.Series(result)  # 結果のシリーズを生成
    return get_prediction_udf  # UDFを返す

get_prediction_udf = make_prediction_udf(MODEL_ENDPOINT_NAME)  # UDFを作成
predictions = get_predictions(prompts=eval_pdf["inputs"][0], 
                max_tokens=max_tokens, 
                temperature=temperature,
                model_serving_endpoint=MODEL_ENDPOINT_NAME)

print(predictions[0]["text"])
DatabricksでconnpassのJEDAIユーザー会のメンバー数を取得してみる

上記のコードでは、単一のプロンプトでエンドポイントをクエリできることがわかります。今度は、eval_dfのバッチのプロンプトを受け入れるために、プロンプトの構造を変更する番です。eval_dfのレコード数が多くなるほど処理に時間を要することになります。

predictions_df = eval_df.withColumn(
    "generated_title",
    get_prediction_udf(
        F.struct(
            F.col("inputs").alias("prompt"),  # 入力プロンプト
            F.lit(max_tokens).alias("max_tokens"),  # 最大トークン数
            F.lit(temperature).alias("temperature"),  # 温度パラメータ
        )
    ),
)
# 評価データフレームを保存する
predictions_df.write.mode("overwrite").saveAsTable(f"{CATALOG}.{USER_SCHEMA}.eval_blog_df")
display(spark.table(f"{CATALOG}.{USER_SCHEMA}.eval_blog_df"))

入力データと正解データ、そして、生成された結果を持つデータが得られました。

Screenshot 2024-10-10 at 17.50.15.png

LLMで生成されたタイトルを評価する

LLMで生成されたタイトルの品質を評価しましょう!

eval_pdf = spark.table(f"{CATALOG}.{USER_SCHEMA}.eval_blog_df").toPandas()
w = WorkspaceClient()
model_name = w.serving_endpoints.get(name=MODEL_ENDPOINT_NAME).config.served_entities[0].entity_name
model_version = 1
mlflow_client = MlflowClient(registry_uri="databricks-uc")

# 登録済みモデルのモデルバージョンオブジェクトを取得する
# 注意:UCの適切な権限がない場合は失敗します
mv = mlflow_client.get_model_version(name=model_name, version=model_version)
training_run_id = mv.run_id
with mlflow.start_run(run_id=training_run_id) as run: 
    # MLflowを使用してモデルの評価を実行
    results = mlflow.evaluate(data=eval_pdf, 
                              targets="ground_truth",
                              predictions="generated_title",
                              model_type="text",
                             )
    
    # 評価結果のメトリックをJSON形式で出力
    print(json.dumps(results.metrics, indent=2))
{
  "toxicity/v1/mean": 0.003505172091536224,
  "toxicity/v1/variance": 1.009375156297804e-05,
  "toxicity/v1/p90": 0.007169587071985006,
  "toxicity/v1/ratio": 0.0,
  "flesch_kincaid_grade_level/v1/mean": 7.040000000000001,
  "flesch_kincaid_grade_level/v1/variance": 64.43440000000001,
  "flesch_kincaid_grade_level/v1/p90": 15.64,
  "ari_grade_level/v1/mean": 98.64,
  "ari_grade_level/v1/variance": 2476.9864,
  "ari_grade_level/v1/p90": 144.84
}

LLMをジャッジとして使用する

上記のメトリクスに加えて、LLMをジャッジとして使用してさらにメトリクスを生成しましょう。既にデフォルトのメトリクスを生成しているため、model_type引数を削除して、LLM判定のメトリクスのみを生成します。

llm_judge = "endpoints:/databricks-dbrx-instruct"
# モデルの回答類似度メトリックを定義
answer_similarity_metric = answer_similarity(model=llm_judge)

with mlflow.start_run(run_id=training_run_id) as run: 
    # MLflowを使用してモデルの評価を実行し、追加メトリックを含める
    results = mlflow.evaluate(data=eval_pdf, 
                              targets="ground_truth",
                              predictions="generated_title",
                              extra_metrics=[answer_similarity_metric]
                             )
    # 評価結果のメトリックをJSON形式で出力
    print(json.dumps(results.metrics, indent=2))
{
  "answer_similarity/v1/mean": 3.0,
  "answer_similarity/v1/variance": 0.0,
  "answer_similarity/v1/p90": 3.0
}

エクスペリメント画面のEvaluationタブで正解データと生成された結果や評価指標を比較することができます。

お疲れ様でした!色々なデータやプロンプトでトライしてみてください!

モデルサービングエンドポイントを使わない場合には停止しておきましょう。

最終的には以下のアセットを作成しました。

Screenshot 2024-10-10 at 17.53.51.png

関連資料

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?