1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MLflowの便利さを徹底解説:実際のコード例付き

Posted at

はじめに

機械学習のプロジェクトを進める中で、モデルの実験管理や再現性の確保に課題を感じたことはありませんか?MLflowは、そんな課題を解決するために開発されたオープンソースのプラットフォームです。本記事では、MLflowの便利な機能を詳細に解説し、実際のコード例を交えながらその活用法を紹介します。

MLflowとは?

MLflowはDatabricks社によって開発されたオープンソースの機械学習ライフサイクル管理ツールです。以下の4つの主要な機能を提供します。

  1. MLflow Tracking:実験の記録と比較を容易にします。
  2. MLflow Projects:プロジェクトの再現性を確保します。
  3. MLflow Models:様々な形式でモデルをパッケージングできます。
  4. MLflow Registry:モデルのライフサイクル管理をサポートします。

なぜMLflowが便利なのか?

1. 実験の記録と比較が簡単(MLflow Tracking)

機械学習の実験では、ハイパーパラメータやモデルのバージョン、評価指標などを記録する必要があります。MLflow Trackingを使えば、これらを簡単に記録し、後で比較できます。

コード例

import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# データの準備
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.2, random_state=42
)

# MLflowで実験開始
with mlflow.start_run():
    # モデルの訓練
    clf = RandomForestClassifier(n_estimators=100, max_depth=3, random_state=42)
    clf.fit(X_train, y_train)

    # 予測と評価
    y_pred = clf.predict(X_test)
    acc = accuracy_score(y_test, y_pred)

    # パラメータとメトリクスの記録
    mlflow.log_param("n_estimators", 100)
    mlflow.log_param("max_depth", 3)
    mlflow.log_metric("accuracy", acc)

    # モデルの保存
    mlflow.sklearn.log_model(clf, "random_forest_model")

print(f"Accuracy: {acc}")

2. 再現性の確保(MLflow Projects)

MLflow Projectsを利用すると、依存関係や環境を含めたプロジェクト全体を再現可能な形で定義できます。

MLproject ファイルの例

name: RandomForestExample

conda_env: conda.yaml

entry_points:
  main:
    parameters:
      n_estimators: {type: int, default: 100}
      max_depth: {type: int, default: 3}
    command: >
      python train.py --n_estimators {n_estimators} --max_depth {max_depth}

3. モデルのパッケージング(MLflow Models)

MLflow Modelsを使えば、訓練したモデルを異なるフレームワークに対応した形式で保存・配布できます。

モデルの保存とロード

# モデルの保存
mlflow.sklearn.save_model(clf, "./model")

# モデルのロード
loaded_model = mlflow.sklearn.load_model("./model")
y_loaded_pred = loaded_model.predict(X_test)

4. モデルのライフサイクル管理(MLflow Registry)

MLflow Registryを使うと、モデルのバージョン管理やステージ管理(例えば「Staging」「Production」)ができます。

モデルの登録

import mlflow.pyfunc

# モデルの登録
result = mlflow.register_model(
    "runs:/<run_id>/random_forest_model", "IrisClassifier"
)

# モデルのステージ変更
from mlflow.tracking import MlflowClient

client = MlflowClient()
client.transition_model_version_stage(
    name="IrisClassifier", version=result.version, stage="Production"
)

まとめ

MLflowを活用することで、機械学習プロジェクトの管理が飛躍的に効率化します。実験の追跡、再現性の確保、モデルのパッケージング、そしてライフサイクル管理と、機械学習の各プロセスを一元的に管理できる点が最大の魅力です。

これからMLflowを導入しようと考えている方は、ぜひ今回紹介したコード例を参考にしてみてください。

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?