LoginSignup
3
6

More than 3 years have passed since last update.

【メモ】効率的なパラメータ探索方法(ブースティング)

Last updated at Posted at 2021-01-30
  • 製造業出身のデータサイエンティストがお送りする記事
  • 今回はハイパーパラメータの探索で、無駄な事を実施していたことが分かったのでメモとして記事に残しておきます。

はじめに

普段、機械学習モデルを構築する際、ハイパーパラメータの探索をしているかと思います。ちなみに、普段、私はグリッドサーチやベイズ最適化を使用しております。
その際、回帰問題や分類問題で勾配ブースティング木のモデルを使用する事が多いのですが、n_estimator(構築する決定木の数)のパラメータを対象に探索していて無駄をしていることが分かったので自分のメモとして残しておきます。

なぜ、n_estimatorのパラメータ探索は無駄か

勾配ブースティング木において、n_estimatorは構築する決定木の数を表しております。イメージは下記です。

スクリーンショット 2021-01-30 11.50.17.png

勾配ブースティング木は、弱分類器の予測値の誤差を新しく作った弱学習器がどんどん引き継いでいきながら誤差を小さくしていく方法です。
つまり、n_estimatorのパラメータを100、500、1000とグリッドサーチで探索する際、決定木は合計で「100+500+1000=1600」構築することになります。
しかし、勾配ブースティング木の場合は、「n_estimator=1000」の時に「n_estimator=100」や「n_estimator=500」のスコアも算出されているのでわざわざパラメータ探索をする必要がないということです。

n_estimatorのパラメータ設定方法はどうするべきか

結論から言うと、n_estimatorのパラメータを大きめに設定し、EarlyStoppingを活用して最適な決定木の数を決める事が良いのではないかと思います。

EarlyStoppingとは

モデルの学習の収束を判定するための方法です。具体的なやり方としては、validation lossを監視し,train lossは減少し続けるのに対して,validation lossが改善されなくなった場合に学習を打ち切ることでオーバーフィッティングを抑制し、汎化性能が高いモデルを構築する事ができます。

スクリーンショット 2021-01-30 12.55.58.png

(参考文献:Kaggleで勝つデータ分析の技術より引用)

具体的な実装方法

今回、ライブラリーにEarlyStoppingのパラメータがあるLightGBMで実装してみます。

# LightGBMライブラリ
import lightgbm as lgb

import pandas as pd
import numpy as np
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

# スコア計算のためのライブラリ
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

# ボストン住宅の価格のデータセットを読み込む
boston = load_boston()
X, y = boston.data, boston.target

# 訓練データとテストデータに分割する
X_train, X_test, y_train, y_test = train_test_split(X, y)

# LightGBM用のデータセットに加工
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test)

# ハイパーパラメータの設定
params = {'metric': 'rmse',
          'max_depth' : 9}

# モデルの学習
gbm = lgb.train(params,
                lgb_train,
                valid_sets=lgb_eval,
                num_boost_round=10000,
                early_stopping_rounds=20,
                verbose_eval=5)

スクリーンショット 2021-01-30 12.36.38.png

各パラメータについて説明します。
* params:ハイパーパラメータの設定
* lgb_train:訓練データ
* valid_sets:評価用データ
* num_boost_round:学習サイクル
* early_stopping_rounds:EarlyStoppingのパラメータ
* verbose_eval:学習過程を表示するサイクル

上記モデルでは、20サイクル分観察し、過学習しているような場合は学習を終わらせております。

一応、評価データの予測まで実施します。

pred = gbm.predict(X_test)

# 評価
r2 = r2_score(y_test, pred)
rmse = np.sqrt(mean_squared_error(y_test, pred))

print("R2 : %.3f" % r2)
print("RMSE : %.3f" % rmse)

#R2 : 0.878
#RMSE : 2.784

さいごに

最後まで読んで頂き、ありがとうございました。
訂正要望がありましたら、ご連絡頂けますと幸いです。

3
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
6