121
121

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Scikit-learnでハイパーパラメータのグリッドサーチ

Last updated at Posted at 2014-01-30

使い方忘れるのでメモ.
Scikit-learnのドキュメントのサンプルを少し改変したものとその実行結果.

ソースコード:

grid_search.py
# -*- coding: utf-8 -*-

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC

## データの読み込み
digits = datasets.load_digits()
X = digits.data
y = digits.target

## トレーニングデータとテストデータに分割.
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=0)

## チューニングパラメータ
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4],
                     'C': [1, 10, 100, 1000]},
                    {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]

scores = ['accuracy', 'precision', 'recall']

for score in scores:
    print '\n' + '='*50
    print score
    print '='*50

    clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5, scoring=score, n_jobs=-1)
    clf.fit(X_train, y_train)

    print "\n+ ベストパラメータ:\n"
    print clf.best_estimator_

    print"\n+ トレーニングデータでCVした時の平均スコア:\n"
    for params, mean_score, all_scores in clf.grid_scores_:
        print "{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params)

    print "\n+ テストデータでの識別結果:\n"
    y_true, y_pred = y_test, clf.predict(X_test)
    print classification_report(y_true, y_pred)

結果:

==================================================
accuracy
==================================================

+ ベストパラメータ:

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001,
  kernel=rbf, max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

+ トレーニングデータでCVした時の平均スコア:

0.982 (+/- 0.002) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.001}
0.954 (+/- 0.006) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}
0.981 (+/- 0.003) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.0001}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 10}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 100}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1000}

+ テストデータでの識別結果:

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        89
          1       0.97      1.00      0.98        90
          2       0.99      0.98      0.98        92
          3       1.00      0.99      0.99        93
          4       1.00      1.00      1.00        76
          5       0.99      0.98      0.99       108
          6       0.99      1.00      0.99        89
          7       0.99      1.00      0.99        78
          8       1.00      0.98      0.99        92
          9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899


==================================================
precision
==================================================

+ ベストパラメータ:

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001,
  kernel=rbf, max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

+ トレーニングデータでCVした時の平均スコア:

0.983 (+/- 0.002) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.001}
0.959 (+/- 0.006) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}
0.982 (+/- 0.003) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.001}
0.985 (+/- 0.004) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.001}
0.985 (+/- 0.004) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.0001}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 1}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 10}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 100}
0.973 (+/- 0.005) for {'kernel': 'linear', 'C': 1000}

+ テストデータでの識別結果:

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        89
          1       0.97      1.00      0.98        90
          2       0.99      0.98      0.98        92
          3       1.00      0.99      0.99        93
          4       1.00      1.00      1.00        76
          5       0.99      0.98      0.99       108
          6       0.99      1.00      0.99        89
          7       0.99      1.00      0.99        78
          8       1.00      0.98      0.99        92
          9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899


==================================================
recall
==================================================

+ ベストパラメータ:

SVC(C=10, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.001,
  kernel=rbf, max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

+ トレーニングデータでCVした時の平均スコア:

0.982 (+/- 0.002) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.001}
0.954 (+/- 0.006) for {'kernel': 'rbf', 'C': 1, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}
0.981 (+/- 0.003) for {'kernel': 'rbf', 'C': 10, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 100, 'gamma': 0.0001}
0.986 (+/- 0.002) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.001}
0.983 (+/- 0.005) for {'kernel': 'rbf', 'C': 1000, 'gamma': 0.0001}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 10}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 100}
0.971 (+/- 0.006) for {'kernel': 'linear', 'C': 1000}

+ テストデータでの識別結果:

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        89
          1       0.97      1.00      0.98        90
          2       0.99      0.98      0.98        92
          3       1.00      0.99      0.99        93
          4       1.00      1.00      1.00        76
          5       0.99      0.98      0.99       108
          6       0.99      1.00      0.99        89
          7       0.99      1.00      0.99        78
          8       1.00      0.98      0.99        92
          9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899

GridSearchCVのパラメータの説明
cv
fold数

scoring
グリードサーチで最適化する値を決められる.
デフォルトでは,
classificationで’accuracy’sklearn.metrics.accuracy_score
regressionで’r2’sklearn.metrics.r2_scoreが指定されている.
他にも例えばclassificationでは’precision’や’recall’等を指定できる.

詳しくはここ
precision, recall等については朱鷺の杜Wiki

n_jobs
ここに整数値を入れるだけで並列に計算してくれる.
-1でコア数を自動で入れてくれる.

121
121
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
121
121

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?