scikit-learnによる多クラスSVM
目的
scikit-learnのSVM(SVC)は,多クラス分類を行うとき,one-versus-oneで分類する.
しかし,one-versus-the-restの方が識別性能がいい場合がある(多い,という報告を見かける)ので,
sklearn.multiclassのOneVsRestClassifierを使った
one-versus-the-restでの多クラスSVM分類の仕方をメモしておく.
(注)ただし,LinearSVCはデフォルトでone-versus-the-restを採用している.
One-versus-the-restとOne-versus-one
$K$クラス分類問題を考える.
One-versus-the-rest
ある特定のクラスに入るか,他の$K-1$個のクラスのどれかに入るかの2クラス分類問題を解く分類器を$K$個利用する.
One-versus-one
ある特定のクラスに入るか,また別の特定のクラスに入るかの2クラス分類問題を解く分類器を$K(K-1)/2$個利用する.
多クラスSVM
digitsデータセットを使い,手書き文字の10クラス分類をRBFカーネルのSVMで行う.
パッケージのインポート
from sklearn.datasets import load_digits
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.cross_validation import train_test_split
from sklearn.metrics import accuracy_score
データの読み込み
digits = load_digits()
train_x, test_x, train_y, test_y = train_test_split(digits.data, digits.target)
ハイパーパラメータの設定
C = 1.
kernel = 'rbf'
gamma = 0.01
One-versus-the-restによる識別
estimator = SVC(C=C, kernel=kernel, gamma=gamma)
classifier = OneVsRestClassifier(estimator)
classifier.fit(train_x, train_y)
pred_y = classifier.predict(test_x)
one-versus-the-oneによる識別(デフォルト)
classifier2 = SVC(C=C, kernel=kernel, gamma=gamma)
classifier2.fit(train_x, train_y)
pred_y2 = classifier2.predict(test_x)
識別結果
print 'One-versus-the-rest: {:.5f}'.format(accuracy_score(test_y, pred_y))
print 'One-versus-one: {:.5f}'.format(accuracy_score(test_y, pred_y2))
One-versus-the-rest: 0.95333
One-versus-one: 0.79111
One-versus-the-restの方が高い識別性能を示している.
関連リンク
pylearn2.models.svm(sklearnのwraper)
sklearn.multiclass.OneVsRestClassifier
Ex. sklearn.multiclass.OneVsRestClassifier
sklearn.svm