LoginSignup
119
114

More than 5 years have passed since last update.

scikit-learnによる多クラスSVM

Last updated at Posted at 2014-07-18

scikit-learnによる多クラスSVM

目的

scikit-learnのSVM(SVC)は,多クラス分類を行うとき,one-versus-oneで分類する.
しかし,one-versus-the-restの方が識別性能がいい場合がある(多い,という報告を見かける)ので,
sklearn.multiclassのOneVsRestClassifierを使った
one-versus-the-restでの多クラスSVM分類の仕方をメモしておく.
(注)ただし,LinearSVCはデフォルトでone-versus-the-restを採用している.

One-versus-the-restとOne-versus-one

$K$クラス分類問題を考える.

One-versus-the-rest

ある特定のクラスに入るか,他の$K-1$個のクラスのどれかに入るかの2クラス分類問題を解く分類器を$K$個利用する.

One-versus-one

ある特定のクラスに入るか,また別の特定のクラスに入るかの2クラス分類問題を解く分類器を$K(K-1)/2$個利用する.

多クラスSVM

digitsデータセットを使い,手書き文字の10クラス分類をRBFカーネルのSVMで行う.

パッケージのインポート

from sklearn.datasets import load_digits
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.cross_validation import train_test_split
from sklearn.metrics import accuracy_score

データの読み込み

digits = load_digits()
train_x, test_x, train_y, test_y  = train_test_split(digits.data, digits.target)

ハイパーパラメータの設定

C = 1.
kernel = 'rbf'
gamma  = 0.01

One-versus-the-restによる識別

estimator = SVC(C=C, kernel=kernel, gamma=gamma)
classifier = OneVsRestClassifier(estimator)
classifier.fit(train_x, train_y)
pred_y = classifier.predict(test_x)

one-versus-the-oneによる識別(デフォルト)

classifier2 = SVC(C=C, kernel=kernel, gamma=gamma)
classifier2.fit(train_x, train_y)
pred_y2 = classifier2.predict(test_x)

識別結果

print 'One-versus-the-rest: {:.5f}'.format(accuracy_score(test_y, pred_y))
print 'One-versus-one: {:.5f}'.format(accuracy_score(test_y, pred_y2))

One-versus-the-rest: 0.95333
One-versus-one: 0.79111

One-versus-the-restの方が高い識別性能を示している.

関連リンク

pylearn2.models.svm(sklearnのwraper)
sklearn.multiclass.OneVsRestClassifier
Ex. sklearn.multiclass.OneVsRestClassifier
sklearn.svm

119
114
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
119
114