scikit-learnによる多クラスSVM

目的

scikit-learnのSVM(SVC)は,多クラス分類を行うとき,one-versus-oneで分類する.
しかし,one-versus-the-restの方が識別性能がいい場合がある(多い,という報告を見かける)ので,
sklearn.multiclassのOneVsRestClassifierを使った
one-versus-the-restでの多クラスSVM分類の仕方をメモしておく.
(注)ただし,LinearSVCはデフォルトでone-versus-the-restを採用している.

One-versus-the-restとOne-versus-one

$K$クラス分類問題を考える.

One-versus-the-rest

ある特定のクラスに入るか,他の$K-1$個のクラスのどれかに入るかの2クラス分類問題を解く分類器を$K$個利用する.

One-versus-one

ある特定のクラスに入るか,また別の特定のクラスに入るかの2クラス分類問題を解く分類器を$K(K-1)/2$個利用する.

多クラスSVM

digitsデータセットを使い,手書き文字の10クラス分類をRBFカーネルのSVMで行う.

パッケージのインポート

from sklearn.datasets import load_digits
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.cross_validation import train_test_split
from sklearn.metrics import accuracy_score

データの読み込み

digits = load_digits()
train_x, test_x, train_y, test_y  = train_test_split(digits.data, digits.target)

ハイパーパラメータの設定

C = 1.
kernel = 'rbf'
gamma  = 0.01

One-versus-the-restによる識別

estimator = SVC(C=C, kernel=kernel, gamma=gamma)
classifier = OneVsRestClassifier(estimator)
classifier.fit(train_x, train_y)
pred_y = classifier.predict(test_x)

one-versus-the-oneによる識別(デフォルト)

classifier2 = SVC(C=C, kernel=kernel, gamma=gamma)
classifier2.fit(train_x, train_y)
pred_y2 = classifier2.predict(test_x)

識別結果

print 'One-versus-the-rest: {:.5f}'.format(accuracy_score(test_y, pred_y))
print 'One-versus-one: {:.5f}'.format(accuracy_score(test_y, pred_y2))

One-versus-the-rest: 0.95333
One-versus-one: 0.79111

One-versus-the-restの方が高い識別性能を示している.

関連リンク

pylearn2.models.svm(sklearnのwraper)
sklearn.multiclass.OneVsRestClassifier
Ex. sklearn.multiclass.OneVsRestClassifier
sklearn.svm

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.