LoginSignup
7
12

More than 5 years have passed since last update.

scikit-learnでグリッドサーチ結果をもとに交差検定(クロスバリデーション)させてみた

Posted at

はじめに

グリッドサーチや交差検定については色々なところで紹介されていますが、グリッドサーチした結果をもとに交差検定する方法を紹介しているところがなかったのでここで紹介します。

環境

  • python: 2.7.6
  • scikit-learn: 0.17.1

内容

グリッドサーチした結果をもとに交差検定する方法を紹介します。

実装

データセットを取得

まずは機械学習するデータを取得します。scikit-learnはあらかじめデータセットが用意されているので、とりあえず機械学習を始めてみるのに最適です。
データセットの詳細はこちらのサイトで分かりやすくまとめられています。

データセット
# データセットを取得
iris = datasets.load_iris()

グリッドサーチ

次にグリッドサーチします。グリッドサーチするパラメータを設定して実施します。簡単にグリッドサーチできることもscikit-learnの魅力の一つです。
パラメータをハイパーパラメータと呼ぶとカッコイイ♪

グリッドサーチ
# グリッドサーチするパラメータを設定
parameters = {
    'C':[1, 3, 5],
    'loss':('hinge', 'squared_hinge')
}

# グリッドサーチを実行
clf = grid_search.GridSearchCV(svm.LinearSVC(), parameters)
clf.fit(iris.data, iris.target)

# グリッドサーチ結果(最適パラメータ)を取得
GS_loss, GS_C = clf.best_params_.values()
print "最適パラメータ:{}".format(clf.best_params_)

最適パラメータをそれぞれ'GS_loss'と'GS_C'に代入しています。
最適パラメータを取得する前に一度表示させてパラメータの順番を確認するのがよいと思います。パラメータの順番は公式サイト(sklearn.svm.LinearSVC)のParametersの順番ではなさそうなので…

交差検定(クロスバリデーション)

最後にグリッドサーチした結果をもとに交差検定します。

交差検定
# 交差検定(クロスバリデーション)を実行
clf = svm.LinearSVC(loss=GS_loss, C=GS_C)
score = cross_validation.cross_val_score(clf, iris.data, iris.target, cv=5)

# 交差検定の結果を表示
print "正解率(平均):{}".format(score.mean())
print "正解率(最小):{}".format(score.min())
print "正解率(最大):{}".format(score.max())
print "正解率(標準偏差):{}".format(score.std())
print "正解率(全て):{}".format(score)

コード全体

全体
# -*- coding: utf-8 -*-
from sklearn import datasets
from sklearn import svm
from sklearn import grid_search
from sklearn import cross_validation

# main
if __name__ == "__main__":
    # データセットを取得
    iris = datasets.load_iris()

    # グリッドサーチするパラメータを設定
    parameters = {
        'C':[1, 3, 5],
        'loss':('hinge', 'squared_hinge')
    }

    # グリッドサーチを実行
    clf = grid_search.GridSearchCV(svm.LinearSVC(), parameters)
    clf.fit(iris.data, iris.target)

    # グリッドサーチ結果(最適パラメータ)を取得
    GS_loss, GS_C = clf.best_params_.values()
    print "最適パラメータ:{}".format(clf.best_params_)

    # 交差検定(クロスバリデーション)を実行
    clf = svm.LinearSVC(loss=GS_loss, C=GS_C)
    score = cross_validation.cross_val_score(clf, iris.data, iris.target, cv=5)

    # 交差検定の結果を表示
    print "正解率(平均):{}".format(score.mean())
    print "正解率(最小):{}".format(score.min())
    print "正解率(最大):{}".format(score.max())
    print "正解率(標準偏差):{}".format(score.std())
    print "正解率(全て):{}".format(score)

実行結果

実行結果
最適パラメータ{'loss': 'squared_hinge', 'C': 1}
正解率(平均)0.966666666667
正解率(最小)0.9
正解率(最大)1.0
正解率(標準偏差)0.0421637021356
正解率(全て)[ 1.          1.          0.93333333  0.9         1.        ]

まとめ

グリッドサーチした結果がLinearSVC()のデフォルトと一緒だったのは少し残念でしたが、ひとまずグリッドサーチの結果を使用して交差検定することができました。
英語アレルギーなので公式サイトを見ながら学習するのは一苦労でした。

参考

scikit-learn に付属しているデータセット

公式サイト(sklearn.svm.LinearSVC)

pythonの機械学習ライブラリscikit-learnの紹介

7
12
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
12