- 製造業出身のデータサイエンティストがお送りする記事
- 今回はハイパーパラメータの探索で、無駄な事を実施していたことが分かったのでメモとして記事に残しておきます。
##はじめに
普段、機械学習モデルを構築する際、ハイパーパラメータの探索をしているかと思います。ちなみに、普段、私はグリッドサーチやベイズ最適化を使用しております。
その際、回帰問題や分類問題で勾配ブースティング木のモデルを使用する事が多いのですが、n_estimator
(構築する決定木の数)のパラメータを対象に探索していて無駄をしていることが分かったので自分のメモとして残しておきます。
##なぜ、n_estimatorのパラメータ探索は無駄か
勾配ブースティング木において、n_estimator
は構築する決定木の数を表しております。イメージは下記です。
勾配ブースティング木は、弱分類器の予測値の誤差を新しく作った弱学習器がどんどん引き継いでいきながら誤差を小さくしていく方法です。
つまり、n_estimator
のパラメータを100、500、1000とグリッドサーチで探索する際、決定木は合計で「100+500+1000=1600」構築することになります。
しかし、勾配ブースティング木の場合は、「n_estimator=1000」の時に「n_estimator=100」や「n_estimator=500」のスコアも算出されているのでわざわざパラメータ探索をする必要がないということです。
##n_estimatorのパラメータ設定方法はどうするべきか
結論から言うと、n_estimator
のパラメータを大きめに設定し、EarlyStoppingを活用して最適な決定木の数を決める事が良いのではないかと思います。
##EarlyStoppingとは
モデルの学習の収束を判定するための方法です。具体的なやり方としては、validation lossを監視し,train lossは減少し続けるのに対して,validation lossが改善されなくなった場合に学習を打ち切ることでオーバーフィッティングを抑制し、汎化性能が高いモデルを構築する事ができます。
(参考文献:Kaggleで勝つデータ分析の技術より引用)
##具体的な実装方法
今回、ライブラリーにEarlyStoppingのパラメータがあるLightGBMで実装してみます。
# LightGBMライブラリ
import lightgbm as lgb
import pandas as pd
import numpy as np
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
# スコア計算のためのライブラリ
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
# ボストン住宅の価格のデータセットを読み込む
boston = load_boston()
X, y = boston.data, boston.target
# 訓練データとテストデータに分割する
X_train, X_test, y_train, y_test = train_test_split(X, y)
# LightGBM用のデータセットに加工
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test)
# ハイパーパラメータの設定
params = {'metric': 'rmse',
'max_depth' : 9}
# モデルの学習
gbm = lgb.train(params,
lgb_train,
valid_sets=lgb_eval,
num_boost_round=10000,
early_stopping_rounds=20,
verbose_eval=5)
各パラメータについて説明します。
- params:ハイパーパラメータの設定
- lgb_train:訓練データ
- valid_sets:評価用データ
- num_boost_round:学習サイクル
- early_stopping_rounds:EarlyStoppingのパラメータ
- verbose_eval:学習過程を表示するサイクル
上記モデルでは、20サイクル分観察し、過学習しているような場合は学習を終わらせております。
一応、評価データの予測まで実施します。
pred = gbm.predict(X_test)
# 評価
r2 = r2_score(y_test, pred)
rmse = np.sqrt(mean_squared_error(y_test, pred))
print("R2 : %.3f" % r2)
print("RMSE : %.3f" % rmse)
#R2 : 0.878
#RMSE : 2.784
##さいごに
最後まで読んで頂き、ありがとうございました。
訂正要望がありましたら、ご連絡頂けますと幸いです。