この記事について
この記事は IQ1の2まいめっ Advent Calendar 2018の17日目の記事です.
最近AoV(伝説対決)とかいうのをやりすぎて昼夜逆転してるIQ1です(o_o).
そんなIQ1のぼくですが,どうしてもアヤメの花の分類をしてみたくなりました.しかし,ぼくはIQ1なのでできません.しかし,SVMはIQ5000兆だからできるのでは?ってことでSVMをつくってみました.
(Advent Calendar以外から来た人は茶番ですいません,内容はまともです)
SVM実装方針
学習
SVMにおける学習とは以下の目的関数を最小化する$a_i$を求めることになります.
L(a_1, ..., a_N) = \sum_i a_i - \frac{1}{2} \sum_i \sum_j a_i a_j y_i y_j k({\bf x}_i, {\bf x}_j) \\
0\leq a_i \leq C \\
\sum_i a_iy_i = 0
ここで ${\bf x}_i$ は $i$ 番目のサンプルの特徴ベクトル, $y_i$ は$i$ 番目のサンプルのラベル(1 or -1),$k(\cdot,\cdot)$ はカーネル関数,$C$はハイパーパラメータです.
この式をよくみていると,$a_i$に関する二次計画問題になっているなーという気持ちがわいてきます.そのため,pythonの二次計画ソルバーライブラリである cvxoptを使用してこの式を解いていきます.
cvxoptでは以下の形の式を最適化してくれます.
minimize \ \frac{1}{2} {\bf x}^TP{\bf x} + {\bf q}^T{\bf x} \\
G{\bf x} < {\bf h} \\
A{\bf x} = {\bf b}
要するに等式条件と不等式条件を考慮することができます.
この式の$P$, ${\bf q}$, $G$, ${\bf h}$, $A$, ${\bf b}$ を当てはめると,
P = KY \\
{\bf q} = -{\bf 1} \\
G = [I, -I] \\
{\bf h} = [C{\bf 1}, {\bf 0}] \\
A = Y \\
b = {\bf 0}
ただし$[A,B]$は行列$A, B$をたてに連結する操作をあらわします.また,$C$は(スカラの)ハイパーパラメータで, $K$ は $k({\bf x}_i, {\bf x}_j ) $ をならべた行列です.
学習における特徴
詳しいはなしは別の記事に委ねますが,SVMの損失関数の導出過程は以下のような手順で行います.
- 境界面からのマージン最大化の定式化
- ラグランジュの未定定数法によるラグランジュ関数の導出
- ラグランジュ関数の双対問題による計算の単純化
このときに最適な$a_i$においてKKT条件というものが成立します.このKKT条件の中には相補性条件という条件が含まれており,目的関数中の$a_i$がスパースになることが保証されます.
この中で$a_i \neq 0$ であるようなサンプル$i$を サポートベクトル とよびます.実装上はこの $a_i \neq 0$ となるようなサンプルと対応する $a_i$ を保存しておけばいいことになります(なので実はSVMでは学習したすべてのサンプルを予測に使うわけではなく,いらないサンプルが存在することになります).
予測
では実際に新しいベクトル${\bf x}$をどのように予測するかというと,
y({\bf x}) = \sum_{i \in S} a_i y_i k({\bf x}, {\bf x}_i ) + b \\
b = \frac{1}{N_S} \sum_{i \in S} (y_i - \sum_{j \in S} a_i y_i k({\bf x}_i, {\bf x}_j ) )
と予測します.ここで$S$はサポートベクトルの集合を表します.
実装
以上実装すると以下のようになります.
from constants import *
import numpy as np
import cvxopt as cv
class MySVM:
def __init__(self, C=1.0, kernel="rbf"):
self.C = C
self.kernel = kernel_dct[kernel] if kernel in kernel_dct else kernel
self.w = None
self.support_vector = None
self.b = None
def fit(self, X, y):
X = np.array(X)
y = np.array(y)
gram_matrix = self.kernel(X)
assert len(X) == len(y), "not same shape label and feature"
N = len(X)
T = np.array([[y[i] * y[j] for j in range(N)] for i in range(N)])
P = gram_matrix * T
P = cv.matrix(P)
q = cv.matrix(-np.ones(N))
G = cv.matrix(np.r_[np.identity(N), -np.identity(N)])
h = cv.matrix(np.r_[self.C*np.ones(N).T, np.zeros(N).T])
A = cv.matrix(np.array([y], dtype="double"))
b = cv.matrix(0.0)
sol = cv.solvers.qp(P, q, G=G, h=h, A=A, b=b)
#print(len(list(filter(lambda x: x > eps, sol["x"]))))
index_list = list(filter(lambda x: sol["x"][x] > eps, range(N)))
self.w = np.array(sol["x"])[index_list].reshape(len(index_list)) * y[index_list]
#print(np.array(sol["x"]).shape)
print(self.w)
self.support_vector = X[index_list]
# calc b
tmp_list = []
for i in index_list:
tmp = 0
for j in index_list:
tmp += (sol["x"][j] * y[j] * gram_matrix[i][j])
tmp_list.append(y[i]-tmp)
self.b = np.mean(tmp_list)
def predict(self, X):
assert self.w is not None or self.support_vector is not None or self.b is not None, "not call fit method yet"
#print(self.w.shape)
#print(self.kernel(X, Y=self.support_vector).shape)
y = np.dot(np.array([self.w]), self.kernel(X, Y=self.support_vector).T) + self.b
y = y.reshape(len(X))
return np.array([1 if pred > 0 else -1 for pred in y])
実装はぼくの設計図共有サイト(o_o)に公開します.
アヤメの花の分類結果は,
pcost dcost gap pres dres
0: -1.2160e+00 -9.3480e+01 3e+02 1e+00 3e-16
1: 4.6055e-01 -3.3964e+01 3e+01 3e-16 4e-16
2: -1.8046e+00 -6.3648e+00 5e+00 2e-16 5e-16
3: -2.3472e+00 -3.6182e+00 1e+00 3e-16 3e-16
4: -2.4902e+00 -2.9605e+00 5e-01 2e-16 2e-16
5: -2.5731e+00 -2.7007e+00 1e-01 1e-16 2e-16
6: -2.6078e+00 -2.6359e+00 3e-02 2e-16 2e-16
7: -2.6172e+00 -2.6185e+00 1e-03 3e-16 2e-16
8: -2.6177e+00 -2.6178e+00 1e-04 2e-16 2e-16
9: -2.6178e+00 -2.6178e+00 4e-06 4e-16 3e-16
10: -2.6178e+00 -2.6178e+00 4e-08 5e-16 3e-16
Optimal solution found.
[ 0.74264555 0.03316614 -0.16098464 1. -0.18675439 -1.
-0.41870703 -0.03135181 -0.21518834 -0.1965622 -0.18404231 0.61776438]
Number of SV: 12
Accuray: 1.0
で,Accuracyが1.0だった!うれしい!(今回はデータが簡単だったので,それはそう).50サンプル学習させましたが,サポートベクトルの数は12個だったみたいですね.
参考文献
PRML下巻: https://www.amazon.co.jp/dp/4621061240
カーネル多変量解析: https://www.amazon.co.jp/dp/4000069713