6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

はじめての記事投稿

Optunaでパラメーター探索中に常にベストモデルを保存しておく方法

Posted at

はじめに

Optunaで機械学習モデルのハイパーパラメーターチューニングを行う際に、最終的に導き出されたベストなパラメーターで改めてモデル構築を行うのが手間(+モデルによっては再度構築するのに時間がかかる)と思い、ベストパラメーターで構築したモデルを常に保存する仕組みを作成してみました。

使用したバージョン

  • python 3.9.7
  • optuna 3.2.0
  • lightgbm 3.3.5
  • scikit-learn 1.2.2

コード

  • ModelRegistry: 機械学習モデルの保存読み込みの機能を持つクラス。
  • StudyCallback: Optunaの学習(study)に対するコールバックとして利用されるクラス。学習が進行すると、このコールバックが呼び出されます。
model_registry.py
import pathlib

from joblib import dump, load

class ModelRegistry:
    def __init__(self, study_name):
        self._models = {}
        save_model_dir = pathlib.Path(__file__).resolve().parent / f".{study_name}_best_model/"
        if not pathlib.Path(save_model_dir).exists():
            pathlib.Path(save_model_dir).mkdir(parents=True, exist_ok=True)
        self.best_model_file = pathlib.Path(save_model_dir) / "model.joblib"

    def register(self, trial_id, model):
        self._models[trial_id] = model

    def retrieve(self, trial_id):
        return self._models.get(trial_id)

    def save_best_model(self, best_trial_id):
        best_model = self.retrieve(best_trial_id)
        if best_model is not None:
            dump(best_model, self.best_model_file)

    def load_best_model(self):
        if pathlib.Path(self.best_model_file).exists():
            return load(self.best_model_file)
        else:
            return None
study_callback.py
class StudyCallback:
    def __init__(self, model_registry):
        self._model_registry = model_registry

    def __call__(self, study, trial):
        if study.best_trial.number == trial.number:
            self._model_registry.save_best_model(study.best_trial.number)

LGBMClassifierでの実装例

main.py
import optuna
import pandas as pd
from lightgbm import LGBMClassifier
from sklearn.metrics import roc_auc_score

from model_registry import ModelRegistry
from study_callback import StudyCallback

class Objective:
    def __init__(self, X_train, y_train, X_valid, y_valid, model_registry):
        self.X_train = X_train
        self.y_train = y_train
        self.X_valid = X_valid
        self.y_valid = y_valid
        self.model_registry = model_registry

    def __call__(self, trial):
        static_params = {
            "random_state": 123,
            "boosting_type": "gbdt",
            "objective": "binary",
        }

        tune_params = {
            "n_estimators": trial.suggest_int("n_estimators", 1, 10000),
            "num_leaves": trial.suggest_int("num_leaves", 2, 256),
            "min_child_samples": trial.suggest_int("min_child_samples", 3, 200),
            "max_depth": trial.suggest_int("max_depth", 1, 8),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3),
        }

        all_params = {
            **tune_params,
            **static_params,
        }

        model = LGBMClassifier(**all_params)

        # 学習
        model.fit(self.X_train, self.y_train, eval_metric="auc")

        # モデルの評価
        y_pred = model.predict_proba(self.X_valid)[:, 1]
        score = roc_auc_score(self.y_valid, y_pred)

        # モデルの登録
        self.model_registry.register(trial.number, model)

        return score

if __name__ == "__main__":
    #データの読み込み
    #省略

    study_name = "test"
    model_registry = ModelRegistry(study_name=study_name)
    callback = StudyCallback(model_registry)
    
    objective = Objective(X_train, y_train, X_valid, y_valid, model_registry)
    
    study = optuna.create_study(
        direction="maximize",
        storage="sqlite:///optuna_study.db",
        study_name=study_name,
        load_if_exists=True,
    )
    
    # callbackに今回作成したStudyCallbackを設定する
    study.optimize(objective, n_trials=100, callbacks=[callback])
    
    # 全トライアル終了後、最も精度の良かったモデルを取得できる
    best_model = model_registry.retrieve(study.best_trial.number)

おわりに

これを実装することでoptunaでパラメーター探索完了後に再びfitする手間・時間を省くことができました。
もっと良いやり方があるかもしれないので、もしあればぜひコメントください。

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?