122

More than 5 years have passed since last update.

posted at

updated at

# Scikit-learnでハイパーパラメータのグリッドサーチ

Scikit-learnのドキュメントのサンプルを少し改変したものとその実行結果．

ソースコード:

grid_search.py
``````# -*- coding: utf-8 -*-

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC

## データの読み込み
X = digits.data
y = digits.target

## トレーニングデータとテストデータに分割．
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.5, random_state=0)

## チューニングパラメータ
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
'C': [1, 10, 100, 1000]},
{'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]

scores = ['accuracy', 'precision', 'recall']

for score in scores:
print '\n' + '='*50
print score
print '='*50

clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5, scoring=score, n_jobs=-1)
clf.fit(X_train, y_train)

print "\n+ ベストパラメータ:\n"
print clf.best_estimator_

print"\n+ トレーニングデータでCVした時の平均スコア:\n"
for params, mean_score, all_scores in clf.grid_scores_:
print "{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params)

print "\n+ テストデータでの識別結果:\n"
y_true, y_pred = y_test, clf.predict(X_test)
print classification_report(y_true, y_pred)
``````

``````==================================================
accuracy
==================================================

+ ベストパラメータ:

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001,
kernel=rbf, max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)

+ トレーニングデータでCVした時の平均スコア:

0.982 (+/- 0.002) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.001}
0.954 (+/- 0.006) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}
0.981 (+/- 0.003) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.0001}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 10}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 100}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1000}

+ テストデータでの識別結果:

precision    recall  f1-score   support

0       1.00      1.00      1.00        89
1       0.97      1.00      0.98        90
2       0.99      0.98      0.98        92
3       1.00      0.99      0.99        93
4       1.00      1.00      1.00        76
5       0.99      0.98      0.99       108
6       0.99      1.00      0.99        89
7       0.99      1.00      0.99        78
8       1.00      0.98      0.99        92
9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899

==================================================
precision
==================================================

+ ベストパラメータ:

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001,
kernel=rbf, max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)

+ トレーニングデータでCVした時の平均スコア:

0.983 (+/- 0.002) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.001}
0.959 (+/- 0.006) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}
0.982 (+/- 0.003) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.001}
0.985 (+/- 0.004) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.001}
0.985 (+/- 0.004) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.0001}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 1}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 10}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 100}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 1000}

+ テストデータでの識別結果:

precision    recall  f1-score   support

0       1.00      1.00      1.00        89
1       0.97      1.00      0.98        90
2       0.99      0.98      0.98        92
3       1.00      0.99      0.99        93
4       1.00      1.00      1.00        76
5       0.99      0.98      0.99       108
6       0.99      1.00      0.99        89
7       0.99      1.00      0.99        78
8       1.00      0.98      0.99        92
9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899

==================================================
recall
==================================================

+ ベストパラメータ:

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001,
kernel=rbf, max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)

+ トレーニングデータでCVした時の平均スコア:

0.982 (+/- 0.002) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.001}
0.954 (+/- 0.006) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}
0.981 (+/- 0.003) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.0001}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 10}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 100}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1000}

+ テストデータでの識別結果:

precision    recall  f1-score   support

0       1.00      1.00      1.00        89
1       0.97      1.00      0.98        90
2       0.99      0.98      0.98        92
3       1.00      0.99      0.99        93
4       1.00      1.00      1.00        76
5       0.99      0.98      0.99       108
6       0.99      1.00      0.99        89
7       0.99      1.00      0.99        78
8       1.00      0.98      0.99        92
9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899
``````

`GridSearchCV`のパラメータの説明
cv
fold数

scoring
グリードサーチで最適化する値を決められる．
デフォルトでは，
classificationで’accuracy’`sklearn.metrics.accuracy_score`
regressionで’r2’`sklearn.metrics.r2_score`が指定されている．

precision, recall等については朱鷺の杜Wiki

n_jobs
ここに整数値を入れるだけで並列に計算してくれる．
-1でコア数を自動で入れてくれる．

Register as a new user and use Qiita more conveniently

1. You get articles that match your needs
2. You can efficiently read back useful information
What you can do with signing up
122