こちらのサンプルノートブックの一つ目をウォークスルーします。
オリジナルのノートブックはこちらです。
これは、Databricksでchronosモデルの使い方を説明するサンプルノートブックです。このノートブックでは、モデルをロードし、推論を分散処理し、モデルを登録し、モデルをデプロイしてオンライン予測を行います。
クラスターのセットアップ
Databricks Runtime 14.3 LTS for ML以降のクラスターを使うことをお勧めします。このクラスターは、それぞれのワーカーごとに1つ以上のGPUを持つシングルノードあるいはマルチノードで構いません。インスタンスタイプはAWSならg5.12xlarge [A10G]、AzureならStandard_NV72ads_A10_v5などになります。このノートブックでは、推論タスクを分散し、利用できる全てのリソースを活用するようにPandas UDFを活用します。
パッケージのインストール
%pip install git+https://github.com/amazon-science/chronos-forecasting.git --quiet
dbutils.library.restartPython()
データの準備
M4データをダウンロードするために datasetsforecast
パッケージを活用します。M4データセットには、テストのために使用する時系列データセットが含まれています。M4時系列を期待するフォーマットに変換するために記述したいくつかのカスタム関数については、 data_preparation
ノートブックをご覧ください。
カタログとスキーマがすでに作成済みであることを確認してください。
catalog = "users" # アセットを管理するために使用するカタログの名前
db = "takaaki_yayoi" # アセットを管理するために使用するスキーマの名前 (例: datasets)
n = 100 # サンプリングする時系列の数
# このセルではノートブック ../data_preparation を実行し、M4データを格納する以下のテーブルを作成します:
# 1. {catalog}.{db}.m4_daily_train
# 2. {catalog}.{db}.m4_monthly_train
dbutils.notebook.run("./99_data_preparation", timeout_seconds=0, arguments={"catalog": catalog, "db": db, "n": n})
from pyspark.sql.functions import collect_list
# データが存在することを確認
df = spark.table(f'{catalog}.{db}.m4_daily_train')
df = df.groupBy('unique_id').agg(collect_list('ds').alias('ds'), collect_list('y').alias('y'))
display(df)
推論の分散処理
推論を分散するためにPandas UDFを活用します。
import pandas as pd
import numpy as np
import torch
from typing import Iterator
from pyspark.sql.functions import pandas_udf
# ホライゾンのタイムスタンプを生成するためのPandas UDFを作成する関数
def create_get_horizon_timestamps(freq, prediction_length):
"""
指定された頻度と予測期間に基づいて未来のタイムスタンプを生成するPandas UDFを作成します。
パラメーター:
freq (str): タイムスタンプの頻度 ('M' は月末、それ以外は日時)
prediction_length (int): 生成する未来のタイムスタンプの数
戻り値:
関数: 入力されたそれぞれの時系列に対して未来のタイムスタンプの配列を生成するPandas UDF
"""
@pandas_udf('array<timestamp>')
def get_horizon_timestamps(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# 頻度に基づいて次のタイムスタンプのオフセットを特定
one_ts_offset = pd.offsets.MonthEnd(1) if freq == "M" else pd.DateOffset(days=1)
barch_horizon_timestamps = [] # 未来のタイムスタンプの配列を保持するリスト
# 入力時系列のバッチに対するイテレーション
for batch in batch_iterator:
for series in batch:
timestamp = last = series.max() # シリーズの最後のタイムスタンプを取得
horizon_timestamps = [] # 現在のシリーズの未来のタイムスタンプを保持するリスト
# 未来のタイムスタンプを生成
for i in range(prediction_length):
timestamp = timestamp + one_ts_offset
horizon_timestamps.append(timestamp.to_numpy())
barch_horizon_timestamps.append(np.array(horizon_timestamps))
yield pd.Series(barch_horizon_timestamps) # Pandasのシリーズとして結果を生成
return get_horizon_timestamps
# 予測を生成するPandas UDFを作成する関数
def create_forecast_udf(repository, prediction_length, num_samples, batch_size):
"""
指定されたリポジトリの事前学習済みモデルを用いて予測を生成するPandas UDFを作成します。
パラメーター:
repository (str): モデルリポジトリのパスあるいは識別子
prediction_length (int): 予測する未来の値の数
num_samples (int): それぞれの予測結果に対して生成するサンプルの数
batch_size (int): それぞれのバッチで処理する時系列の数
戻り値:
関数: それぞれの入力時系列の予測値の配列を生成するPandas UDF
"""
@pandas_udf('array<double>')
def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# 初期化ステップ
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
# リポジトリから事前学習済みモデルのロード
pipeline = ChronosPipeline.from_pretrained(repository, device_map="auto", torch_dtype=torch.bfloat16)
# 推論ステップ
for bulk in bulk_iterator:
median = [] # それぞれのシリーズの予測値の中央値を保持するリスト
# バッチにある時系列を処理
for i in range(0, len(bulk), batch_size):
batch = bulk[i:i+batch_size]
contexts = [torch.tensor(list(series)) for series in batch] # series を tensors に変換
# 事前学習モデルを用いた予測の生成
forecasts = pipeline.predict(context=contexts, prediction_length=prediction_length, num_samples=num_samples)
# それぞれのシリーズの予測の中央値を計算
median.extend([np.median(forecast, axis=0) for forecast in forecasts])
yield pd.Series(median) # Pandas Seriesとして結果を生成
return forecast_udf
予測のパラメーターを指定します。
chronos_model = "chronos-t5-tiny" # 他の選択肢: chronos-t5-mini, chronos-t5-small, chronos-t5-base, chronos-t5-large
prediction_length = 10 # 予測の時間ホライゾン
num_samples = 10 # 生成する予測の数。最終的な予測値として中央値を取ります。
batch_size = 4 # 同時に処理する時系列の数
freq = "D" # 時系列の頻度
device_count = torch.cuda.device_count() # 利用できるGPUの数
それでは予測結果を生成しましょう。
# 指定された頻度と予測長を用いて未来のタイムスタンプを生成するPandas UDFの作成
get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length)
# 指定したリポジトリの学習済みモデルを用いて予測結果を生成するPandas UDFの作成
forecast_udf = create_forecast_udf(
repository=f"amazon/{chronos_model}", # モデルリポジトリのあるいは識別子
prediction_length=prediction_length, # 予測する未来の値の数
num_samples=num_samples, # それぞれの予測結果に対して生成するサンプルの数
batch_size=batch_size, # それぞれのバッチで処理する時系列の数
)
# データフレームにUDFを適用し、適切な列を選択
forecasts = df.repartition(device_count).select(
df.unique_id, # それぞれの時系列の固有のIDを選択
get_horizon_timestamps(df.ds).alias("ds"), # それぞれの時系列の未来のタイムスタンプを生成し、別名を付与
forecast_udf(df.y).alias("forecast") # それぞれの時系列の予測値を生成し、別名を付与
)
# 予測値を含む結果のデータフレームを表示
display(forecasts)
モデルの登録
mlflow.pyfunc.PythonModel
を用いてモデルをパッケージングし、Unity Catalogに登録します。
import mlflow
import torch
import numpy as np
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, TensorSpec
# Databricks Unity Catalogを使うようにMLflowレジストリURIを設定
mlflow.set_registry_uri("databricks-uc")
experiment_name = "/Workspace/Users/takaaki.yayoi@databricks.com/chronos/"
# ChronosパイプラインのカスタムMLflowモデルクラスを定義
class ChronosModel(mlflow.pyfunc.PythonModel):
def __init__(self, repository):
import torch
from chronos import ChronosPipeline
# 指定したリポジトリの事前学習済みモデルを用いてChronosPipelineを初期化
self.pipeline = ChronosPipeline.from_pretrained(
repository,
device_map="cuda", # 推論にGPUを使用
torch_dtype=torch.bfloat16, # bfloat16精度を使用
)
def predict(self, context, input_data, params=None):
# 入力データをPyTorchのtensorのリストに変換
history = [torch.tensor(list(series)) for series in input_data]
# ChronosPipelineを用いた予測の生成
forecast = self.pipeline.predict(
context=history,
prediction_length=10, # 予測ホライゾンの長さ
num_samples=10, # 生成するサンプルの数
)
return forecast.numpy() # 予測結果をNumPy配列に変換
# 指定したリポジトリを用いてカスタムモデルのインスタンスを作成
pipeline = ChronosModel(f"amazon/{chronos_model}")
# モデルシグネチャのための入力と出力のスキーマを定義
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1, -1))]) # Input: doubleの2D array
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1, -1, -1))]) # Output: unsigned 8-bit integerの3D array
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# モデルのサンプル入力の作成 (1 sample, 52 features)
input_example = np.random.rand(1, 52)
# catalog.database.model_name のフォーマットで登録モデル名を定義
registered_model_name = f"{catalog}.{db}.{chronos_model}"
# 現在のエクスペリメントを設定
mlflow.set_experiment(experiment_name)
# MLflowによるモデルの記録と登録
with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
"model", # モデルアーティファクトのパス
python_model=pipeline, # カスタムモデルクラスのインスタンス
registered_model_name=registered_model_name, # モデルの登録名
signature=signature, # モデルシグネチャ
input_example=input_example, # 入力サンプル
pip_requirements=[ # pip要件のリスト
f"git+https://github.com/amazon-science/chronos-forecasting.git",
],
)
モデルのリロード
登録が完了したら、モデルをリロードして予測を生成します。
from mlflow import MlflowClient
client = MlflowClient()
# 登録モデルの最新バージョンを取得する関数
def get_latest_model_version(client, registered_model_name):
latest_version = 1 # 最新バージョン1で初期化
# 指定したモデル名の全てのモデルバージョンに対するループ
for mv in client.search_model_versions(f"name='{registered_model_name}'"):
version_int = int(mv.version) # バージョンの文字列を整数に変換
# 大きいバージョンがある場合には最新バージョンを更新
if version_int > latest_version:
latest_version = version_int
return latest_version # 最新バージョンを返却
# 指定した登録モデルの最新バージョンを取得
model_version = get_latest_model_version(client, registered_model_name)
# 登録モデルと最新バージョンを用いてモデルのURIを構築
logged_model = f"models:/{registered_model_name}/{model_version}"
# 指定したURIでPyFuncModelとしてモデルをロード
loaded_model = mlflow.pyfunc.load_model(logged_model)
# ランダムな入力データを作成 (5サンプル、それぞれは52データポイント)
input_data = np.random.rand(5, 52) # (batch, series)
# ロードしたモデルを用いて予測結果を生成
loaded_model.predict(input_data)
予測結果
array([[[7.20715874e-01, 6.79424899e-01, 6.75671174e-01, 7.88282985e-01,
8.10805397e-01, 6.53148763e-01, 5.74320477e-01, 6.00596552e-01,
6.45641313e-01, 6.64409938e-01],
[7.24469599e-01, 6.56902488e-01, 6.90686074e-01, 7.54499460e-01,
7.28223324e-01, 5.78074202e-01, 5.70566752e-01, 4.91738465e-01,
5.70566752e-01, 4.57954910e-01],
[6.49395038e-01, 6.86932349e-01, 6.11857788e-01, 7.58253185e-01,
7.05700974e-01, 6.19365238e-01, 7.39484560e-01, 5.70566752e-01,
6.15611513e-01, 6.00596552e-01],
[8.14559122e-01, 5.89335377e-01, 7.01947249e-01, 6.98193524e-01,
7.58253185e-01, 7.58253185e-01, 7.65760635e-01, 7.46992010e-01,
6.56902488e-01, 7.16962149e-01],
[7.01947249e-01, 6.41887588e-01, 6.49395038e-01, 7.84529260e-01,
6.49395038e-01, 7.69514360e-01, 7.31977110e-01, 6.23118963e-01,
5.55551851e-01, 4.95492190e-01],
[8.29574022e-01, 5.74320477e-01, 7.15267724e-09, 8.44588922e-01,
7.73268085e-01, 6.53148763e-01, 5.70566752e-01, 3.67865418e-01,
6.45641313e-01, 5.81827927e-01],
[8.18312847e-01, 6.19365238e-01, 6.56902488e-01, 6.64409938e-01,
7.01947249e-01, 5.85581652e-01, 6.60656213e-01, 6.41887588e-01,
6.90686074e-01, 6.49395038e-01],
[6.19365238e-01, 7.09454699e-01, 6.75671174e-01, 6.19365238e-01,
7.05700974e-01, 6.41887588e-01, 6.98193524e-01, 6.34380138e-01,
6.00596552e-01, 6.86932349e-01],
[8.18312847e-01, 6.11857788e-01, 7.50745735e-01, 7.99544222e-01,
6.04350338e-01, 7.09454699e-01, 7.01947249e-01, 6.15611513e-01,
6.79424899e-01, 5.21768265e-01],
[5.93089102e-01, 5.59305576e-01, 5.59305576e-01, 6.64409938e-01,
6.41887588e-01, 7.09454699e-01, 5.78074202e-01, 5.55551851e-01,
6.60656213e-01, 6.34380138e-01]],
[[4.52397114e-01, 4.45934304e-01, 3.45760672e-01, 1.93884494e-01,
1.29256330e-01, 9.69422471e-02, 1.55107595e-01, 1.61570405e-01,
9.04794307e-02, 4.20083104e-02],
[4.94405431e-01, 3.26372216e-01, 3.52223482e-01, 4.81479811e-01,
3.45760672e-01, 2.10041532e-01, 1.71264633e-01, 1.16330697e-01,
3.87769021e-02, 9.69423046e-03],
[6.20430329e-01, 4.52397114e-01, 4.75017001e-01, 6.20430329e-01,
4.58859924e-01, 6.17198924e-01, 3.10215165e-01, 3.06983760e-01,
3.58686291e-01, 3.65149101e-01],
[4.97636836e-01, 4.75017001e-01, 4.49165709e-01, 4.75017001e-01,
6.01041899e-01, 6.23661734e-01, 6.72132861e-01, 7.14141178e-01,
7.30298203e-01, 7.36761013e-01],
[4.78248406e-01, 4.65322786e-01, 4.91174026e-01, 5.17025266e-01,
4.75017001e-01, 4.65322786e-01, 5.04099646e-01, 5.23488128e-01,
5.33182343e-01, 5.75190660e-01],
[6.39818811e-01, 4.36240089e-01, 4.71785596e-01, 4.84711216e-01,
3.97463203e-01, 3.74843342e-01, 3.81306152e-01, 4.39471494e-01,
4.10388823e-01, 4.23314469e-01],
[6.30124597e-01, 2.90826774e-02, 6.26893139e-01, 2.52049823e-01,
2.90826774e-02, 2.03578709e-01, 5.17025351e-02, 8.07852092e-02,
2.06810114e-01, 2.55281228e-01],
[6.43050216e-01, 4.91174026e-01, 1.64801810e-01, 1.71264633e-01,
4.91174026e-01, 6.33356001e-01, 6.59207241e-01, 6.85058533e-01,
6.43050216e-01, 5.81653470e-01],
[3.45760672e-01, 2.16504342e-01, 2.35892798e-01, 1.90653076e-01,
2.32661393e-01, 2.03578709e-01, 1.68033228e-01, 1.61570405e-01,
1.97115899e-01, 2.00347304e-01],
[4.42702899e-01, 4.55628519e-01, 4.81479811e-01, 4.91174026e-01,
4.26545874e-01, 4.10388823e-01, 3.71611937e-01, 3.68380533e-01,
3.71611937e-01, 3.39297862e-01]],
[[5.47164089e-01, 5.16766054e-01, 4.78768555e-01, 5.66162823e-01,
5.81361810e-01, 6.42157820e-01, 6.64956301e-01, 6.38358073e-01,
6.19359340e-01, 5.81361810e-01],
[3.41977548e-01, 5.01567066e-01, 5.01567066e-01, 4.59769790e-01,
4.36971310e-01, 4.63569568e-01, 4.67369315e-01, 3.83774793e-01,
3.95174034e-01, 3.83774793e-01],
[5.09166560e-01, 4.78768555e-01, 5.20565800e-01, 5.88961304e-01,
5.01567066e-01, 4.59769790e-01, 5.73762317e-01, 6.00360544e-01,
6.23159086e-01, 5.96560797e-01],
[3.30578277e-01, 8.20746103e-01, 5.31965041e-01, 5.73762317e-01,
6.04160291e-01, 6.83955096e-01, 5.77562063e-01, 6.00360544e-01,
6.42157820e-01, 6.53557060e-01],
[3.41977548e-01, 4.86368049e-01, 4.21772292e-01, 4.67369315e-01,
5.62363076e-01, 5.96560797e-01, 6.57356807e-01, 6.19359340e-01,
5.92761050e-01, 5.96560797e-01],
[5.50963836e-01, 5.54763583e-01, 6.15559593e-01, 5.12966307e-01,
5.69962570e-01, 5.73762317e-01, 6.11759846e-01, 6.15559593e-01,
6.45957567e-01, 5.28165294e-01],
[6.34558327e-01, 6.64956301e-01, 5.43364281e-01, 6.23159086e-01,
7.25752311e-01, 7.18152817e-01, 7.21952564e-01, 7.40951359e-01,
6.95354337e-01, 7.14353070e-01],
[5.09166560e-01, 5.31965041e-01, 4.93967542e-01, 5.69962570e-01,
6.23159086e-01, 6.72555794e-01, 6.04160291e-01, 5.39564534e-01,
5.62363076e-01, 5.43364281e-01],
[3.41977548e-01, 3.64776029e-01, 3.64776029e-01, 3.68575806e-01,
3.34378055e-01, 3.03980050e-01, 3.03980050e-01, 2.69782298e-01,
3.22978783e-01, 3.30578277e-01],
[4.90167795e-01, 6.30758580e-01, 5.66162823e-01, 5.28165294e-01,
5.92761050e-01, 6.64956301e-01, 6.76355603e-01, 6.07960038e-01,
6.80155350e-01, 6.30758580e-01]],
[[5.41081127e-01, 5.47930249e-01, 5.13684640e-01, 6.16421593e-02,
5.13684640e-01, 5.51354866e-01, 5.17109201e-01, 6.50667272e-02,
7.19158560e-02, 9.93123712e-02],
[5.47930249e-01, 5.03410957e-01, 4.52042487e-01, 4.48617927e-01,
4.58891609e-01, 5.03410957e-01, 4.38344216e-01, 5.47930249e-01,
4.65740731e-01, 5.27382883e-01],
[4.41768777e-01, 3.04786229e-01, 5.13684695e-02, 4.79439051e-02,
2.97937079e-01, 2.05473916e-02, 2.05473916e-02, 7.53404169e-02,
6.50667272e-02, 6.50667272e-02],
[3.76702119e-02, 7.53404169e-02, 8.21895457e-02, 8.56141136e-02,
2.73965205e-02, 8.56141136e-02, 7.87649848e-02, 7.19158560e-02,
3.08210831e-02, 2.73965205e-02],
[3.15059911e-01, 5.06835518e-01, 5.06835518e-01, 4.72589853e-01,
5.10260079e-01, 3.04786229e-01, 6.16421593e-02, 2.73965205e-02,
2.80814274e-01, 1.36982628e-02],
[4.69165292e-01, 5.20533762e-01, 5.47930249e-01, 5.10260079e-01,
5.47930249e-01, 5.58203987e-01, 4.52042487e-01, 5.34232005e-01,
4.34919655e-01, 4.89712658e-01],
[5.41081127e-01, 2.60266881e-01, 2.19172108e-01, 1.91775607e-01,
1.67803652e-01, 1.98624728e-01, 2.29445805e-01, 1.36982576e-01,
7.19158560e-02, 3.76702119e-02],
[4.93137274e-01, 4.76014414e-01, 5.03410957e-01, 2.49993198e-01,
5.03410957e-01, 5.23958323e-01, 5.03410957e-01, 5.65053109e-01,
6.09572457e-01, 5.85600475e-01],
[5.20533762e-01, 4.93137274e-01, 4.96561835e-01, 5.17109201e-01,
4.86288097e-01, 4.93137274e-01, 5.37656566e-01, 5.37656566e-01,
4.52042487e-01, 4.48617927e-01],
[8.21895457e-02, 1.16435190e-01, 4.10947763e-02, 1.19859758e-01,
9.24632424e-02, 5.47930339e-02, 3.42456475e-02, 5.82175984e-02,
1.60954517e-01, 2.73965205e-02]],
[[4.91494607e-01, 5.13667274e-01, 5.72794447e-01, 6.24530731e-01,
5.83880780e-01, 6.31921620e-01, 6.13444397e-01, 8.01912250e-01,
8.72125756e-01, 8.90602979e-01],
[3.47372119e-01, 5.43230890e-01, 6.06053508e-01, 5.94967174e-01,
7.02135126e-01, 7.39089632e-01, 7.35394188e-01, 7.79739522e-01,
8.35171251e-01, 7.05830571e-01],
[4.95190051e-01, 5.72794447e-01, 6.06053508e-01, 6.54094287e-01,
6.28226175e-01, 6.46703398e-01, 8.16694028e-01, 7.39089632e-01,
8.16694028e-01, 7.39089632e-01],
[3.76935705e-01, 6.79962459e-01, 6.94744237e-01, 5.83880780e-01,
6.72571570e-01, 6.39312509e-01, 7.28003298e-01, 7.09526016e-01,
8.72125756e-01, 8.38866695e-01],
[4.02803847e-01, 6.65180681e-01, 8.09303139e-01, 6.91048793e-01,
6.68876125e-01, 8.09303139e-01, 6.87353348e-01, 7.46480521e-01,
7.31698743e-01, 8.20389472e-01],
[3.39981230e-01, 4.80408243e-01, 6.68876125e-01, 7.53871410e-01,
6.91048793e-01, 7.57566855e-01, 6.94744237e-01, 6.91048793e-01,
8.20389472e-01, 8.38866695e-01],
[4.36062878e-01, 5.28449052e-01, 6.02358063e-01, 5.76489891e-01,
6.28226175e-01, 7.20612409e-01, 6.79962459e-01, 7.72348633e-01,
8.53648533e-01, 7.98216805e-01],
[4.65626465e-01, 5.39535446e-01, 4.61931020e-01, 6.06053508e-01,
6.09748952e-01, 7.46480521e-01, 7.68653189e-01, 8.79516645e-01,
8.72125756e-01, 8.64734867e-01],
[3.76935705e-01, 6.09748952e-01, 7.20612409e-01, 6.28226175e-01,
7.24307854e-01, 6.94744237e-01, 7.31698743e-01, 7.57566855e-01,
8.75821201e-01, 8.61039423e-01],
[4.32367434e-01, 4.69321909e-01, 6.46703398e-01, 7.20612409e-01,
6.31921620e-01, 6.24530731e-01, 7.28003298e-01, 8.31475806e-01,
7.98216805e-01, 8.31475806e-01]]])
モデルのデプロイ
Databricks Mosaic AI Model Servingのリアルタイムエンドポイントにモデルをデプロイします。
# トークンを用いることで以降のREST呼び出しの認証ヘッダーを作成することができます
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
# 次に、あなたのリクエストを処理するエンドポイントが必要となりますが、これはノートブックのtagsコレクションから取得できます
java_tags = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags()
# このオブジェクトはJava CMと共に提供されます - Java MapオブジェクトをPythonのディクショナリーに変換します
tags = sc._jvm.scala.collection.JavaConversions.mapAsJavaMap(java_tags)
# 最後に、ディクショナリーからDatabricksインスタンス(ドメイン名)を抽出します
instance = tags["browserHostName"]
import requests
model_serving_endpoint_name = chronos_model
# auto_capture_configでは、推論ログをどこに書き込むのかを指定します
my_json = {
"name": model_serving_endpoint_name,
"config": {
"served_models": [
{
"model_name": registered_model_name,
"model_version": model_version,
"workload_type": "GPU_SMALL",
"workload_size": "Small",
"scale_to_zero_enabled": "true",
}
],
"auto_capture_config": {
"catalog_name": catalog,
"schema_name": db,
"table_name_prefix": model_serving_endpoint_name,
},
},
}
# 推論テーブルがある場合には削除します
_ = spark.sql(
f"DROP TABLE IF EXISTS {catalog}.{db}.`{model_serving_endpoint_name}_payload`"
)
# モデルサービングにエンドポイントを作成し、モデルをデプロイする関数
def func_create_endpoint(model_serving_endpoint_name):
# エンドポイントのステータスの取得
endpoint_url = f"https://{instance}/api/2.0/serving-endpoints"
url = f"{endpoint_url}/{model_serving_endpoint_name}"
r = requests.get(url, headers=headers)
if "RESOURCE_DOES_NOT_EXIST" in r.text:
print(
"Creating this new endpoint: ",
f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations",
)
re = requests.post(endpoint_url, headers=headers, json=my_json)
else:
new_model_version = (my_json["config"])["served_models"][0]["model_version"]
print(
"This endpoint existed previously! We are updating it to a new config with new model version: ",
new_model_version,
)
# 設定の更新
url = f"{endpoint_url}/{model_serving_endpoint_name}/config"
re = requests.put(url, headers=headers, json=my_json["config"])
# 新規設定ファイルが配置されるのを待ちます
import time, json
# エンドポイントのステータスの取得
url = f"https://{instance}/api/2.0/serving-endpoints/{model_serving_endpoint_name}"
retry = True
total_wait = 0
while retry:
r = requests.get(url, headers=headers)
assert (
r.status_code == 200
), f"Expected an HTTP 200 response when accessing endpoint info, received {r.status_code}"
endpoint = json.loads(r.text)
if "pending_config" in endpoint.keys():
seconds = 10
print("New config still pending")
if total_wait < 6000:
# 10分待っていない場合には、待機を継続
print(f"Wait for {seconds} seconds")
print(f"Total waiting time so far: {total_wait} seconds")
time.sleep(10)
total_wait += seconds
else:
print(f"Stopping, waited for {total_wait} seconds")
retry = False
else:
print("New config in place now!")
retry = False
assert (
re.status_code == 200
), f"Expected an HTTP 200 response, received {re.status_code}"
# モデルサービングからエンドポイントを削除する関数
def func_delete_model_serving_endpoint(model_serving_endpoint_name):
endpoint_url = f"https://{instance}/api/2.0/serving-endpoints"
url = f"{endpoint_url}/{model_serving_endpoint_name}"
response = requests.delete(url, headers=headers)
if response.status_code != 200:
raise Exception(
f"Request failed with status {response.status_code}, {response.text}"
)
else:
print(model_serving_endpoint_name, "endpoint is deleted!")
return response.json()
# エンドポイントを作成。これにはある程度の時間を要します
func_create_endpoint(model_serving_endpoint_name)
import time, mlflow
def wait_for_endpoint():
# サービングエンドポイントAPIのベースURLの構成
endpoint_url = f"https://{instance}/api/2.0/serving-endpoints"
while True:
# 特定のモデルサービングエンドポイントのURLを構成
url = f"{endpoint_url}/{model_serving_endpoint_name}"
# エンドポイントURLにGETリクエストを送信
response = requests.get(url, headers=headers)
# レスポンスのステータスコードが 200 (OK) であることを確認
assert (
response.status_code == 200
), f"Expected an HTTP 200 response, received {response.status_code}\n{response.text}"
# レスポンスからエンドポイントのステータスを抽出
status = response.json().get("state", {}).get("ready", {})
# エンドポイントの準備ができたら、ステータスを表示して返却
if status == "READY":
print(status)
print("-" * 80)
return
else:
# エンドポイントが準備中の場合、ステータスを表示してさらに5分待機
print(f"Endpoint not ready ({status}), waiting 5 minutes")
time.sleep(300) # 300 seconds (5 minutes) 待機
# DatabricksインスタンスのAPI URLを取得
api_url = mlflow.utils.databricks_utils.get_webapp_url()
# エンドポイントの準備ができるまで待機する関数の呼び出し
wait_for_endpoint()
30分程度でデプロイされました。
Endpoint not ready (NOT_READY), waiting 5 minutes
Endpoint not ready (NOT_READY), waiting 5 minutes
Endpoint not ready (NOT_READY), waiting 5 minutes
Endpoint not ready (NOT_READY), waiting 5 minutes
Endpoint not ready (NOT_READY), waiting 5 minutes
READY
--------------------------------------------------------------------------------
オンライン予測
エンドポイントの準備ができたら、モデルにリクエストを行い、オンライン予測結果を生成しましょう。
import os
import requests
import pandas as pd
import json
import matplotlib.pyplot as plt
# モデルサービングページで取得できるエンドポイント呼び出しURLで置き換え
endpoint_url = f"https://{instance}/serving-endpoints/{model_serving_endpoint_name}/invocations"
# Databricks APIトークンの取得
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
# モデルサービングエンドポイントに入力データを送信し、予測結果を取得する関数の定義
def forecast(input_data, url=endpoint_url, databricks_token=token):
# 認証トークンを含むPOSTリクエストのヘッダーのセットアップ
headers = {
"Authorization": f"Bearer {databricks_token}",
"Content-Type": "application/json",
}
# 入力データを持つリクエストボディの準備
body = {"inputs": input_data.tolist()}
# ボディをJSON文字列に変換
data = json.dumps(body)
# モデルサービングエンドポイントにPOSTリクエストを送信
response = requests.request(method="POST", headers=headers, url=url, data=data)
# レスポンスのステータスコードが 200 (OK) でないかチェック
if response.status_code != 200:
# リクエストが失敗したら例外
raise Exception(
f"Request failed with status {response.status_code}, {response.text}"
)
# PythonディクショナリーとしてレスポンスのJSONを返却
return response.json()
# エンドポイントにリクエストを送信
input_data = np.random.rand(5, 52) # (batch, series)
forecast(input_data)
予測結果
{'predictions': [[[0.646409909692345,
0.6587616798929878,
0.7122860843763702,
0.6052373423568693,
0.6587616798929878,
0.6134718558239645,
0.6669962603073469,
0.6052373423568693,
0.5187748505314743,
0.5722992215412247],
[0.6052373423568693,
0.6669962603073469,
0.5764164782747723,
0.7575759753926575,
0.6711135170408945,
0.6422926529587974,
0.6834652872415372,
0.6011200856233218,
0.617589112557512,
0.7328724349913721],
[0.6505271664258926,
0.5270093639985695,
0.5475956811399393,
0.559947451340582,
0.48995401992300924,
0.47348499298881896,
0.3252635832129463,
0.3417326436207686,
0.4405469056468064,
0.5023058235972839],
[0.6340581394917023,
0.7411069484584673,
0.6752307737744421,
0.7328724349913721,
0.5475956811399393,
0.5146575937979267,
0.5228921072650219,
0.5517129378734869,
0.6628790035737994,
0.7122860843763702],
[0.5928855721562266,
0.5228921072650219,
0.5640647080741296,
0.7781622590603954,
0.5105403370643791,
0.7452242051920148,
0.5640647080741296,
0.5558301946070344,
0.5722992215412247,
0.5764164782747723],
[0.47760224972236653,
0.6958170574421799,
0.4611331893145442,
0.6258236260246072,
0.5764164782747723,
0.6093545990904169,
0.4405469056468064,
0.617589112557512,
0.646409909692345,
0.4611331893145442],
[0.7493414619255624,
0.7987485427281333,
0.4611331893145442,
0.5475956811399393,
0.5023058235972839,
0.4817195064559141,
0.559947451340582,
0.4405469056468064,
0.531126620732117,
0.42407784523898406],
[0.49818853339010444,
0.5558301946070344,
0.4817195064559141,
0.49818853339010444,
0.5434784244063917,
0.48995401992300924,
0.3211463264793987,
0.5475956811399393,
0.7205205978434654,
0.7287551113105606],
[0.5970028288897742,
0.7369896917249197,
0.6505271664258926,
0.531126620732117,
0.6258236260246072,
0.5805337350083198,
0.5558301946070344,
0.6669962603073469,
0.5887682484754151,
0.5270093639985695],
[0.5270093639985695,
0.5764164782747723,
0.5640647080741296,
0.7369896917249197,
0.5105403370643791,
0.5064230803308315,
0.47760224972236653,
0.42407784523898406,
0.4405469056468064,
0.5928855721562266]],
[[0.5607117454064976,
0.5714259725930186,
0.42142667583833965,
0.5464261091578029,
0.024999894285802227,
0.19999911073264875,
0.20357051979482244,
0.007142831734900204,
0.007142831734900204,
0.003571419270088309],
[0.25357030473694614,
0.2714273500478146,
0.032142717854370735,
0.40714103958964487,
0.48928350609133175,
0.01785706890249334,
0.0142856562108389,
0.06428542844977995,
0.024999894285802227,
0.06785684477091515],
[0.40714103958964487,
0.27857016817216196,
0.13928509860400398,
0.2749987591099883,
0.2749987591099883,
0.2642845319234672,
0.20714192885699612,
0.25357030473694614,
0.2607131228612935,
0.1678563856193165],
[0.20714192885699612,
0.18928486902820468,
0.032142717854370735,
0.5607117454064976,
0.41785526677616597,
0.024999894285802227,
0.03571413054602517,
0.04999978131264294,
0.1678563856193165,
0.3392841803008067],
[0.31785572592776457,
0.20714192885699612,
0.15714215843279544,
0.31785572592776457,
0.35714125464752117,
0.3214271349899383,
0.19285627809037836,
0.22499898868578758,
0.22857039774796128,
0.11428520613294214],
[0.41785526677616597,
0.41428385771399223,
0.19285627809037836,
0.007142831734900204,
0.042856955929334056,
0.1678563856193165,
0.28928442439452906,
0.39999819242945145,
0.3428555893629804,
0.032142717854370735],
[0.5035691423400265,
0.40714103958964487,
0.18571345996603097,
0.41071244865181855,
0.3928553743051041,
0.010714244426554645,
0.007142831734900204,
0.24999886663892643,
0.3499984365231738,
0.48928350609133175],
[0.33214136217645934,
0.3392841803008067,
0.33214136217645934,
0.33214136217645934,
0.3392841803008067,
0.20357051979482244,
0.1821420363859343,
0.15357074937062176,
0.15357074937062176,
0.2214275796236139],
[0.5142834275982396,
0.4035696014916252,
6.805276414202911e-09,
0.032142717854370735,
0.41428385771399223,
0.5999973031621002,
0.3107128787675712,
0.33214136217645934,
0.4321409030248607,
0.2999986515810501],
[0.5142834275982396,
0.20357051979482244,
0.36071266370969485,
0.38214114711858305,
0.4392837501850541,
0.4678550517182896,
0.49285491515350544,
0.4678550517182896,
0.5178548366604133,
0.5178548366604133]],
[[0.7935291572924906,
0.36624422389773725,
0.36624422389773725,
0.36624422389773725,
0.28485660020692927,
0.33775855726017245,
7.75414676426538e-09,
0.9156105101178038,
0.012208148309497884,
0.05697132913958472],
[0.40286864297907504,
0.3581054714539643,
0.321481019288267,
0.3987992667571886,
0.3418279334820589,
0.3540360621477183,
0.05697132913958472,
0.8627086192332797,
0.016277527633043062,
0.4272849333947534],
[0.7935291572924906,
0.7772515862362257,
0.7894597149018852,
0.20346900960048075,
0.7406272002392474,
0.19533025715670782,
0.03255504906276871,
0.8627086192332797,
0.8423616719551283,
0.8260841670675825],
[0.411007395422848,
0.3743829763415102,
0.4069380192009615,
0.6388926623546476,
0.30113413817883466,
0.6226151574671017,
0.3458973097039454,
0.30113413817883466,
0.3092728906226076,
0.822014790845696],
[0.40286864297907504,
0.6714477382984584,
0.3621748476758508,
0.3703136001196237,
0.30520351440072113,
0.048832568424721896,
0.03255504906276871,
0.8627086192332797,
0.2726484715412698,
0.4557705669479587],
[0.37845235256339665,
0.411007395422848,
0.37845235256339665,
0.31334226684449407,
0.28485660020692927,
0.3621748476758508,
0.2970647619569482,
0.8545698006207877,
0.8342229195113554,
0.8342229195113554],
[0.7813209624581122,
0.325550428594513,
0.7528353289049069,
0.2970647619569482,
0.3947298905353021,
0.31741164306638053,
0.04476318806729048,
0.8098066621800365,
0.3092728906226076,
0.40286864297907504],
[0.3581054714539643,
0.38659110500716964,
0.36624422389773725,
0.040693807709859066,
0.37845235256339665,
0.004069387853106618,
0.012208148309497884,
0.8708473716770526,
0.020346907990474474,
0.4069380192009615],
[0.7691128337924528,
0.3947298905353021,
0.411007395422848,
0.38659110500716964,
0.6592396096327989,
0.0284856687053373,
0.3987992667571886,
0.008138767952066472,
0.7813209624581122,
0.8098066621800365],
[0.8057372859581501,
0.744696576461134,
0.7040027480735501,
0.34996668592583186,
0.30113413817883466,
0.3418279334820589,
0.325550428594513,
0.08545699164160461,
0.8789861241208257,
0.8749167478989391]],
[[0.5190475862349077,
0.20326341208167345,
0.21052281227857972,
0.22504162742727066,
0.22141192732881754,
0.03629704156044702,
0.03629704156044702,
0.0036297104729770373,
0.007259414029604802,
0.007259414029604802],
[0.8856476502957563,
0.1851148820796509,
0.22141192732881754,
0.2359307424775085,
0.8820179501973032,
0.8856476502957563,
0.23956044257596162,
0.032667337773274276,
0.25044954287132104,
0.23956044257596162],
[0.9037962098075358,
0.15607726653714743,
0.03992674534761975,
0.032667337773274276,
0.8892773503942095,
0.81305358930718,
0.13429903643667176,
0.032667337773274276,
0.8856476502957563,
0.8566100495081314],
[1.0090575716821903,
0.9037962098075358,
0.2286713275257238,
0.9219447102998014,
0.9437229108905202,
0.9146853101028952,
0.9292041104967077,
0.9001665097090826,
0.9437229108905202,
0.8711288499019438],
[0.9400932107920671,
0.9037962098075358,
0.17785548188274466,
0.17785548188274466,
0.8711288499019438,
0.9437229108905202,
0.15607726653714743,
0.7077922274325255,
0.7114219275309787,
0.6932733680191991],
[0.8820179501973032,
0.8856476502957563,
0.9074259099059889,
0.911055610004442,
0.9146853101028952,
0.9074259099059889,
0.8929071095121764,
0.9074259099059889,
0.9328338105951608,
0.9219447102998014],
[1.0054278715837373,
0.5263070454513277,
0.2722277729717967,
1.027206072174456,
1.0163169718790965,
0.998168471386831,
0.22141192732881754,
0.061704968070656195,
0.972760570697659,
0.8166832894056332],
[0.9509823110874265,
0.23956044257596162,
0.22504162742727066,
0.8892773503942095,
0.8892773503942095,
0.8638694497050377,
0.16696636683250682,
0.047186152921965235,
0.9655011705007528,
0.9292041104967077],
[1.0017981714852842,
0.9219447102998014,
0.23230102762417693,
0.9328338105951608,
0.9800199708945653,
0.9655011705007528,
0.20689311218012657,
0.23230102762417693,
1.0054278715837373,
0.8929071095121764],
[0.9146853101028952,
0.25770897257798414,
0.23230102762417693,
0.9473526109889734,
0.9691308705992059,
0.9400932107920671,
0.9400932107920671,
0.9400932107920671,
0.9328338105951608,
0.9400932107920671]],
[[0.8311700617297251,
0.850866080594575,
0.8784404173327253,
0.8587444625197608,
0.8587444625197608,
0.8626836534823537,
0.8626836534823537,
0.850866080594575,
0.850866080594575,
0.941467536786097],
[0.8784404173327253,
0.9572243006364685,
0.9375283458235041,
0.9690419375761327,
0.886318799257911,
0.9375283458235041,
0.9375283458235041,
1.0005554652768758,
0.9611634915990613,
0.9020756271601681],
[0.9375283458235041,
0.9375283458235041,
0.9572243006364685,
0.941467536786097,
0.8823796082953181,
0.9335891548609112,
0.9611634915990613,
0.8981363721456896,
0.9335891548609112,
0.9572243006364685],
[0.850866080594575,
0.799656534028982,
0.8351093167442035,
0.803595724991575,
0.8745012263701324,
0.748446987463389,
0.8469268896319821,
0.799656534028982,
0.748446987463389,
0.7523861784259819],
[0.8941971811830968,
0.906014818122761,
0.9769203195013184,
0.9493459187112827,
0.9493459187112827,
0.9690419375761327,
1.0044946562394685,
0.8666228444449466,
0.9808595104639113,
0.8981363721456896],
[0.8784404173327253,
0.9217715819731325,
0.8587444625197608,
0.8429876986693893,
0.8626836534823537,
0.8626836534823537,
0.8941971811830968,
0.8351093167442035,
0.8390485077067964,
0.8587444625197608],
[0.8075349159541678,
0.803595724991575,
0.8902579902205039,
0.8311700617297251,
0.8784404173327253,
0.886318799257911,
0.8784404173327253,
0.8429876986693893,
0.850866080594575,
0.8075349159541678],
[0.8587444625197608,
0.9375283458235041,
0.9532851096738756,
0.8626836534823537,
0.9847987014265042,
0.8232916798045393,
0.9454067277486898,
0.9808595104639113,
0.850866080594575,
0.8666228444449466],
[0.8745012263701324,
0.9375283458235041,
0.9138932000479467,
0.8469268896319821,
0.8548052715571679,
0.8390485077067964,
0.9375283458235041,
0.9454067277486898,
0.8469268896319821,
0.8666228444449466],
[0.8823796082953181,
0.8784404173327253,
0.768143006328239,
0.8784404173327253,
0.8272308707671322,
0.8193524888419464,
0.7917781521037962,
0.7917781521037962,
0.8075349159541678,
0.768143006328239]]]}
簡単に可視化してみます。後でもう少し可視化のロジックは精査します。
df = pd.DataFrame(input_data.transpose(), columns=["value"])
df['No'] = range(1, len(df.index) + 1)
display(df)
forecasted_df = pd.DataFrame(forecasted["predictions"][0]).T
forecasted_df['No'] = range(1, len(forecasted_df.index) + 1)
display(forecasted_df)
# サービングエンドポイントの削除
func_delete_model_serving_endpoint(model_serving_endpoint_name)