1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

nested k-fold cross validation

Posted at

nested k-fold cross validationの勉強のため
https://axa.biopapyrus.jp/machine-learning/model-evaluation/nested-k-fold-cross-validation.html
のコードを使用させていただいた。ここに勉強した内容をまとめておく。

Google Colaboratoryで実行しようとすると
9行目の
from sklearn.grid_search import GridSearchCV
ではエラーが出たので
from sklearn.model_selection import GridSearchCV
と変更した。

他にも最後のところで少し誤記があったので修正し、下記のコードになった。

import numpy as np
from sklearn import datasets
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
#from sklearn.grid_search import GridSearchCV
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score


# load data
cancer = datasets.load_breast_cancer()
x = cancer.data
y = cancer.target

print(x.shape)
## (569, 30)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
print(x_train.shape)
## (455, 30)


# SVM model
ppln_svm = Pipeline([
               ('scale', StandardScaler()),
               ('pca', PCA(0.80)),
               ('clf', SVC())
           ])

# SVM model hyperparameters
param_grid_svm = [
    {
        'clf__kernel': ['rbf'],
        'clf__C': 10 ** np.linspace(-5, 5, 20),
        'clf__gamma': 10 ** np.linspace(-5, 5, 20)
    }
]

# Random Forest model
ppln_rf = Pipeline([
               ('scale', StandardScaler()),
               ('pca', PCA(0.80)),
               ('clf', RandomForestClassifier())
           ])

# Random Forest model hyperparameters
param_grid_rf = [
    {'clf__max_depth': [2, 3, 4, 5, 6, 7, 8]}
]


# grid search in inner-loop
gs_svm = GridSearchCV(estimator=ppln_svm, param_grid=param_grid_svm, scoring='f1', cv=2, n_jobs=1)
gs_rf = GridSearchCV(estimator=ppln_rf, param_grid=param_grid_rf, scoring='f1', cv=2, n_jobs=1)


# validate model in outer-loop
scores_svm = cross_val_score(gs_svm, x_train, y_train, scoring='f1', cv=10)
scores_rf = cross_val_score(gs_rf, x_train, y_train, scoring='f1', cv=10)

print('SVM: %.2f±%.2f' % (np.mean(scores_svm), np.std(scores_svm)))
## SVM: 0.97±0.02

print('RF: %.2f±%.2f' % (np.mean(scores_rf), np.std(scores_rf)))
## RF: 0.95±0.03

##メモ

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?