LoginSignup
1
2

More than 1 year has passed since last update.

MLflowチュートリアル

Posted at

はじめに

MLflow という実験管理用ツールを試す。まずはどのような操作で何ができるのか、概要をつかみたいので MLflow チュートリアルを一通り行ってみる。

Training the Model

Training the Model にあるソースコードを順に読み解いていく。前半は一般的な処理であるので、簡潔な説明のみとする。

まず必要なパッケージのインポートとロガーの設定、メトリクス計算のための関数を定義。

import os
import warnings
import sys

import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import ElasticNet
from urllib.parse import urlparse
import mlflow
import mlflow.sklearn

import logging

logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)


def eval_metrics(actual, pred):
    rmse = np.sqrt(mean_squared_error(actual, pred))
    mae = mean_absolute_error(actual, pred)
    r2 = r2_score(actual, pred)
    return rmse, mae, r2

学習データの読み込み、訓練データとテストデータの分割といった基本的な処理を実行。パラメータである alphal1_ratio はスクリプト実行時に変数を与えることができ、ない場合はデフォルト値として 0.5 となっている。

if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    np.random.seed(40)

    # Read the wine-quality csv file from the URL
    csv_url = (
        "http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
    )
    try:
        data = pd.read_csv(csv_url, sep=";")
    except Exception as e:
        logger.exception(
            "Unable to download training & test CSV, check your internet connection. Error: %s", e
        )

    # Split the data into training and test sets. (0.75, 0.25) split.
    train, test = train_test_split(data)

    # The predicted column is "quality" which is a scalar from [3, 9]
    train_x = train.drop(["quality"], axis=1)
    test_x = test.drop(["quality"], axis=1)
    train_y = train[["quality"]]
    test_y = test[["quality"]]

    alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
    l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5

ここから mlflow の処理となる。mlflow.start_run() は新しい MLflow run を開始し、run がアクティブな状態であるときにメトリクスやパラメータを記録することができる。with 文を用いない場合は、明示的に mlflow.end_run() を呼び出し、run を終了させる必要がある。
その他は学習を行い、予測結果から各種メトリクスを計算している。

    with mlflow.start_run():
        lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
        lr.fit(train_x, train_y)

        predicted_qualities = lr.predict(test_x)

        (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)

        print("Elasticnet model (alpha=%f, l1_ratio=%f):" % (alpha, l1_ratio))
        print("  RMSE: %s" % rmse)
        print("  MAE: %s" % mae)
        print("  R2: %s" % r2)

上記で計算したメトリクスやパラメータをログとして残していく。log_param() は引数に keyvalue を持ち、パラメータ名とパラメータ値を記録する。以下では、alphal1_ratio の2つが記録されている。
同様に log_metric() はメトリクス名と値を記録し、以下では rmser2mae の3つが記録されている。

        mlflow.log_param("alpha", alpha)
        mlflow.log_param("l1_ratio", l1_ratio)
        mlflow.log_metric("rmse", rmse)
        mlflow.log_metric("r2", r2)
        mlflow.log_metric("mae", mae)

上記は log_params()mlflow.log_metrics() を使えば、以下のようにしてまとめて書くこともできる。

(other pattern)
params = {"alpha": alpha, "l1_ratio": l1_ratio}
mlflow.log_params(params)
metrics = {"rmse": rmse, "r2": r2, "mae": mae}
mlflow.log_metrics(metrics)

get_tracking_uri() は tracking_uri(file:///home/user/mlflow_test/mlruns のような形式)を取得している。得られた tracking_uri を urlparse() を用いて解析し、スキーマ(上記例だと file に相当)が得られる。
これが file かどうかで log_model() のモデルの記録方法を変更している。(変更している意図は不明)

        tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme

        # Model registry does not work with file store
        if tracking_url_type_store != "file":

            # Register the model
            # There are other ways to use the Model Registry, which depends on the use case,
            # please refer to the doc for more information:
            # https://mlflow.org/docs/latest/model-registry.html#api-workflow
            mlflow.sklearn.log_model(lr, "model", registered_model_name="ElasticnetWineModel")
        else:
            mlflow.sklearn.log_model(lr, "model")

Comparing the Models

alphal1_ratio を変更したモデルをいくつか作成した後、mlflow ui をターミナルで実行し、http://localhost:5000 にアクセスすると以下のような画面に遷移する。

mlflow_ui.png

各モデルに関して、先ほど記録したメトリクスやパラメータを見ることができる。また複数のモデルを選択して比較したり、フィルタをかけることなどもできる。

おわりに

MLflow チュートリアルに則って、シンプルな使い方や GUI を用いたモデルの比較などを行った。業務レベルに落としこんで、さまざまなことを試していきたい。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2