Edited at
Kazoo04Day 11

高速、高精度、省メモリな線形分類器、SCW

More than 3 years have passed since last update.

今回は機械学習アルゴリズム、SCW(Exact Soft Confidence-Weighted Learning)の紹介です。

まずはどれだけすごいか見てみてください。

SCW    time:0.003194    accuracy:1.000

SVC time:0.010297 accuracy:0.903

使用しているデータセットはscikit-learnの手書き文字認識用のものです。

上がSCW、下がscikit-learnのSVCで学習、分類した結果です。timeは学習にかかった時間、accuracyは精度を表しています。

結果を見ればわかるように、SCWは非常に高速に学習することができます。

また、SCWは逐次学習が可能です。すなわち、データをひとつずつ入力しても学習することができます。つまり、データを全てメモリ上に展開して学習させなくてもよいのです。

精度はデータセットに依存します。というのも、SCWは線形分類器だからです。

線形分離不可能なデータに対してはSCWでは精度が落ちてしまいますが、線形分離可能、もしくはそれに近いかたちで分布しているデータに対しては高い精度を得ることができます。

scikit-learnの手書き文字認識データセットは線形分離可能だったようで、精度100%という結果が得られました。

アルゴリズムについてですが、ここでは解説は行いません。

すでにかずー氏がアルゴリズムの解説をブログに上げているのでそちらをご覧ください。

また、もとの論文はこちらにあります。

SCWのコードは私のGitHubに置いてあります。

以下がテストに用いたコードです。SCWのリポジトリをcloneした後、scw.pyが置いてるディレクトリで以下のコードを実行してください。

テストコードはGistにも上げてあるのでそちらからもどうぞ。


digits_recognition.py

from __future__ import division

import time

import numpy as np
from sklearn.datasets import load_digits, make_classification
from sklearn.svm import SVC
from matplotlib import pyplot

from scw import SCW1

def generate_dataset():
digits = load_digits(2)

classes = np.unique(digits.target)
y = []
for target in digits.target:
if(target == classes[0]):
y.append(-1)
if(target == classes[1]):
y.append(1)
y = np.array(y)

return digits.data, y

def calc_accuracy(resutls, answers):
n_correct_answers = 0
for result, answer in zip(results, answers):
if(result == answer):
n_correct_answers += 1
accuracy = n_correct_answers/len(results)
return accuracy

X, y = generate_dataset()

N = int(len(X)*0.8)
training, test = X[:N], X[N:]
labels, answers = y[:N], y[N:]

scw = SCW1(len(X[0]), C=1.0, ETA=1.0)
t1 = time.time()
scw.fit(training, labels)
t2 = time.time()
results = scw.predict(test)
accuracy = calc_accuracy(results, answers)
print("SCW time:{:3.6f} accuracy:{:1.3f}".format(t2-t1, accuracy))

svc = SVC(C=10.0)
t1 = time.time()
svc.fit(training, labels)
t2 = time.time()
results = svc.predict(test)
accuracy = calc_accuracy(results, answers)
print("SVC time:{:3.6f} accuracy:{:1.3f}".format(t2-t1, accuracy))