1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MLflowで実験管理を自動化(サンプルコード)

Posted at

MLflowとは

以下の記事で細かく解説されています。
こちらの記事は、備忘録として最低限動作確認するためのソースコードを中心に記載します。
また、あくまでもMLFlowがメインなので、作成するモデルに関する説明は省略させていただきます。

環境

項目 情報
Python 3.9.13
mlflow 2.17.1

ソースコード

事前準備

# ライブラリのインポート
import mlflow
import mlflow.sklearn
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from mlflow.models.signature import infer_signature

# データのロード(乳がんデータセット:二値分類用)
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names).astype(float)
y = data.target

# データ分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=97)

# 新しい実験名を設定
mlflow.set_experiment("qiita sample")

関数定義

# 学習用関数
def fit_rfn(n_estimators, max_depth):
    # MLflowで実験を記録するためのセッション開始
    with mlflow.start_run():
        # モデルのインスタンス化とトレーニング
        model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
        model.fit(X_train, y_train)

        # 予測と評価
        predictions = model.predict(X_test)
        probabilities = model.predict_proba(X_test)[:, 1]

        # メトリクスの計算
        auc = roc_auc_score(y_test, probabilities)
        precision = precision_score(y_test, predictions)
        recall = recall_score(y_test, predictions)

        # ハイパーパラメータの記録
        mlflow.log_param("n_estimators", n_estimators)
        mlflow.log_param("max_depth", max_depth)

        # メトリクスの記録
        mlflow.log_metric("AUC", auc)
        mlflow.log_metric("Precision", precision)
        mlflow.log_metric("Recall", recall)

        # サンプル入力(input_example)の設定
        input_example = X_test.iloc[0:1]  # テストデータの最初の行を使用

        # スキーマの設定
        signature = infer_signature(X_test, predictions)

        # モデルの保存
        mlflow.sklearn.log_model(model, "model", signature=signature, input_example=input_example)

        print(f"Logged Model - AUC: {auc}, Precision: {precision}, Recall: {recall}")

モデル作成

# ハイパーパラメータを変更して複数回実行
fit_rfn(n_estimators=100, max_depth=10)
fit_rfn(n_estimators=200, max_depth=20)
fit_rfn(n_estimators=50, max_depth=5)

結果確認

以下のコマンドを入力すると、ブラウザから作成したモデルの情報を確認することができます。

mlflow ui

image.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?