こちらのサンプルノートブックをウォークスルーします。
翻訳版のノートブックはこちらです。
MLflowクイックスタート: トラッキング
このノートブックでは、シンプルなデータセットに対してランダムフォレストモデルを作成し、モデル、選択されたパラメーター、評価メトリクス、その他のアーティファクトを記録するためにMlflowトラッキングAPIを使用します。
ライブラリのインストール
このノートブックでは分散処理を行いませんので、ドライバーノードのみにパッケージをインストールするために、Rのinstall.packages()
関数を使うことができます。
分散処理の利点を活用するには、クラスターライブラリを作成することで、クラスターの全ノードにパッケージをインストールする必要があります。
install.packages("mlflow")
library(mlflow)
ライブラリのインポート
必要なライブラリをインポートします。
このノートブックでは、あとでモデルをロードし直せる様に、トレーニングしたモデルのpredictメソッドをシリアライズするために、Rライブラリのcarrier
を使用します。詳細に関しては、carrier
github repoをご覧ください。
install.packages("carrier")
install.packages("e1071")
library(MASS)
library(caret)
library(e1071)
library(randomForest)
library(SparkR)
library(carrier)
トレーニングおよびトラッキング
with(mlflow_start_run(), {
# モデルパラメーターの設定
ntree <- 100
mtry <- 3
# モデルの作成およびトレーニング
rf <- randomForest(type ~ ., data=Pima.tr, ntree=ntree, mtry=mtry)
# テストデータセットに対する予測にモデルを使用
pred <- predict(rf, newdata=Pima.te[,1:7])
# このランで使用されたモデルパラメーターを記録
mlflow_log_param("ntree", ntree)
mlflow_log_param("mtry", mtry)
# モデルを評価するためのメトリクスの定義
cm <- confusionMatrix(pred, reference = Pima.te[,8])
sensitivity <- cm[["byClass"]]["Sensitivity"]
specificity <- cm[["byClass"]]["Specificity"]
# メトリクスの値を記録
mlflow_log_metric("sensitivity", sensitivity)
mlflow_log_metric("specificity", specificity)
# モデルの記録
# 関数としてモデルを格納するRパッケージ "carrier" の crate() 関数
predictor <- crate(function(x) predict(rf,.x))
mlflow_log_model(predictor, "model")
# コンフュージョンマトリクス(混同行列)の作成およびプロット
png(filename="confusion_matrix_plot.png")
barplot(as.matrix(cm), main="Results",
xlab="Observed", ylim=c(0,200), col=c("green","blue"),
legend=rownames(cm), beside=TRUE)
dev.off()
# アーティファクトとしてプロットを保存
mlflow_log_artifact("confusion_matrix_plot.png")
})
トラッキング結果の確認
結果を参照するには、このページの右上にあるフラスコアイコンをクリックします。エクスペリメントのサイドバーが表示されます。エクスペリメントのサイドバーには、このノートブックにおけるそれぞれのランのパラメーターとメトリクスが表示されます。最新のランを表示する様にするには円形の矢印アイコンをクリックします。
ランダムに割り当てられたランの名称(skillfull-bat-xxx
など)をクリックすると、ランのページが新規タブに表示されます。このページでは、ランとして記録されたすべての情報を確認することができます。記録されたモデルやプロットを参照するために、アーティファクトセクションまで下にスクロールします。
詳細に関しては、"View notebook experiment" (AWS|Azure|GCP)をご覧ください。
まとめ
DatabricksではPythonだけではなく、Rを活用して機械学習モデルをトレーニングして、MLflowを用いてトラッキングすることもできます。Rを得意とされる方も是非試してみてください!