1
1

DatabricksマーケットプレイスからWhisper V3 Modelを試してみる

Last updated at Posted at 2023-12-12

久々にマーケットプレイスを覗いたらモデルが増えてました。Whisper V3 Modelとな。

音声テキスト変換のモデルなんですね。

モデルとノートブックの取得

Screenshot 2023-12-12 at 11.41.16.png

即時アクセス権を取得をクリックして、その他のオプションを展開します。モデルは新規の共有カタログ配下に格納されるので、カタログ名を変更します。
Screenshot 2023-12-12 at 11.41.37.png
Screenshot 2023-12-12 at 11.45.20.png

カタログエクスプローラでモデルにアクセスできます。
Screenshot 2023-12-12 at 11.46.15.png

サンプルノートブックもインポートします。
Screenshot 2023-12-12 at 11.43.31.png

ノートブックのウォークスルー

ノートブックの実行自体はCPUクラスターで問題ありません。Databricks SDKをインストールします。

# Upgrade to use the newest Databricks SDK
%pip install --upgrade databricks-sdk
dbutils.library.restartPython()
# Select the model from the dropdown list
model_names = ['whisper_large_v3']
dbutils.widgets.dropdown("model_name", model_names[0], model_names)

カタログ名は上で設定したものに変更します。必要に応じてモデルサービングエンドポイント名を変更します。

# Default catalog name when installing the model from Databricks Marketplace.
# Replace with the name of the catalog containing this model
# You can also specify a different model version to load for inference
catalog_name = "databricks_whisper_v3_model_taka"
version = "1"
model_name = dbutils.widgets.get("model_name")
model_uc_path = f"{catalog_name}.models.{model_name}"
endpoint_name = f'{model_name}_marketplace_taka'

モデルのサービングにはGPUが必要となります。

# Choose the right workload types based on the model size
workload_type = "GPU_MEDIUM"

SDKでモデルサービングエンドポイントをデプロイします。30分くらいかかります。

import datetime

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput
w = WorkspaceClient()

config = EndpointCoreConfigInput.from_dict({
    "served_models": [
        {
            "name": endpoint_name,
            "model_name": model_uc_path,
            "model_version": version,
            "workload_type": workload_type,
            "workload_size": "Small",
            "scale_to_zero_enabled": "False",
        }
    ]
})
model_details = w.serving_endpoints.create(name=endpoint_name, config=config)
model_details.result(timeout=datetime.timedelta(minutes=30))

Screenshot 2023-12-12 at 13.27.27.png

音声ファイル自体にはこちらからアクセスできます。

2つの音声ファイルがあります。聞いてみるとバイデン首相に関するニュース記事のようです。

データセットとしてロードしてモデルサービングエンドポイントにリクエストします。

from datasets import load_dataset
import pandas as pd
import base64
import json

from databricks.sdk import WorkspaceClient

dataset = load_dataset("Nexdata/accented_english", split="train")
sample_path = dataset[0]["audio"]["path"]

# Change it to your own input file name
with open(sample_path, 'rb') as audio_file:
    audio_bytes = audio_file.read()
    audio_b64 = base64.b64encode(audio_bytes).decode('ascii')

dataframe_records = [audio_b64]

w = WorkspaceClient()
response = w.serving_endpoints.query(
    name=endpoint_name,
    dataframe_records=dataframe_records,
)
print(response.predictions)
["The news forced the state to move immediately from planning the site's operational processes to building software, Baden said."]

(私の英語力では)合っています。

SparkのUDF(ユーザー定義関数)でモデルをラッピングすることで、バッチ推論で活用する小tゴアできます。

import mlflow
mlflow.set_registry_uri("databricks-uc")

catalog_name = "databricks_whisper_v3_models_taka"
transcribe = mlflow.pyfunc.spark_udf(spark, f"models:/{model_uc_path}/{version}", "string")
import pandas as pd
from datasets import load_dataset
import base64
import json

dataset = load_dataset("Nexdata/accented_english", split="train")
sample_path = dataset[0]["audio"]["path"]

with open(sample_path, 'rb') as audio_file:
    audio_bytes = audio_file.read()
    dataset = pd.DataFrame(pd.Series([audio_bytes]))

df = spark.createDataFrame(dataset)

# You can use the UDF directly on a text column
transcribed_df = df.select(transcribe(df["0"]).alias('transcription'))

display(transcribed_df)

データフレームに格納されている音声データからテキストに起こしてくれます。
Screenshot 2023-12-12 at 13.24.39.png

複数行であっても同じように処理してくれます。

dataset = load_dataset("Nexdata/accented_english", split="train")

sample_path_1 = dataset[0]["audio"]["path"]
sample_path_2 = dataset[1]["audio"]["path"]

with open(sample_path_1, 'rb') as audio_file_1:
    audio_bytes_1 = audio_file_1.read()

with open(sample_path_2, 'rb') as audio_file_2:
    audio_bytes_2 = audio_file_2.read()

dataset = pd.DataFrame(pd.Series([audio_bytes_1, audio_bytes_2]))

df = spark.createDataFrame(dataset)

# You can use the UDF directly on a text column
transcribed_df = df.select(transcribe(df["0"]).alias('transcription'))

display(transcribed_df)

これはこれで夢が広がりますね。
Screenshot 2023-12-12 at 13.25.44.png

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1