1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【Azure AutoML】Python SDK V1 / V2で学習モデルのmetricsを取得する方法

Posted at

はじめに

Azure AutoMLでトレーニングしたモデルのmetricsをPython SDK v1 / v2の2パターンで取得してみます。

モデル自体はSDK V2でトレーニングしたものを使います。
トレーニングについては過去の記事にあるので、ぜひご参照ください。

タスク

環境

  • OS Windows 10(NVIDIA GTX 1650Ti,16GB RAM, i5-10300H CPU)
  • Visual Studio Code 1.73.1
  • Python 3.8 / 3.9
  • AzureML Python SDK v1 / v2

Python SDK V1のmetrics取得する方法

SDK V2でトレーニングしたモデルからmetricsを取得します。

V1の場合は、get_metrics関数を使うと簡単にできます。

get_metrics_v1.py
from azureml.core import Workspace
from azureml.train.automl.run import AutoMLRun
import pandas as pd

ws = Workspace.from_config()

experiment = ws.experiments['<実験名>']
automl_run = AutoMLRun(experiment, run_id = '<ジョブID>')
best_model, _ = automl_run.get_output()
print(best_model.get_metrics())

結果

V2で作ったモデルを無理やりV1でmetrics取得しようとしているせいか、Warningがめちゃくちゃ出ますが、次のようにmetricsが色々返ってきました。

{'mean_absolute_percentage_error': 17.41204188234799, 'explained_variance': 0.738272704815586, 'mean_absolute_error': 2126.4896950718503, 'normalized_median_absolute_error': 0.03536353076485475, 'r2_score': 0.719627939723515, 'root_mean_squared_error': 2978.0824945509294, 'median_absolute_error': 1424.513746269879, 'normalized_root_mean_squared_log_error': 0.10514941989299778, 'normalized_root_mean_squared_error': 0.0739308498721744, 'root_mean_squared_log_error': 0.22949649762141444, 'normalized_mean_absolute_error': 0.05279007236661164, 'spearman_correlation': 0.8992768077701008, 'predicted_true': 'aml://artifactId/ExperimentRun/dcid.gray_iron_drsmn87sp0_0/predicted_true', 'residuals': 'aml://artifactId/ExperimentRun/dcid.gray_iron_drsmn87sp0_0/residuals'}
Exception ignored in: <function _Win32Helper.__del__ at 0x000002671874F4C0>
Traceback (most recent call last):
  File "C:\Users\anaconda3\envs\azml38v1\lib\site-packages\azureml\automl\runtime\shared\win32_helper.py", line 246, in __del__
TypeError: catching classes that do not inherit from BaseException is not allowed

最後にTypeErrorも出ますが、なんとかmetricsを読み込むことはできました!!

続いて、素直にV2からmetricsを取得してみます。

Python SDK V2の場合

SDK V2の場合はMLFlowを使います。
参考サイト:

まずはmlflowとazureml-mlflowをインストールします。

pip install mlflow azureml-mlflow

最初にワークスペースへの接続をし、MLFlowで実験を追跡するためにトラッキングURLを取得します。
その次にトレーニングしたジョブ(親ジョブ)からベストモデルを取得し、metricsを得るという流れになります。

get_metrics_v2.py
import mlflow
from azure.identity import DefaultAzureCredential
from azure.ai.ml import MLClient
from mlflow.tracking.client import MlflowClient
import pandas as pd

# workspace
credential = DefaultAzureCredential(exclude_shared_token_cache_credential=True)
subscription_id="<サブスクリプションID>"
resource_group="<リソースグループ>"
workspace_name="<ワークスペース名>"

ml_client = MLClient(credential, subscription_id, resource_group, workspace_name)

# Obtain the tracking URL from MLClient
MLFLOW_TRACKING_URI = ml_client.workspaces.get(
    name=ml_client.workspace_name
).mlflow_tracking_uri
print(MLFLOW_TRACKING_URI)

mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
print("\nCurrent tracking uri: {}".format(mlflow.get_tracking_uri()))

# Initialize MLFlow client
mlflow_client = MlflowClient()

job_name='<ジョブID>'

# Get the parent run
mlflow_parent_run = mlflow_client.get_run(job_name)

print("Parent Run: ")
print(mlflow_parent_run)
print(mlflow_parent_run.data.tags)

best_child_run_id = mlflow_parent_run.data.tags["automl_best_child_run_id"]
print("Found best child run id: ", best_child_run_id)

best_run = mlflow_client.get_run(best_child_run_id)

print("Best child run: ")
print(best_run)

ワークスペースへの接続・ジョブIDを入力するだけで実行できます。

結果

親ジョブのmetrics・ベストモデルのmetricsが返ってきました。

Parent Run:
<Run: data=<RunData: metrics={'AUC_macro': 0.5537367792751098,
 'AUC_micro': 0.8053648205456106,
 'AUC_weighted': 0.5537367792751098,
 'accuracy': 0.7873917952578096,
 'average_precision_score_macro': 0.5301345560931308,
 'average_precision_score_micro': 0.7764476136917231,
 'average_precision_score_weighted': 0.6946791650615354,
 'balanced_accuracy': 0.5,
 'f1_score_macro': 0.44051441816730164,
 'f1_score_micro': 0.7873917952578096,
 'f1_score_weighted': 0.6937547541810158,
 'log_loss': 0.5174726757985704,
 'matthews_correlation': 0.0,
 'norm_macro_recall': 0.0,
 'precision_score_macro': 0.3936958976289048,
 'precision_score_micro': 0.7873917952578096,
 'precision_score_weighted': 0.6200497715962933,
 'recall_score_macro': 0.5,
 'recall_score_micro': 0.7873917952578096,
 'recall_score_weighted': 0.7873917952578096,
 'weighted_accuracy': 0.9318802438597572}, params={}, tags={'automl_best_child_run_id': 'hungry_chain_dyhsn2rnyj_3',
 'fit_time_000': '0.5580559999999999;4.332179999999999;3;8',
 'iteration_000': '0;1;2;3',
 'mlflow.rootRunId': 'hungry_chain_dyhsn2rnyj',
 'mlflow.runName': 'hungry_chain_dyhsn2rnyj',
 'mlflow.user': '',
 'model_explain_best_run_child_id': 'hungry_chain_dyhsn2rnyj_3',
 'model_explain_run': 'best_run',
 'predicted_cost_000': '0;0;0;0',
 'run_algorithm_000': 'LightGBM;XGBoostClassifier;VotingEnsemble;StackEnsemble',
 'run_preprocessor_000': 'MaxAbsScaler;MaxAbsScaler;;',
 'score_000': '0.7861874294316898;0.7775310500564546;0.7863379751599548;0.7873917952578096',
 'training_percent_000': '100;100;100;100'}>, info=<RunInfo: artifact_uri='azureml://japaneast.api.azureml.ms/mlflow/v2.0/subscriptions/XXX/resourceGroups/XXX/providers/Microsoft.MachineLearningServices/workspaces/XXX/experiments/febdb527-437a-4d38-8d26-7cb8528691ae/runs/hungry_chain_dyhsn2rnyj/artifacts', end_time=1683513330904, experiment_id='febdb527-437a-4d38-8d26-7cb8528691ae', lifecycle_stage='active', run_id='hungry_chain_dyhsn2rnyj', run_name='hungry_chain_dyhsn2rnyj', run_uuid='hungry_chain_dyhsn2rnyj', start_time=1683512458707, status='FINISHED', user_id='XXX'>>
{'model_explain_run': 'best_run', 'score_000': '0.7861874294316898;0.7775310500564546;0.7863379751599548;0.7873917952578096', 'predicted_cost_000': '0;0;0;0', 'fit_time_000': '0.5580559999999999;4.332179999999999;3;8', 'training_percent_000': '100;100;100;100', 'iteration_000': '0;1;2;3', 'run_preprocessor_000': 'MaxAbsScaler;MaxAbsScaler;;', 'run_algorithm_000': 'LightGBM;XGBoostClassifier;VotingEnsemble;StackEnsemble', 'automl_best_child_run_id': 'hungry_chain_dyhsn2rnyj_3', 'model_explain_best_run_child_id': 'hungry_chain_dyhsn2rnyj_3', 'mlflow.rootRunId': 'hungry_chain_dyhsn2rnyj', 'mlflow.runName': 'hungry_chain_dyhsn2rnyj', 'mlflow.user': 'XXX'}
Found best child run id:  hungry_chain_dyhsn2rnyj_3
Best child run:
<Run: data=<RunData: metrics={'AUC_macro': 0.5537367792751098,
 'AUC_micro': 0.8053648205456106,
 'AUC_weighted': 0.5537367792751098,
 'accuracy': 0.7873917952578096,
 'average_precision_score_macro': 0.5301345560931308,
 'average_precision_score_micro': 0.7764476136917231,
 'average_precision_score_weighted': 0.6946791650615354,
 'balanced_accuracy': 0.5,
 'f1_score_macro': 0.44051441816730164,
 'f1_score_micro': 0.7873917952578096,
 'f1_score_weighted': 0.6937547541810158,
 'log_loss': 0.5174726757985704,
 'matthews_correlation': 0.0,
 'norm_macro_recall': 0.0,
 'precision_score_macro': 0.3936958976289048,
 'precision_score_micro': 0.7873917952578096,
 'precision_score_weighted': 0.6200497715962933,
 'recall_score_macro': 0.5,
 'recall_score_micro': 0.7873917952578096,
 'recall_score_weighted': 0.7873917952578096,
 'weighted_accuracy': 0.9318802438597572}, params={}, tags={'mlflow.parentRunId': 'hungry_chain_dyhsn2rnyj',
 'mlflow.rootRunId': 'hungry_chain_dyhsn2rnyj',
 'mlflow.runName': 'kind_snake_g0p4zj5j',
 'mlflow.source.name': 'automl_driver.py',
 'mlflow.source.type': 'JOB',
 'mlflow.user': 'XXX',
 'model_explain_run_id': 'hungry_chain_dyhsn2rnyj_ModelExplain',
 'model_explanation': 'True'}>, info=<RunInfo: artifact_uri='azureml://japaneast.api.azureml.ms/mlflow/v2.0/subscriptions/XXX/resourceGroups/XXX/providers/Microsoft.MachineLearningServices/workspaces/XXX/experiments/febdb527-437a-4d38-8d26-7cb8528691ae/runs/hungry_chain_dyhsn2rnyj_3/artifacts', end_time=1683513329701, experiment_id='febdb527-437a-4d38-8d26-7cb8528691ae', lifecycle_stage='active', run_id='hungry_chain_dyhsn2rnyj_3', run_name='kind_snake_g0p4zj5j', run_uuid='hungry_chain_dyhsn2rnyj_3', start_time=1683513268867, status='FINISHED', user_id='XXX'>>

親ジョブ・ベストモデルのmetricsともに取得することができました。
ここからmetricsだけ取り出すには、下記をコードの最後に貼り付けます。

data = best_run.data.metrics
res = json.dumps(data)
print(res)

結果はこうなりました。

{
  "AUC_weighted": 0.5537367792751098,
  "AUC_micro": 0.8053648205456106,
  "average_precision_score_macro": 0.5301345560931308,
  "matthews_correlation": 0,
  "weighted_accuracy": 0.9318802438597572,
  "norm_macro_recall": 0,
  "log_loss": 0.5174726757985704,
  "f1_score_weighted": 0.6937547541810158,
  "balanced_accuracy": 0.5,
  "recall_score_macro": 0.5,
  "precision_score_weighted": 0.6200497715962933,
  "f1_score_macro": 0.44051441816730164,
  "average_precision_score_micro": 0.7764476136917231,
  "precision_score_macro": 0.3936958976289048,
  "f1_score_micro": 0.7873917952578096,
  "recall_score_weighted": 0.7873917952578096,
  "AUC_macro": 0.5537367792751098,
  "average_precision_score_weighted": 0.6946791650615354,
  "accuracy": 0.7873917952578096,
  "precision_score_micro": 0.7873917952578096,
  "recall_score_micro": 0.7873917952578096
}
1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?