こちらのクイックスタートをDatabricksでウォークスルーします。Databricksでは、簡単にMLflowを活用できるように色々工夫されていますので、その点についても触れていきます。
MLflowのインストール
Databricks機械学習ランタイムには最初からMLflowが入っていますので、MLflowのインストールは不要です。
コードにMLflowトラッキングを追加
Databricksではデフォルトでオートロギングが有効化されているので、一般的な機械学習ライブラリを用いてトレーニングを行う場合には、自動的にモデルやパラメーターが記録されます。
import mlflow
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
#mlflow.autolog() # Databricksでは最初から有効化されています
db = load_diabetes()
X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)
# モデルの作成およびトレーニング
rf = RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3)
rf.fit(X_train, y_train)
# テストデータセットに対する予測を行うためにモデルを使用
predictions = rf.predict(X_test)
画面右のフラスコアイコンをクリックすると、記録されたモデルをクイックに確認することができます。
MLflowランとエクスペリメントの参照
Databricksでは、最初からMLflowのランやエクスペリメントを参照するGUIが提供されています。上の例では、able-jay-32(ランダムに割り振られるMLflowランの名称です)をクリックします。
モデルのメタデータやモデル本体にアクセスすることができます。
MLflowランとエクスペリメントの共有
Databricksでは最初からMLflowトラッキングサーバーが稼働しています。これによって、他のユーザーと簡単にトレーニングしたモデルを共有することができます。
そして、MLflowモデルレジストリにモデルを登録することで、モデルのバージョン管理やステージ変更を行うこともできます。
MLflowにおけるモデルの格納
以下の例では明示的にモデルのみを記録しています。
from mlflow.models import infer_signature
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor
with mlflow.start_run() as run:
# 肥満データセットのロード
db = load_diabetes()
X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)
# モデルの作成およびトレーニング
rf = RandomForestRegressor(n_estimators=100, max_depth=6, max_features=3)
rf.fit(X_train, y_train)
# テストデータセットに対する予測を行うためにモデルを使用
predictions = rf.predict(X_test)
print(predictions)
signature = infer_signature(X_test, predictions)
mlflow.sklearn.log_model(rf, "model", signature=signature)
print("Run ID: {}".format(run.info.run_id))
予測結果とMLflowランのIDが表示されます。MLflowランとは一回の機械学習モデルトレーニングの単位となります。トレーニングを一度実行すると一つのMLflowランが作成されます。これに、モデルのメタデータやモデル本体が格納されます。このランを識別するのがランIDとなります。ランの中身はMLflowランとエクスペリメントの参照のスクリーンショットとなります。
166.65036953 98.03364649 161.58286012 231.04610204 153.21889821
197.44826792 81.95439573 142.87658726 100.44504301 104.3191207
244.29602895 164.09347765 178.62838511 201.18453629 86.76278193
100.54268893 99.54048543 216.0454191 256.27544385 221.85699435
228.91260635 160.35332129 226.86258835 147.56799572 119.3539991
80.79713601 106.47927539 166.81398074 96.27388411 178.25882309
87.55807753 183.94572357 116.11188938 88.11050328 96.24554151
90.79430986 117.67192536 123.59356521 243.40493915 227.25405076
113.43119586 215.48205066 146.30803545 102.21284592 178.38787909
87.80294451 199.44103452 127.74580398 165.77781199 104.60040632
224.67158042]
Run ID: 4e4fcb33a9674b2f96a4b3f89b298060
推論のために特定のトレーニングランのモデルをロード
上で取得したランIDを用いることで、記録された機械学習モデルを簡単にロードすることができます。
import mlflow
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
db = load_diabetes()
X_train, X_test, y_train, y_test = train_test_split(db.data, db.target)
model = mlflow.sklearn.load_model("runs:/4e4fcb33a9674b2f96a4b3f89b298060/model")
predictions = model.predict(X_test)
print(predictions)
[ 80.79713601 204.84289248 220.81657508 149.7813926 115.19312512
180.74948643 192.32297658 242.49986613 100.24462388 102.69450998
173.67547489 103.35198154 199.44103452 97.50758223 126.34563398
263.17895411 192.46613324 172.4212659 88.5341338 119.03151311
176.9826351 127.49339136 105.87572449 154.42964913 83.59904458
99.2188995 197.76728548 155.69246889 92.23599226 101.03448389
168.70722873 156.92065338 194.38561928 234.0162017 87.65881457
100.44504301 160.35332129 109.10922676 267.66252006 269.88405211
119.88418895 237.24865348 132.85057795 76.59845554 222.92935138
147.91253963 224.67158042 99.23902848 158.80030757 105.2951715
213.88045428 141.82039546 211.92238472 142.87658726 201.18453629
120.45003206 153.90813032 228.91260635 162.04875195 132.45666152
109.62532645 106.29439161 91.60861258 87.26780318 119.40719688
135.39628263 82.57655631 141.50305867 213.50566206 147.4065048
185.79176333 191.88049929 261.60705085 231.7422972 195.24882286
86.83676848 83.12819316 89.07706368 83.11144491 197.02787232
162.70305232 215.48205066 257.88537273 92.16962614 217.54390241
103.26284085 166.65036953 203.1664986 174.34916461 147.56799572
231.04610204 168.5504616 104.81574879 195.9918726 97.98698814
197.05206448 93.33294332 264.98027582 109.37871312 295.16904602
109.45869195 253.71902442 78.33563923 155.11245165 124.11237821
231.43037629 229.56000598 175.05611851 194.59560339 212.55498504
186.9967145 ]
まとめ
上で説明したように、Databricksでは最初からMLflowがインテグレーションされているので、簡単に機械学習モデルの管理を実現することができます。私はMLflowを使い始めて以来、Jupyterノートブック+Excelでのモデル管理には戻れなくなりました。是非体験してみてください!
以下のDatabricks漫画シリーズのデータサイエンティスト編でもMLflowのメリットをご紹介しています。なお、私が出演しています。