Help us understand the problem. What is going on with this article?

scikit-learnでCross Validation

More than 5 years have passed since last update.

scikit-learnCross ValidationとGrid Searchをやってみた。

Cross Validation

詳しいことはWikipediaに書いてある。
Cross Validationはモデルの妥当性を検証する方法のひとつ。一般的に開発用のデータは訓練データと検証データに分かれる。
しかし、このまま行ってしまうと折角の訓練データが減ってしまうことになる上に、訓練データの選び方によって汎化性能が下がってしまう可能性がある。
Wikipediaに書いてあるもののホールド・アウト検定がこれに当たる。一般にはこれはCross Validationにはあたらない。

ここに書いてあるK-分割交差検定がこれに当たる。K-分割交差検定では開発用のデータをK個に分割しK-1個を訓練用に、残りの一つを検証用に使いモデルの正当性を計算する。
これにより使える訓練データが増えると同時に、これらを訓練データを変えることにより、汎化性能を上げることができる。

scikit-learnで具体的にどのように行うのか書いてみた。訓練に使ったデータとしてはKaggleのData Science Londonで出されているものを用いた。

SVM

まずは単純にサポートベクターマシンでクラス分けをさせた時のコード

# -*- coding: utf-8 -*-

import os
import sys
from sklearn import svm
import numpy as np
import csv

if __name__ == "__main__":
    train_feature_file = np.genfromtxt(open("../data/train.csv", "rb"), delimiter=",", dtype=float)
    train_label_file = np.genfromtxt(open("../data/trainLabels.csv", "rb"), delimiter=",", dtype=float)

    train_features = []
    train_labels = []
    for train_feature, train_label in zip(train_feature_file, train_label_file):
        train_features.append(train_feature)
        train_labels.append(train_label)

    train_features = np.array(train_features)
    train_labels = np.array(train_labels)

    clf = svm.SVC(C=100, cache_size=200, class_weight=None, coef0=0.0, degree=3,gamma=0.001, kernel="rbf", max_iter=-1, probability=False,random_state=None, shrinking=True, tol=0.001, verbose=False)

    clf.fit(train_features, train_labels)

    test_feature_file = np.genfromtxt(open("../data/test.csv", "rb"), delimiter=",", dtype=float)

    test_features = []
    print "Id,Solution"
    i = 1
    for test_feature in test_feature_file:
        print str(i) + "," + str(int(clf.predict(test_feature)[0]))
        i += 1

このモデルをCross Validationで検証してみる。

def get_score(clf, train_features, train_labels):
    X_train, X_test, y_train, y_test = cross_validation.train_test_split(train_features, train_labels, test_size=0.4, random_state=0)

    clf.fit(X_train, y_train)
    print clf.score(X_test, y_test) 

cross_validation.train_test_splitは一定の割合が検証用データとなるように開発用データを分割する関数。この場合はtest_size=0.4を指定したので、40%のデータを検証用として使うことになる。
fitが60%の訓練データで行うもので、scoreが残された40%のデータで検証を行いその正答率を出してくれる。これがこのモデルの、このテストデータにおける正当性となる。もちろんこれが高ければ高いほどよいが
汎化性能が高いかどうかはここからでは読み取ることができない。そのためK分割を行うことでK回の検証を行うことができる。これらのスコアを平均することで汎化性能も含めたモデルの正当性を表すことができる。

def get_accuracy(clf, train_features, train_labels):
    scores = cross_validation.cross_val_score(clf, train_features, train_labels, cv=10)
    print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

cross_validation_cross_val_scoreでこれらの検証のすべてのscoreを得ることができる。cvはK分割の分割の個数を指定することができる。今回は開発用のデータを10個に分割し10回の検証を行う。
scoresには10個のscoreが入ったリストが返ってくる。これの平均をAccuracyとして出している。これで汎化性能も含めたモデルの正当性を得ることができるが、モデルパラメータのチューニングを手で行う必要がある。
手で調整して、Accuracyを計算するというのは非常に手間なのでGrid Searchというアルゴリズムでこのチューニングをある程度自動化することができる。

Grid Search

パラメータの範囲を指定することで経験的に最適なパラメータの組を探索する方法がGrid Search。Pythonで行うには以下のように書く。

def grid_search(train_features, train_labels):
    param_grid = [
        {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
        {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
    ]

    clf = GridSearchCV(svm.SVC(C=1), param_grid, n_jobs=-1)
    clf.fit(train_features, train_labels)
    print clf.best_estimator_

param_gridに指定することでこの範囲を指定することができる。n_jobsに並列に計算を行うプロセス数を指定することができる。-1を指定するとコア数をデフォルト選ぶようになっている。与えられた訓練データに対してGrid Searchを行う。
時間は少しかかるが、この訓練データに対して最もスコアが高くなるようなモデルパラメータを選ぶことができる。この訓練データを実際のテストデータに使うことができる。

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした