LoginSignup
0
2

More than 3 years have passed since last update.

XGBoostのcallbackを使ってearly_stoppingがかかったときのnum_boost_roundの回数を表示させたい(未達成)

Last updated at Posted at 2020-09-07

前提

筆者超絶弱者なので備忘録として残してます。
間違いがあったら優しく指摘してくださいメンタル豆腐なので
参考サイトのコードを自分なりにわかりやすくした備忘録
環境はazuremlでハイパラ探しにoptuna回してます

前提知識

実装

最低限の実装

def return_callback():
    def print_num_boost_round(env):
        iteration = env.iteration
        msg = '\t'.join([str(x) for x in env.evaluation_result_list])
        print(iteration, msg)

得られる結果として

0  ('validation_0-mae', 2657.650391)
1  ('validation_0-mae', 2657.609375)
0  ('validation_0-mae', 2624.649658)
2  ('validation_0-mae', 2657.425049)
1  ('validation_0-mae', 2624.609131)

のようなものが得られる
次にコードを以下に変更する

def return_callback():
    def print_num_boost_round(env):
        print(env)
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7fa972703208>, cvfolds=None, iteration=0, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2657.623047)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7fa972703208>, cvfolds=None, iteration=1, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2657.463379)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7f7a8224c208>, cvfolds=None, iteration=0, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2624.622314)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7fa972703208>, cvfolds=None, iteration=2, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2657.411377)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7f7a8224c208>, cvfolds=None, iteration=1, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2624.467285)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7fa972703208>, cvfolds=None, iteration=3, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2657.355957)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7f0ced02c208>, cvfolds=None, iteration=0, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2639.834229)])
XGBoostCallbackEnv(model=<xgboost.core.Booster object at 0x7f7a8224c208>, cvfolds=None, iteration=2, begin_iteration=0, end_iteration=100, rank=0, evaluation_result_list=[('validation_0-mae', 2624.416016)])

iterationの値をenv.iterationで取得していることがわかった

参考(https://kunsen.net/2020/05/02/post-3199/)

Optunaでnum_boost_roundを回してみて判断する

param_list['num_boost_round'] = trial.suggest_int("num_boost_round", 100, 500)

まずこれで初期値の100から500くらいでnum_boost_roundを回してみる

指定済みのパラメータ

  • "objective": "reg:gamma",
  • "eval_metric": "mae",
  • "verbosity": 0,
  • "booster": "gbtree",
  • "subsample": 1,
  • "subsample_freq": 0,
  • "early_stopping": 5,
  • "colsample_bylevel": 1,

Optunaで指定するパラメータ一覧

  • "min_child_weight": ""
  • "eta": "",
  • "lambda": "",
  • "alpha": "",
  • "num_leaves": "",
  • "colsample_bytree": "",
  • "num_boost_round": "",

このまま回すと


{
 'max_depth': 20,
 'eta': 0.22613771945050443,
 'num_leaves': 2560,
 'lambda': 6.0425529841148486e-05,
 'alpha': 6.69043393720362e-07,
 'num_boost_round': 236,
 'colsample_bytree': 0.9727432424922707,
 'min_child_weight': 239.6173703091301
}

num_boost_roundは236となる(Optunaの気まぐれなので毎回同じにはならない)
では何が236なのか...
そもそも236回回っているのか(ちなみに再度実行したら253だった)
結果の出力として

0 ('validation_0-mae', 2657.650391)
1  ('validation_0-mae', 2657.609375)
0  ('validation_0-mae', 2624.649658)
2  ('validation_0-mae', 2657.425049)
1  ('validation_0-mae', 2624.609131)

のように出力されるのだが、iterationはend_iterationが示しているように100までしか回らない
次に値が最小になるところを探した(手作業)
135.56956というのが最小の値だったので、その値が出た行数をカウントした
結果は482

結論

よく観察するとiterationが同じだからといって値が同じとは限らない
XGBoostの論文とか読んで前提知識として持ってるとわかりやすかったのかもしれない...
今はゴリ押しでいくしかないのか...??

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2