0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

以前、こちらの記事を書きました。

機械学習モデルトレーニングの管理においては、トレーニングでどのデータを使ったのかは再現性確保の観点で重要です。しかし、手動でそのような情報を管理するのにも限界があります。

上述の記事では、MLflowでモデルをトラッキングする際に、トレーニングに用いたテーブルをバージョンとともに記録できる機能を説明しました。1年以上前の記事ということもあるので、機能を再度確認します。

こちらでは、以下の観点でウォークスルーします。

  1. MLflowエクスペリメントにおけるモデルとデータのトラッキング
  2. Unity Catalogにおけるモデルとデータのトラッキング

MLflowのエクスペリメント(実験)とはMLflowの用語ですが、MLflowにおいて機械学習モデルトレーニングをトラッキングする際、個々のトレーニングはMLflowラン(実行)として管理します。エクスペリメント複数のMLflowランをまとめて管理するための箱のようなものです。これらは、名前の通り、実験を試行錯誤している段階でモデルに関連する情報を記録するためのものです。

そして、Unity Catalogでのトラッキングに言及しているのは、実験から本格運用に移行する際にUnity Catalogでモデルのライフサイクルを管理することになるためです。以前は、この機能はMLflowのモデルレジストリ(ワークスペースモデルレジストリ)が担っていましたが、DatabricksではUnity Catalogのモデルレジストリに相当する機能(Models in UC)を使うことを推奨するようになっています。

MLflowエクスペリメントにおけるモデルとデータのトラッキング

最新のランタイムを使えば、前回の記事で言及した機能を活用するためのMLflowは最初から利用できます。
Screenshot 2024-07-17 at 8.35.22.png

Auto Loggingによるトレーニングデータのトラッキング

こちらのサンプルを一部変更しながら機能を確認します。

from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier

# irisデータセットでsklearnモデルをトレーニング
X, y = datasets.load_iris(return_X_y=True, as_frame=True)
clf = RandomForestClassifier(max_depth=7)
clf.fit(X, y)

Auto Loggingによってscikit-learnのモデルなどは自動でロギングされます。
Screenshot 2024-07-17 at 8.38.10.png

エクスペリメント配下のラン詳細を確認します。
Screenshot 2024-07-17 at 8.38.36.png

この時点で使用されたデータセットからトレーニングデータを確認することができます。
Screenshot 2024-07-17 at 8.39.42.png

mlflow.log_inputによるトレーニングデータテーブルのトラッキング

上の例でも、トレーニングデータ自体はトラッキングできていますが、より厳密にトレーニングデータを管理するのであれば、Unity Catalogのテーブルでトレーニングデータを保持すべきです。耐障害性、バージョン管理、アクセスコントロールなどさまざまな面でのメリットを享受することができます。

mlflow.log_inputを用いることで、ラン(トレーニング)で用いたデータセットをモデルとともに記録することができます。以下の例では、上と同じirisデータセットをUnity Catalog上のDeltaテーブルtakaakiyayoi_catalog.mlflow.irisとして保存し、これをトレーニングデータとして使用しています。

import mlflow
import pandas as pd
import pyspark.pandas as ps
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestRegressor

# Unity Catalogへのテーブルの書き込み
iris = load_iris()
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
# 列名の変更
iris_df.rename(
  columns = {
    'sepal length (cm)':'sepal_length',
    'sepal width (cm)':'sepal_width',
    'petal length (cm)':'petal_length',
    'petal width (cm)':'petal_width'},
  inplace = True
)
iris_df['species'] = iris.target
ps.from_pandas(iris_df).to_table("takaakiyayoi_catalog.mlflow.iris", mode="overwrite")

# Unity Catalogテーブルのロード
dataset = mlflow.data.load_delta(table_name="takaakiyayoi_catalog.mlflow.iris", version="0")
pd_df = dataset.df.toPandas()
X = pd_df.drop("species", axis=1)
y = pd_df["species"]

with mlflow.start_run():
    clf = RandomForestRegressor(n_estimators=100)
    # モデルのトレーニング
    clf.fit(X, y)
    # 入力テーブルの記録
    mlflow.log_input(dataset, "training")

結果として作成されるMLflowランを確認すると、使用されたデータセットtakaakiyayoi_catalog.mlflow.iris@v0と表示されています。テーブル名とバージョン番号v0です。Deltaはデータのバージョン管理ができるので、mlflow.data.load_deltaでデータロードで指定したバージョン番号がそのまま記録されることになります。
Screenshot 2024-07-17 at 8.54.47.png

テーブルも確認できます。
Screenshot 2024-07-17 at 8.55.15.png

このようにモデルとトレーニングをきちんと管理しながら実験を繰り返すことができます。そして、ある程度のKPI(制度やレーテンシー)をクリアしたら、本格運用に投入という流れになります。そこで、Unity Catalogのモデル管理機能を使うことになります。

Unity Catalogにおけるモデルとデータのトラッキング

上の例では、モデル自体を記録していなかったので、以下の例ではmlflow.sklearn.log_modelを用いてモデルを記録すると同時に、registered_model_name="takaakiyayoi_catalog.mlflow.iris_model"を指定してUnity Catalogに登録するようにします。また、この際に入力サンプルを指定しています。これらはモデルのデプロイに必要になる情報です。詳細に関してはこちらのマニュアルをご覧ください。

# Unity Catalogにモデルを登録するように設定
mlflow.set_registry_uri("databricks-uc")

# Unity Catalogテーブルのロード
dataset = mlflow.data.load_delta(table_name="takaakiyayoi_catalog.mlflow.iris", version="0")
pd_df = dataset.df.toPandas()
X = pd_df.drop("species", axis=1)
y = pd_df["species"]

# モデル入力のサンプルとしてトレーニングデータセットの最初の行を取得
input_example = X.iloc[[0]]

with mlflow.start_run():
    clf = RandomForestRegressor(n_estimators=100)
    # モデルのトレーニング
    clf.fit(X, y)
    # 入力テーブルの記録
    mlflow.log_input(dataset, "training")

    # モデルを記録し、UCの新バージョンとしてモデルを登録
    mlflow.sklearn.log_model(
        sk_model=clf,
        artifact_path="model",
        # 入力サンプルとその予測アウトプットから自動でシグネチャを推定
        input_example=input_example,
        registered_model_name="takaakiyayoi_catalog.mlflow.iris_model",
    )

すると、使用されたデータセットに加えて、登録済みモデルにモデル名とモデルバージョンv1が表示されます。

Screenshot 2024-07-17 at 9.07.58.png

Unity Catalog配下のモデルも確認できます。
Screenshot 2024-07-17 at 9.09.20.png

依存関係(リネージ)も確認できます。
Screenshot 2024-07-17 at 9.09.55.png
Screenshot 2024-07-17 at 9.10.32.png

ご活用ください!

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?