search
LoginSignup
1

More than 3 years have passed since last update.

posted at

updated at

Organization

SVMを実装してみたはなし

この記事について

この記事は IQ1の2まいめっ Advent Calendar 2018の17日目の記事です.
最近AoV(伝説対決)とかいうのをやりすぎて昼夜逆転してるIQ1です(o_o).
そんなIQ1のぼくですが,どうしてもアヤメの花の分類をしてみたくなりました.しかし,ぼくはIQ1なのでできません.しかし,SVMはIQ5000兆だからできるのでは?ってことでSVMをつくってみました.
(Advent Calendar以外から来た人は茶番ですいません,内容はまともです)

SVM実装方針

学習

SVMにおける学習とは以下の目的関数を最小化する$a_i$を求めることになります.

L(a_1, ..., a_N) = \sum_i a_i - \frac{1}{2} \sum_i \sum_j a_i a_j y_i y_j k({\bf x}_i, {\bf x}_j) \\
0\leq a_i \leq C \\
\sum_i a_iy_i = 0

ここで ${\bf x}_i$ は $i$ 番目のサンプルの特徴ベクトル, $y_i$ は$i$ 番目のサンプルのラベル(1 or -1),$k(\cdot,\cdot)$ はカーネル関数,$C$はハイパーパラメータです.

この式をよくみていると,$a_i$に関する二次計画問題になっているなーという気持ちがわいてきます.そのため,pythonの二次計画ソルバーライブラリである cvxoptを使用してこの式を解いていきます.

cvxoptでは以下の形の式を最適化してくれます.

minimize \ \frac{1}{2} {\bf x}^TP{\bf x} + {\bf q}^T{\bf x} \\
G{\bf x} < {\bf h} \\
A{\bf x} = {\bf b}

要するに等式条件と不等式条件を考慮することができます.
この式の$P$, ${\bf q}$, $G$, ${\bf h}$, $A$, ${\bf b}$ を当てはめると,

P = KY \\
{\bf q} = -{\bf 1} \\
G = [I, -I] \\
{\bf h} = [C{\bf 1}, {\bf 0}] \\
A = Y \\
b = {\bf 0} 

ただし$[A,B]$は行列$A, B$をたてに連結する操作をあらわします.また,$C$は(スカラの)ハイパーパラメータで, $K$ は $k({\bf x}_i, {\bf x}_j ) $ をならべた行列です.

学習における特徴

詳しいはなしは別の記事に委ねますが,SVMの損失関数の導出過程は以下のような手順で行います.

  1. 境界面からのマージン最大化の定式化
  2. ラグランジュの未定定数法によるラグランジュ関数の導出
  3. ラグランジュ関数の双対問題による計算の単純化

このときに最適な$a_i$においてKKT条件というものが成立します.このKKT条件の中には相補性条件という条件が含まれており,目的関数中の$a_i$がスパースになることが保証されます.
この中で$a_i \neq 0$ であるようなサンプル$i$を サポートベクトル とよびます.実装上はこの $a_i \neq 0$ となるようなサンプルと対応する $a_i$ を保存しておけばいいことになります(なので実はSVMでは学習したすべてのサンプルを予測に使うわけではなく,いらないサンプルが存在することになります).

予測

では実際に新しいベクトル${\bf x}$をどのように予測するかというと,

y({\bf x}) = \sum_{i \in S} a_i y_i k({\bf x}, {\bf x}_i ) + b \\
b = \frac{1}{N_S} \sum_{i \in S} (y_i - \sum_{j \in S} a_i y_i k({\bf x}_i, {\bf x}_j ) )

と予測します.ここで$S$はサポートベクトルの集合を表します.

実装

以上実装すると以下のようになります.

from constants import *
import numpy as np
import cvxopt as cv

class MySVM:
    def __init__(self, C=1.0, kernel="rbf"):
        self.C = C
        self.kernel = kernel_dct[kernel] if kernel in kernel_dct else kernel
        self.w = None
        self.support_vector = None
        self.b = None

    def fit(self, X, y):
        X = np.array(X)
        y = np.array(y)
        gram_matrix = self.kernel(X)
        assert len(X) == len(y), "not same shape label and feature"
        N = len(X)

        T = np.array([[y[i] * y[j] for j in range(N)] for i in range(N)])
        P = gram_matrix * T
        P = cv.matrix(P)

        q = cv.matrix(-np.ones(N))
        G = cv.matrix(np.r_[np.identity(N), -np.identity(N)])
        h = cv.matrix(np.r_[self.C*np.ones(N).T, np.zeros(N).T])

        A = cv.matrix(np.array([y], dtype="double"))
        b = cv.matrix(0.0)

        sol = cv.solvers.qp(P, q, G=G, h=h, A=A, b=b)
        #print(len(list(filter(lambda x: x > eps, sol["x"]))))

        index_list = list(filter(lambda x: sol["x"][x] > eps, range(N)))
        self.w = np.array(sol["x"])[index_list].reshape(len(index_list)) * y[index_list]
        #print(np.array(sol["x"]).shape)
        print(self.w)
        self.support_vector = X[index_list]
        # calc b
        tmp_list = []
        for i in index_list:
            tmp = 0    
            for j in index_list:
                tmp += (sol["x"][j] * y[j] * gram_matrix[i][j])
            tmp_list.append(y[i]-tmp)
        self.b = np.mean(tmp_list)

    def predict(self, X):
        assert self.w is not None or self.support_vector is not None or self.b is not None, "not call fit method yet"
        #print(self.w.shape)
        #print(self.kernel(X, Y=self.support_vector).shape)
        y = np.dot(np.array([self.w]), self.kernel(X, Y=self.support_vector).T) + self.b
        y = y.reshape(len(X))
        return np.array([1 if pred > 0 else -1 for pred in y])

実装はぼくの設計図共有サイト(o_o)に公開します.
アヤメの花の分類結果は,

     pcost       dcost       gap    pres   dres
 0: -1.2160e+00 -9.3480e+01  3e+02  1e+00  3e-16
 1:  4.6055e-01 -3.3964e+01  3e+01  3e-16  4e-16
 2: -1.8046e+00 -6.3648e+00  5e+00  2e-16  5e-16
 3: -2.3472e+00 -3.6182e+00  1e+00  3e-16  3e-16
 4: -2.4902e+00 -2.9605e+00  5e-01  2e-16  2e-16
 5: -2.5731e+00 -2.7007e+00  1e-01  1e-16  2e-16
 6: -2.6078e+00 -2.6359e+00  3e-02  2e-16  2e-16
 7: -2.6172e+00 -2.6185e+00  1e-03  3e-16  2e-16
 8: -2.6177e+00 -2.6178e+00  1e-04  2e-16  2e-16
 9: -2.6178e+00 -2.6178e+00  4e-06  4e-16  3e-16
10: -2.6178e+00 -2.6178e+00  4e-08  5e-16  3e-16
Optimal solution found.
[ 0.74264555  0.03316614 -0.16098464  1.         -0.18675439 -1.
 -0.41870703 -0.03135181 -0.21518834 -0.1965622  -0.18404231  0.61776438]
Number of SV: 12
Accuray: 1.0

で,Accuracyが1.0だった!うれしい!(今回はデータが簡単だったので,それはそう).50サンプル学習させましたが,サポートベクトルの数は12個だったみたいですね.

参考文献

PRML下巻: https://www.amazon.co.jp/dp/4621061240
カーネル多変量解析: https://www.amazon.co.jp/dp/4000069713

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
1