1
0

More than 1 year has passed since last update.

DatabricksにおけるMLflowクイックスタートのウォークスルー

Posted at

こちらのクイックスタートをDatabricksでウォークスルーします。Databricksでは、簡単にMLflowを活用できるように色々工夫されていますので、その点についても触れていきます。

MLflowのインストール

Databricks機械学習ランタイムには最初からMLflowが入っていますので、MLflowのインストールは不要です。

コードにMLflowトラッキングを追加

Databricksではデフォルトでオートロギングが有効化されているので、一般的な機械学習ライブラリを用いてトレーニングを行う場合には、自動的にモデルやパラメーターが記録されます。

import mlflow

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor

#mlflow.autolog() # Databricksでは最初から有効化されています

db = load_diabetes()
X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)

# モデルの作成およびトレーニング
rf = RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3)
rf.fit(X_train, y_train)

# テストデータセットに対する予測を行うためにモデルを使用
predictions = rf.predict(X_test)

記録された旨のメッセージが表示されます。
Screenshot 2023-09-01 at 20.02.23.png

画面右のフラスコアイコンをクリックすると、記録されたモデルをクイックに確認することができます。
Screenshot 2023-09-01 at 20.03.06.png

MLflowランとエクスペリメントの参照

Databricksでは、最初からMLflowのランやエクスペリメントを参照するGUIが提供されています。上の例では、able-jay-32(ランダムに割り振られるMLflowランの名称です)をクリックします。

モデルのメタデータやモデル本体にアクセスすることができます。
Screenshot 2023-09-01 at 20.04.41.png
Screenshot 2023-09-01 at 20.04.53.png

MLflowランとエクスペリメントの共有

Databricksでは最初からMLflowトラッキングサーバーが稼働しています。これによって、他のユーザーと簡単にトレーニングしたモデルを共有することができます。

そして、MLflowモデルレジストリにモデルを登録することで、モデルのバージョン管理やステージ変更を行うこともできます。

MLflowにおけるモデルの格納

以下の例では明示的にモデルのみを記録しています。

from mlflow.models import infer_signature

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor

with mlflow.start_run() as run:
    # 肥満データセットのロード
    db = load_diabetes()
    X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)

    # モデルの作成およびトレーニング
    rf = RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3)
    rf.fit(X_train, y_train)

    # テストデータセットに対する予測を行うためにモデルを使用
    predictions = rf.predict(X_test)
    print(predictions)

    signature = infer_signature(X_test, predictions)
    mlflow.sklearn.log_model(rf, "model", signature=signature)

    print("Run ID: {}".format(run.info.run_id))

予測結果とMLflowランのIDが表示されます。MLflowランとは一回の機械学習モデルトレーニングの単位となります。トレーニングを一度実行すると一つのMLflowランが作成されます。これに、モデルのメタデータやモデル本体が格納されます。このランを識別するのがランIDとなります。ランの中身はMLflowランとエクスペリメントの参照のスクリーンショットとなります。

166.65036953  98.03364649 161.58286012 231.04610204 153.21889821
 197.44826792  81.95439573 142.87658726 100.44504301 104.3191207
 244.29602895 164.09347765 178.62838511 201.18453629  86.76278193
 100.54268893  99.54048543 216.0454191  256.27544385 221.85699435
 228.91260635 160.35332129 226.86258835 147.56799572 119.3539991
  80.79713601 106.47927539 166.81398074  96.27388411 178.25882309
  87.55807753 183.94572357 116.11188938  88.11050328  96.24554151
  90.79430986 117.67192536 123.59356521 243.40493915 227.25405076
 113.43119586 215.48205066 146.30803545 102.21284592 178.38787909
  87.80294451 199.44103452 127.74580398 165.77781199 104.60040632
 224.67158042]
Run ID: 4e4fcb33a9674b2f96a4b3f89b298060

推論のために特定のトレーニングランのモデルをロード

上で取得したランIDを用いることで、記録された機械学習モデルを簡単にロードすることができます。

import mlflow

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes

db = load_diabetes()
X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)

model = mlflow.sklearn.load_model("runs:/4e4fcb33a9674b2f96a4b3f89b298060/model")
predictions = model.predict(X_test)
print(predictions)
[ 80.79713601 204.84289248 220.81657508 149.7813926  115.19312512
 180.74948643 192.32297658 242.49986613 100.24462388 102.69450998
 173.67547489 103.35198154 199.44103452  97.50758223 126.34563398
 263.17895411 192.46613324 172.4212659   88.5341338  119.03151311
 176.9826351  127.49339136 105.87572449 154.42964913  83.59904458
  99.2188995  197.76728548 155.69246889  92.23599226 101.03448389
 168.70722873 156.92065338 194.38561928 234.0162017   87.65881457
 100.44504301 160.35332129 109.10922676 267.66252006 269.88405211
 119.88418895 237.24865348 132.85057795  76.59845554 222.92935138
 147.91253963 224.67158042  99.23902848 158.80030757 105.2951715
 213.88045428 141.82039546 211.92238472 142.87658726 201.18453629
 120.45003206 153.90813032 228.91260635 162.04875195 132.45666152
 109.62532645 106.29439161  91.60861258  87.26780318 119.40719688
 135.39628263  82.57655631 141.50305867 213.50566206 147.4065048
 185.79176333 191.88049929 261.60705085 231.7422972  195.24882286
  86.83676848  83.12819316  89.07706368  83.11144491 197.02787232
 162.70305232 215.48205066 257.88537273  92.16962614 217.54390241
 103.26284085 166.65036953 203.1664986  174.34916461 147.56799572
 231.04610204 168.5504616  104.81574879 195.9918726   97.98698814
 197.05206448  93.33294332 264.98027582 109.37871312 295.16904602
 109.45869195 253.71902442  78.33563923 155.11245165 124.11237821
 231.43037629 229.56000598 175.05611851 194.59560339 212.55498504
 186.9967145 ]

まとめ

上で説明したように、Databricksでは最初からMLflowがインテグレーションされているので、簡単に機械学習モデルの管理を実現することができます。私はMLflowを使い始めて以来、Jupyterノートブック+Excelでのモデル管理には戻れなくなりました。是非体験してみてください!

以下のDatabricks漫画シリーズのデータサイエンティスト編でもMLflowのメリットをご紹介しています。なお、私が出演しています。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0