LoginSignup
7

More than 5 years have passed since last update.

VotingClassifierを使いつつGridSearchCV/RandomizedSearchCVでパラメータチューニング

Last updated at Posted at 2017-03-13

概要

scikit-learnでは複数のモデルの予測値を特定のルールに則ってマージするVotingClassifierがある。入力とするモデルは、個別にパラメータを指定した状態で投げるか、パラメータチューニングとVotingClassifierを同時に動かす必要がある。

今回はVotingClassifierを使いつつ個々のモデルでもパラメータチューニングする方法をまとめてみる。基本的には、下記ドキュメントに書かれている方法と同じ。

1.11. Ensemble methods — scikit-learn 0.18.1 documentation

パラメータの指定方法

GridSearchCV/RandomizedSearchCVに指定するparam_grid/param_distributionsには、lr__Cのように{モデル名}__{パラメータ名}を指定する。モデル名はVotingClassifierでestimatorsに指定する名前を用いる。

# http://scikit-learn.org/stable/modules/ensemble.html#using-the-votingclassifier-with-gridsearch
params = {'lr__C': [1.0, 100.0], 'rf__n_estimators': [20, 200],}

パラメータチューニング時の注意

VotingClassifierは各モデルでパラメータチューニングするわけではなく、すべてのモデルのすべてのパラメータの組み合わせでチューニングを行う。つまりxgboostで組み合わせが100通りあってRandomForestで100通りあると、それぞれ100通り試して最後にチューニング結果を合わせてVotingしベストを決めるわけではなく、100×100=10,000通りの分類器をVotingとともに試すようになっている。そのため、試行するパラメータ数が多いとGridSearchCVで組み合わせ爆発が起こる可能性があるので、できればRandomizedSearchCVで試行回数(n_iter)を定めた上で最初は試したほうが現実的。

# 3種類のモデルでGridSearchCVを使ったために異常な組み合わせ数になった悪い例
Fitting 5 folds for each of 324000 candidates, totalling 1620000 fits
[CV] xg__colsample_bytree=0.5, rf__random_state=0, xg__learning_rate=0.5, rf__n_estimators=5, rf__n_jobs=1, xg__n_estimators=50, rf__max_depth=3, rf__min_samples_split=3, rf__max_features=3, xg__max_depth=3, lg__C=1.0 
[...]

またパラメータチューニングに用いる辞書は、VotingClassifierで利用するモデル以外の不要なモデル名のパラメータを入れているとエラーとなる。下記はlrというモデル名が無いにも関わらずパラメータに"lr__"を含めていた場合のエラー例。

ValueError: Invalid parameter lr for estimator VotingClassifier(estimators=[(
[...]

ソースコード

xgboostとRandomForest、LogisticRegressionの3つをVotingClassifierの入力とした例。各モデルのパラメータはVotingしない場合と書き方を統一するために、辞書同士をマージするときに名前を付けなおしている。

from sklearn.ensemble import VotingClassifier
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV


xg = xgb.XGBClassifier()
rf = RandomForestClassifier()
lr = LogisticRegression()

xg_param = {
    "n_estimators": [50, 100, 150],
    "max_depth": [3, 6, 9],
    "colsample_bytree": [0.5, 0.9, 1.0],
    "learning_rate": [0.5, 0.9, 1.0]
}
rf_param = {
    "n_estimators": [5, 10, 50, 100, 300],
    "max_features": [3, 5, 10, 15, 20],
    "min_samples_split": [3, 5, 10, 20],
    "max_depth": [3, 5, 10, 20]
}
lr_param = {
    "C": list(np.logspace(0, 4, 10))
}

params = {}
params.update({"xg__" + k: v for k, v in xg_param.items()})
params.update({"rf__" + k: v for k, v in rf_param.items()})
params.update({"lr__" + k: v for k, v in lr_param.items()})

eclf = VotingClassifier(estimators=[("xg", xg),
                                    ("rf", rf),
                                    ("lr", lr)],
                        voting="soft")

clf = RandomizedSearchCV(eclf,
                         param_distributions=params,
                         cv=5,
                         n_iter=100,
                         n_jobs=1,
                         verbose=2)
clf.fit(X_train, y_train)
predict = clf.predict(X_test)

参考

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7