Help us understand the problem. What is going on with this article?

Scikit learnより グリッドサーチによるパラメータ最適化

Grid search とは

scikit learnにはグリッドサーチなる機能がある。機械学習モデルのハイパーパラメータを自動的に最適化してくれるというありがたい機能。例えば、SVMならCや、kernelやgammaとか。Scikit-learnのユーザーガイドより、今回参考にしたのはこちら。

やったこと

  • 手書き数字(0~9)のデータセットdigitsをSVMで分類
  • GridSearchCVを使って、交差検定でハイパーパラメータを最適化
  • 最適化時のモデルの評価関数にはf1を使用

データの準備

手書き数字のdigitsをインポート。

from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.svm import SVC

digits = datasets.load_digits()
n_samples = len(digits.images) # 標本数 1797個
X = digits.images.reshape((n_samples, -1)) # 8x8の配列から64次元のベクトルに変換
y = digits.target # 正解ラベル

train_test_splitで、データセットをトレーニング用とテスト用に2分割。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
print(X.shape)
>>> (1797, 64)
print(X_train.shape)
>>> (898, 64)
print(X_test.shape)
>>> (899, 64)

最適化したいパラメータの設定

次に、最適化したいパラメータをリストで定義。例題のものに加えて、polyとsigmoidのkernelを追加。

tuned_parameters = [
    {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
    {'C': [1, 10, 100, 1000], 'kernel': ['rbf'], 'gamma': [0.001, 0.0001]},
    {'C': [1, 10, 100, 1000], 'kernel': ['poly'], 'degree': [2, 3, 4], 'gamma': [0.001, 0.0001]},
    {'C': [1, 10, 100, 1000], 'kernel': ['sigmoid'], 'gamma': [0.001, 0.0001]}
    ]

最適化の実行

GridSearchCVを使って、上で定義したパラメータを最適化。指定した変数は、使用するモデル、最適化したいパラメータセット、交差検定の回数、モデルの評価値の4つ。評価値はf1とした。precisionrecallでもOK。詳しくはこちら

score = 'f1'
clf = GridSearchCV(
    SVC(), # 識別器
    tuned_parameters, # 最適化したいパラメータセット 
    cv=5, # 交差検定の回数
    scoring='%s_weighted' % score ) # モデルの評価関数の指定

トレーニング用データセットのみを使い、最適化を実行。パラメータセットなどが表示される。

clf.fit(X_train, y_train) 
>>> GridSearchCV(cv=5, error_score='raise',
>>>       estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
>>>  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
>>>  max_iter=-1, probability=False, random_state=None, shrinking=True,
>>>  tol=0.001, verbose=False),
>>>       fit_params={}, iid=True, n_jobs=1,
>>>       param_grid=[{'kernel': ['linear'], 'C': [1, 10, 100, 1000]}, {'kernel': ['rbf'], 'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}, {'kernel': ['poly'], 'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'degree': [2, 3, 4]}, {'kernel': ['sigmoid'], 'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001]}],
       pre_dispatch='2*n_jobs', refit=True, scoring='f1_weighted',
       verbose=0)

結果の表示

clf.grid_scores_で各試行でのスコアを確認できる。

clf.grid_scores_
>>> [mean: 0.97311, std: 0.00741, params: {'kernel': 'linear', 'C': 1},
>>> mean: 0.97311, std: 0.00741, params: {'kernel': 'linear', 'C': 10},
>>> mean: 0.97311, std: 0.00741, params: {'kernel': 'linear', 'C': 100}, 
>>> ...
>>>  mean: 0.96741, std: 0.00457, params: {'kernel': 'sigmoid', 'C': 1000, 'gamma': 0.0001}]

clf.best_params_で最適化したパラメータを確認できる。

clf.best_params_
{'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}

各試行のスコアをテキストで、最適化したパラメータでの精度をテーブルで表示。

print("# Tuning hyper-parameters for %s" % score)
print()
print("Best parameters set found on development set: %s" % clf.best_params_)
print()

# それぞれのパラメータでの試行結果の表示
print("Grid scores on development set:")
print()
for params, mean_score, scores in clf.grid_scores_:
    print("%0.3f (+/-%0.03f) for %r"
          % (mean_score, scores.std() * 2, params))
print()

# テストデータセットでの分類精度を表示
print("The scores are computed on the full evaluation set.")
print()
y_true, y_pred = y_test, clf.predict(X_test)
print(classification_report(y_true, y_pred))

こんなふうに出力されます。

# Tuning hyper-parameters for f1

Best parameters set found on development set: {'kernel': 'rbf', 'C': 10, 'gamma': 0.001}

Grid scores on development set:

0.973 (+/-0.015) for {'kernel': 'linear', 'C': 1}
0.973 (+/-0.015) for {'kernel': 'linear', 'C': 10}
0.973 (+/-0.015) for {'kernel': 'linear', 'C': 100}

 [長いので中略]

0.967 (+/-0.009) for {'kernel': 'sigmoid', 'C': 1000, 'gamma': 0.0001}

The scores are computed on the full evaluation set.

             precision    recall  f1-score   support

          0       1.00      1.00      1.00        89
          1       0.97      1.00      0.98        90
          2       0.99      0.98      0.98        92
          3       1.00      0.99      0.99        93
          4       1.00      1.00      1.00        76
          5       0.99      0.98      0.99       108
          6       0.99      1.00      0.99        89
          7       0.99      1.00      0.99        78
          8       1.00      0.98      0.99        92
          9       0.99      0.99      0.99        92

avg / total       0.99      0.99      0.99       899
yhyhyhjp
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away