16
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

LightGBMのdartモードでEarlyStoppingを使う

Posted at

LightGBMでboosting="dart"を指定すると、early_stopping_roundsを指定してもdartモードでは使えないという趣旨のwarningが表示され、使うことはできない。
dartでは、学習中に過去に作った木も変更されるため、スコアが悪化したときに以前のベストイテレーションで止めても、そのときのスコアを再現できないためである。
今回は、callback機能を利用して無理やりEarlyStoppingを実現させる。

callbackクラス

lightgbmのcallbackのソースコードを参考に、次のようなクラスを作成した。

import lightgbm as lgb
import pickle

class DartEarlyStopping(object):
    """DartEarlyStopping"""

    def __init__(self, data_name, monitor_metric, stopping_round):
        self.data_name = data_name
        self.monitor_metric = monitor_metric
        self.stopping_round = stopping_round
        self.best_score = None
        self.best_model = None
        self.best_score_list = []
        self.best_iter = 0

    def _is_higher_score(self, metric_score, is_higher_better):
        if self.best_score is None:
            return True
        return (self.best_score < metric_score) if is_higher_better else (self.best_score > metric_score)

    def _deepcopy(self, x):
        # copy.deepcopyではlightgbmのモデルは完全にコピーされないためpickleを使用
        return pickle.loads(pickle.dumps(x))

    def __call__(self, env):
        evals = env.evaluation_result_list
        for data, metric, score, is_higher_better in evals:
            if data != self.data_name or metric != self.monitor_metric:
                continue
            if not self._is_higher_score(score, is_higher_better):
                if env.iteration - self.best_iter > self.stopping_round:
                    # 終了させる
                    eval_result_str = '\t'.join([lgb.callback._format_eval_result(x) for x in self.best_score_list])
                    lgb.basic._log_info(f"Early stopping, best iteration is:\n[{self.best_iter+1}]\t{eval_result_str}") 
                    lgb.basic._log_info(f"You can get best model by \"DartEarlyStopping.best_model\"")
                    raise lgb.callback.EarlyStopException(self.best_iter, self.best_score_list)
                return
            # dartでは過去の木も更新されてしまうため、deepcopyしておく
            self.best_model = self._deepcopy(env.model)
            self.best_iter = env.iteration
            self.best_score_list = evals
            self.best_score = score
            return
        raise ValueError("monitoring metric not found")

各イテレーションで、スコア改善があればモデルをdeepcopyして保存しておくようにした。
ただし、copy.deepcopyではlightgbmのモデルが完全にコピーされない問題があるため、pickleでエンコード/デコードさせた。
学習後に、DartEarlyStopping.best_modelからベストモデルを取得して使うようにする。

利用例

import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error
# データ
np.random.seed(17)
X = np.random.rand(100,2)
Y = np.ravel(np.random.rand(100,1))
eval_X = np.random.rand(100,2)
eval_Y = np.ravel(np.random.rand(100,1))
data = lgb.Dataset(X, label=Y)
eval_data=lgb.Dataset(eval_X, label=eval_Y, reference=data)
# 学習
params = {
    "boosting": "dart",
    'objective': 'rmse',
    'metric': 'rmse',
    "seed":1,
}
des = DartEarlyStopping("valid_1", "rmse", 5)
model = lgb.train(
    params, data,
    valid_sets=[data, eval_data], 
    num_boost_round=100,
    callbacks=[des],
    verbose_eval=1,
)
model = des.best_model
print(f"{np.sqrt(mean_squared_error(eval_Y, model.predict(eval_X))):.6f}")

[1] training's rmse: 0.290185 valid_1's rmse: 0.273062
[2] training's rmse: 0.288899 valid_1's rmse: 0.273025
[3] training's rmse: 0.287789 valid_1's rmse: 0.272813
[4] training's rmse: 0.286382 valid_1's rmse: 0.272106
[5] training's rmse: 0.285237 valid_1's rmse: 0.271591
[6] training's rmse: 0.284481 valid_1's rmse: 0.27179
[7] training's rmse: 0.283597 valid_1's rmse: 0.271512
[8] training's rmse: 0.282869 valid_1's rmse: 0.271293
[9] training's rmse: 0.282278 valid_1's rmse: 0.271212
[10] training's rmse: 0.282335 valid_1's rmse: 0.271227
[11] training's rmse: 0.282359 valid_1's rmse: 0.271237
[12] training's rmse: 0.281777 valid_1's rmse: 0.271276
[13] training's rmse: 0.28188 valid_1's rmse: 0.271512
[14] training's rmse: 0.281742 valid_1's rmse: 0.271591
Early stopping, best iteration is:
[9] training's rmse: 0.282278 valid_1's rmse: 0.271212
You can get best model by "DartEarlyStopping.best_model"
0.271212

当然ですがmodel = des.best_modelとすることで、ベストスコアと一致しました。
ちなみに、model = des.best_modelを書かない場合のスコアは、一般に、学習中に出力されたスコアのどれとも一致しません。
なぜなら、14 stepまで学習する過程で過去の木が変更されており、9 stepまでしか使わない予測をするためです。
このスコアがベストスコアを上回る場合も下回る場合もあります。

16
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?