0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ダブルクロスバリデーション(sklearn)

Last updated at Posted at 2021-11-29

load

from sklearn.datasets import load_diabetes
from sklearn.model_selection import GridSearchCV, KFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectFromModel
from sklearn.ensemble import RandomForestRegressor

X, y = load_diabetes(return_X_y=True)

estimatorのハイパーパラメータの探索

estimator = RandomForestRegressor()
param_grid = {'n_estimators':[10,50,100], 'max_depth':[3,5,10]}
tune_cv = KFold(3, shuffle=True)
evaluate_cv = KFold(5, shuffle=True)
# for cross-validation
gsv_cv = GridSearchCV(estimator, param_grid, cv=tune_cv)

# for double cross-validation
gsv_dcv = GridSearchCV(estimator, param_grid, cv=tune_cv)

クロスバリデーションによるハイパーパラメータ最適化&評価

gsv_cv.fit(X,y)
best_estimator = gsv_cv.best_estimator_
score = cross_val_score(best_estimator , X, y, cv=evaluate_cv)
# 上と下のコードは同じです。
score = cross_val_score(gsv_cv.best_estaimator_, X, y, cv=evaluate_cv)

このコードでは

  1. GridSearchCV で最適なハイパーパラメータを決定する
  2. そのハイパーパラメータを用いてcross_val_scoreを求める
     という手順になり、不適切な評価結果となる可能性があります。

詳細は以下のサイト等をご参照ください。

クロスバリデーションの結果がよくなるようにハイパーパラメータの値を決めているからです。つまり、クロスバリデーションの結果を実測値にフィッティング (適合) してハイパーパラメータの値を決定しているようなものです。

ダブルクロスバリデーション(モデルクロスバリデーション)でテストデータいらず~サンプルが少ないときのモデル検証~

scikit-learnのドキュメントにも過適合の危険性について記載されています。
https://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html

ダブルクロスバリデーションによるハイパーパラメータ最適化&評価

情報のleakageを避け、正しくモデルの評価を行うためにはdouble cross-validation (nested cross-validationとも)が必要です。
この場合、「モデル」とは「GridSearchCVでハイパーパラメータを最適化する手順も含めた機械学習モデル(GridSearchCV(estimator, param_grid))」のことを指します。

scikit-learnは上手く設計されているため、cross_val_score()にGridSearchCVを渡すとそれだけでdouble cross validationの評価値を得ることが出来ます。

score = cross_val_score(gsv_dcv , X, y, cv=evaluate_cv)
# 上と下のコードは同じ内容です。
score = cross_val_score(GridSearchCV(RandomForestRegressor(), param_grid={'n_estimators':[10,50,100], 'max_depth':[3,5,10]})) , X, y, cv=KFold(3, shuffle=True))

なぜ cross_val_score(gsv_dcv , X, y, cv=evaluate_cv) の一行で
ダブルクロスバリデーションが評価できるのかは、詳細なscikit-learnの仕様説明が必要になるので割愛しますが、
ざっくり説明すると以下の仕様が効いてきます。

・cross_val_scoreでは各foldごとに独立したestimatorを用い、fit_predictを行うこと
・GridSearchCVはfitにより、best_estimator_を定め、predictではbest_estimator_が用いられること
・cloneは__init__時に定義済のパラメータのみコピーし、best_estimator_などはコピーされないこと
 などです。

詳しく知りたい方はドキュメントやコードを参照ください。

GridSearchCV.predict 仕様
GridSearchCV.predict ソース
GridSearchCV.fit
clone

Pipelineを用いる場合

selectorのハイパーパラメータの探索

selector = SelectFromModel(RandomForestRegressor())
selector_param_grid = {'selector__threshold':[0.1,0.2]}
pipe = Pipeline([('selector', selector), ('estimator', estimator)])
param_grid = {**selector_param_grid}    #python 3.5以上

tune_cv = KFold(3, shuffle=True)
evaluate_cv = KFold(5, shuffle=True)

# for cross-validation
gsv_cv = GridSearchCV(pipe , param_grid, cv=tune_cv)

# for double cross-validation
gsv_dcv = GridSearchCV(pipe , param_grid, cv=tune_cv)

クロスバリデーションによるハイパーパラメータ最適化&評価

gsv_cv.fit(X,y)
if True:
    [setattr(i[1], 'prefit',True) for i in gsv_cv.estimator.steps
        if hasattr(i[1],'prefit')]
best_pipe = gsv_cv.best_estimator_
score = cross_val_score(best_pipe , X, y, cv=evaluate_cv)

ダブルクロスバリデーションによるハイパーパラメータ最適化&評価

score = cross_val_score(gsv, X, y, cv=evaluate_cv)

pipelineを用いる場合も同じ様に一行でダブルクロスバリデーションを評価することが出来ます。

TIPS

GridSearchCVでcvにLeaveOneOutCV()を指定するとエラーとなります。
これはデフォルトではmetricsにr2_scoreが指定されており、n=1ではr2を定義できないためです。
代わりにmetricsにneg_mean_squared_errorを指定するとエラー回避できます。

他に指定可能なmetrics:https://qiita.com/shnchr/items/f5066021b7143566f950

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?