12
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

MLflowでkerasモデルなどを管理してみた

Posted at

はじめに

会社の業務でKerasを使って、様々なデータセットを様々なモデルで試行することになりました。

モデルのバージョンや、パラメータ設定・メトリクスなど一括で試行毎に管理できないものかと考えていたところ、先輩にOSSの「MLflow」を教えていただきました!

今回は、環境構築実装結果の管理までを文書化したいと思います。

お試し環境(仕事ではWindows, 今回は自宅MacBookPro)

MacOS Mojave 10.14.5

Anaconda

環境構築

まず、MLflowお試し用仮想環境をAnacondaで作成します。

$ conda create -n mlflow_example python=3.7.0
$ source activate mlflow_example

仮想環境が作成でき、アクティブになったら必要なライブラリをインストールします。

(mlflow_example)$ pip install tensorflow
(mlflow_example)$ pip install keras
(mlflow_example)$ pip install mlflow
(mlflow_example)$ pip install sklearn
(mlflow_example)$ pip install matplotlib

お試し最小限の環境がこれで作ることができました。

実装の前に・・・

前提条件

今回は以下の条件をコマンドライン引数で指定できるようにします。

  • 実験名
    • MLflowに与える実験名(例:mnist_cnn とか fashion_fully)
  • データセット
    • MNIST
    • Fashion-MNIST
  • モデル
    • 全結合モデル
    • CNNモデル
  • パラメータ(今回は適当に・・・)
    • 初期学習率
    • Epoch数
    • バッチサイズ

MLflowで管理したいもの

以下に挙げたものをMLflowで管理することにします。

  • パラメータ値
    • 初期学習率
    • Epoch数
    • バッチサイズ
  • メトリクス
    • accuracy (検証データ、テストデータ)
    • presicion (検証データ、テストデータ)
    • recall (検証データ、テストデータ)
    • f1 (検証データ、テストデータ)
  • その他保存したいもの
    • 最もval_lossが低いモデル
    • 学習時のacc, lossのログファイル
    • 学習の推移画像

実装

コードを以下の機能に分割します。

  • 処理管理
    • コマンドライン引数管理
    • データセット読み込み
    • モデル作成
    • メトリクス計算

分量の問題で、ここでは処理管理のコードのみを記載します。(MLflowに関するコードは処理管理部にしか今回は実装していません)

全体コードはこちらにありますので、よければご確認ください。

処理管理

MLflow部分の説明などはコード内に記載しております。

mlflow_example.py
def main():
    """処理を管理する
    """

    # MLflowの実験を開始する(時間計測を開始する)
    mlflow.set_experiment(args.exp_name)
    with mlflow.start_run() as run:

        # ~~~~ 省略 ~~~~~

        # 学習を開始
        history = model.fit(train_x, train_y, batch_size=args.b, epochs=args.e,
                            callbacks=[cb_csv, cb_checkpoint],
                            validation_data=[val_x, val_y])
        plot_history(history)

        # 学習済みモデルを読み込む
        if args.model == 'fully':
            model = build_fully_model(
                reshape_num, class_num, weight_path=model_path)
        elif args.model == 'cnn':
            model = build_cnn_model(
                width, height, ch, class_num, weight_path=model_path)

        # メトリクスを計算
        val_acc, val_prec, val_recall, val_f1 = calc_metrics(
            model, val_x, val_y)
        test_acc, test_prec, test_recall, test_f1 = calc_metrics(
            model, test_x, test_y)

        # MLflowにパラメータを記録する
        mlflow.log_param("learning_rate", args.lr)
        mlflow.log_param("epoch_num", args.e)
        mlflow.log_param("batch_size", args.b)

        # MLflowにメトリクスを記録する
        mlflow.log_metric("val_acc", val_acc)
        mlflow.log_metric("val_precision", val_prec)
        mlflow.log_metric("val_recall", val_recall)
        mlflow.log_metric("val_f1", val_f1)
        # 辞書でまとめて記録することも可能!(ちなみにパラメータも同様のことができる)
        mlflow.log_metrics({"test_acc": test_acc,
                            "test_presicion": test_prec,
                            "test_recall": test_recall,
                            "test_f1": test_f1})

        # その他に保存したいものを記録する
        mlflow.log_artifact(csv_path)  # ファイルパスを与えてやる
        mlflow.log_artifact("./outputs/acc.png")
        mlflow.log_artifact("./outputs/loss.png")
        # kerasの場合、モデルをmlflow.kerasで保存できる
        mlflow.keras.log_model(model, "models")

お試し

実際に試してみた結果を記載します。


(mlflow_example)$ git clone https://github.com/T-Sumida/mlflow_keras_example.git
(mlflow_example)$ cd mlflow_keras_examples

~~~~ 環境構築などは省略 ~~~~~

# 全結合モデルでMNISTを識別してみる
# 実験名:mnist_fully で実験を開始
(mlflow_example)$ python mlflow_example.py mnist_fully --dataset mnist --model fully -e 10
# 少し条件を変えて、実験名:mnist_fully で実験を開始
(mlflow_example)$ python mlflow_example.py mnist_fully --dataset mnist --model fully -e 15 -b 32


# CNNモデルでFashion-MNISTを識別してみる
# 実験名:fashion_cnn で実験を開始
(mlflow_example)$ python mlflow_example.py fashion_cnn --dataset fashion --model cnn -e 10
# 少し条件を変えて、実験名:fashion_cnn で実験を開始
(mlflow_example)$ python mlflow_example.py fashion_cnn --dataset fashion --model cnn -e 10

結果の確認

MLflowはブラウザ上で結果を確認することができます。

実験を行うと、"mlruns"ディレクトリが作成されると思います。そのディレクトリがあるところで以下コマンドを実行します。

(mlflow_example)$ls
README.md         mlflow_example.py mlruns            outputs           qiita.md          requirements.txt
(mlflow_example)$mlflow ui
[2019-08-24 18:01:16 +0900] [78187] [INFO] Starting gunicorn 19.9.0
[2019-08-24 18:01:16 +0900] [78187] [INFO] Listening at: http://127.0.0.1:5000 (78187)
[2019-08-24 18:01:16 +0900] [78187] [INFO] Using worker: sync
[2019-08-24 18:01:16 +0900] [78190] [INFO] Booting worker with pid: 78190

これがうまく起動できれば、http://127.0.0.1:5000にアクセスすると以下画面が表示されます。

pic1

先程実行した実験名が左に並んでいますね。

試しに、"fashion_cnn"を見ると以下のような画面が表示されます。

pic2

そのうちの一つ"2019-08-24 18:57:20"の実験をクリックすると、以下のように記録しておいたパラメータやメトリクス、保存しておいた学習推移の画像を確認することができます。

pic3

もちろん、同じ実験名内でのメトリクスやパラメータの比較なども簡単に確認することができます。

保存したモデルや画像などは一応以下のファイルパスに格納されているので、いつでもファイルをどこかに移すことができます。(Githubのリポジトリにディレクトリ構造のテキストを置いておきました。)

ソースコード

今回のソースコードですが一応公開しております。こちら

おわりに

機械学習モデルのバージョン管理が面倒くさくなってきたので、OSSの「MLflow」を使ってみました。

割と簡単にコード内に埋め込むことができ、結構楽だなぁと感じました。モデル以外にも画像やCSVなども一緒に管理することができるので、後々確認するときに管理が楽になりました。(私はこういうのが非常に苦手なので・・・)
また、ブラウザ上で簡単にパラメータやメトリクスを比較することができるのは非常にありがたいです。不満もちょろちょろありますが、概ね私個人としては満足です!

今回紹介したのは、MLflow Trackingの機能のみです。他にも、手動でコマンドを打たずに自動で環境構築->学習処理ができる MLflow Projects や、学習で作成されたモデル使ったAPIサーバを立ててくれる MLflow Models などが存在します。

今後は、まだ使っていない機能も活用して楽して仕事していこうと考えております。

12
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?