Edited at

scikit-learnによる多クラスSVM

More than 1 year has passed since last update.


scikit-learnによる多クラスSVM


目的

scikit-learnのSVM(SVC)は,多クラス分類を行うとき,one-versus-oneで分類する.

しかし,one-versus-the-restの方が識別性能がいい場合がある(多い,という報告を見かける)ので,

sklearn.multiclassのOneVsRestClassifierを使った

one-versus-the-restでの多クラスSVM分類の仕方をメモしておく.

(注)ただし,LinearSVCはデフォルトでone-versus-the-restを採用している.


One-versus-the-restとOne-versus-one

$K$クラス分類問題を考える.


One-versus-the-rest

ある特定のクラスに入るか,他の$K-1$個のクラスのどれかに入るかの2クラス分類問題を解く分類器を$K$個利用する.


One-versus-one

ある特定のクラスに入るか,また別の特定のクラスに入るかの2クラス分類問題を解く分類器を$K(K-1)/2$個利用する.


多クラスSVM

digitsデータセットを使い,手書き文字の10クラス分類をRBFカーネルのSVMで行う.


パッケージのインポート

from sklearn.datasets import load_digits

from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.cross_validation import train_test_split
from sklearn.metrics import accuracy_score


データの読み込み

digits = load_digits()

train_x, test_x, train_y, test_y = train_test_split(digits.data, digits.target)


ハイパーパラメータの設定

C = 1.

kernel = 'rbf'
gamma = 0.01


One-versus-the-restによる識別

estimator = SVC(C=C, kernel=kernel, gamma=gamma)

classifier = OneVsRestClassifier(estimator)
classifier.fit(train_x, train_y)
pred_y = classifier.predict(test_x)


one-versus-the-oneによる識別(デフォルト)

classifier2 = SVC(C=C, kernel=kernel, gamma=gamma)

classifier2.fit(train_x, train_y)
pred_y2 = classifier2.predict(test_x)


識別結果

print 'One-versus-the-rest: {:.5f}'.format(accuracy_score(test_y, pred_y))

print 'One-versus-one: {:.5f}'.format(accuracy_score(test_y, pred_y2))

One-versus-the-rest: 0.95333

One-versus-one: 0.79111

One-versus-the-restの方が高い識別性能を示している.


関連リンク

pylearn2.models.svm(sklearnのwraper)

sklearn.multiclass.OneVsRestClassifier

Ex. sklearn.multiclass.OneVsRestClassifier

sklearn.svm