LoginSignup
1
0

More than 1 year has passed since last update.

MLflow tracking 備忘録

Posted at

MLflow tracking 概要

機械学習を行う際のハイパーパラメータや評価指標のログをとり、管理を楽にするツール

基本となる使い方

以下ではpython APIについての説明をする。

インストール

pip install mlflow

code

import mlflow 

# 処理を記述
lr = ~
model = ~
loss = ~

# MLflowでのログ 
with mlflow.sart_run():
   mlflow.log_param('lr', lr)
   mlflow.log_metrics('loss', loss)
   mlflow.pytorch.log_model(model, 'model')

code解説

  • with mlflow.start_run():
    • このコードブロック内でログの記述をする
  • mlflow.log_param('lr', lr)
    • パラメータの保存(学習率等のハイパラ)
  • mlflow.log_metrics('loss', loss)
    • メトリクスの保存(損失関数,精度等)
  • mlflow.pytorch.log_model(model, 'loss')
    • モデルの保存

プログラム実行後、ディレクトリ直下のmlruns/以下にログが保存される。
mlruns/のあるディレクトリで

mlflow ui

とするとローカルサーバーが立ち上がるりurlが表示されるのでアクセス
(デフォルトはhttp://127.0.0.1:5000)
ログとして保存したものをグラフで確認できる。

MLflow with pytorch lightning

pytorch lightningと組み合わせて使うとさらに簡単にログをとることができる。
(最初にMLflowを使ったのがこちらだったため簡単すぎて逆に仕組みが分かりにくかった)

Lightning Moduleの定義

class Model(pl.LightningModule):
    def __init__():
        ~
    def training_step(self, batch, batch_idx):
        loss=~
        return loss
    def training_epoch_end(self, losses):
        self.log('train_loss', np.mean(losses))
    def ~

code

import mlflow
import pytorch_lightning as pl
with mlflow.start_run():
    mlflow.pytorch.autolog()
    model=Model()
    dataloader=~
    trainer=pl.trainer()
    trainer.fit(model,data_loader)
  • mlflow.pytorch.autolog()
    • Lightning Moduleのログ(この場合train_loss),trainerのパラメータ,モデルの保存が自動で行われる。

その他便利なメソッド(追記予定)

  • mlflow.set_experiment(mlflow_exp_name)
    • 実験に名前を付けられる。(あとで見るとき分かりやすい)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0