5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

SVMを実装してみたはなし

Last updated at Posted at 2018-12-16

この記事について

この記事は IQ1の2まいめっ Advent Calendar 2018の17日目の記事です.
最近AoV(伝説対決)とかいうのをやりすぎて昼夜逆転してるIQ1です(o_o).
そんなIQ1のぼくですが,どうしてもアヤメの花の分類をしてみたくなりました.しかし,ぼくはIQ1なのでできません.しかし,SVMはIQ5000兆だからできるのでは?ってことでSVMをつくってみました.
(Advent Calendar以外から来た人は茶番ですいません,内容はまともです)

SVM実装方針

学習

SVMにおける学習とは以下の目的関数を最小化する$a_i$を求めることになります.

L(a_1, ..., a_N) = \sum_i a_i - \frac{1}{2} \sum_i \sum_j a_i a_j y_i y_j k({\bf x}_i, {\bf x}_j) \\
0\leq a_i \leq C \\
\sum_i a_iy_i = 0

ここで ${\bf x}_i$ は $i$ 番目のサンプルの特徴ベクトル, $y_i$ は$i$ 番目のサンプルのラベル(1 or -1),$k(\cdot,\cdot)$ はカーネル関数,$C$はハイパーパラメータです.

この式をよくみていると,$a_i$に関する二次計画問題になっているなーという気持ちがわいてきます.そのため,pythonの二次計画ソルバーライブラリである cvxoptを使用してこの式を解いていきます.

cvxoptでは以下の形の式を最適化してくれます.

minimize \ \frac{1}{2} {\bf x}^TP{\bf x} + {\bf q}^T{\bf x} \\
G{\bf x} < {\bf h} \\
A{\bf x} = {\bf b}

要するに等式条件と不等式条件を考慮することができます.
この式の$P$, ${\bf q}$, $G$, ${\bf h}$, $A$, ${\bf b}$ を当てはめると,

P = KY \\
{\bf q} = -{\bf 1} \\
G = [I, -I] \\
{\bf h} = [C{\bf 1}, {\bf 0}] \\
A = Y \\
b = {\bf 0} 

ただし$[A,B]$は行列$A, B$をたてに連結する操作をあらわします.また,$C$は(スカラの)ハイパーパラメータで, $K$ は $k({\bf x}_i, {\bf x}_j ) $ をならべた行列です.

学習における特徴

詳しいはなしは別の記事に委ねますが,SVMの損失関数の導出過程は以下のような手順で行います.

  1. 境界面からのマージン最大化の定式化
  2. ラグランジュの未定定数法によるラグランジュ関数の導出
  3. ラグランジュ関数の双対問題による計算の単純化

このときに最適な$a_i$においてKKT条件というものが成立します.このKKT条件の中には相補性条件という条件が含まれており,目的関数中の$a_i$がスパースになることが保証されます.
この中で$a_i \neq 0$ であるようなサンプル$i$を サポートベクトル とよびます.実装上はこの $a_i \neq 0$ となるようなサンプルと対応する $a_i$ を保存しておけばいいことになります(なので実はSVMでは学習したすべてのサンプルを予測に使うわけではなく,いらないサンプルが存在することになります).

予測

では実際に新しいベクトル${\bf x}$をどのように予測するかというと,

y({\bf x}) = \sum_{i \in S} a_i y_i k({\bf x}, {\bf x}_i ) + b \\
b = \frac{1}{N_S} \sum_{i \in S} (y_i - \sum_{j \in S} a_i y_i k({\bf x}_i, {\bf x}_j ) )

と予測します.ここで$S$はサポートベクトルの集合を表します.

実装

以上実装すると以下のようになります.

from constants import *
import numpy as np
import cvxopt as cv

class MySVM:
    def __init__(self, C=1.0, kernel="rbf"):
        self.C = C
        self.kernel = kernel_dct[kernel] if kernel in kernel_dct else kernel
        self.w = None
        self.support_vector = None
        self.b = None
        
    def fit(self, X, y):
        X = np.array(X)
        y = np.array(y)
        gram_matrix = self.kernel(X)
        assert len(X) == len(y), "not same shape label and feature"
        N = len(X)

        T = np.array([[y[i] * y[j] for j in range(N)] for i in range(N)])
        P = gram_matrix * T
        P = cv.matrix(P)

        q = cv.matrix(-np.ones(N))
        G = cv.matrix(np.r_[np.identity(N), -np.identity(N)])
        h = cv.matrix(np.r_[self.C*np.ones(N).T, np.zeros(N).T])
        
        A = cv.matrix(np.array([y], dtype="double"))
        b = cv.matrix(0.0)

        sol = cv.solvers.qp(P, q, G=G, h=h, A=A, b=b)
        #print(len(list(filter(lambda x: x > eps, sol["x"]))))
        
        index_list = list(filter(lambda x: sol["x"][x] > eps, range(N)))
        self.w = np.array(sol["x"])[index_list].reshape(len(index_list)) * y[index_list]
        #print(np.array(sol["x"]).shape)
        print(self.w)
        self.support_vector = X[index_list]
        # calc b
        tmp_list = []
        for i in index_list:
            tmp = 0    
            for j in index_list:
                tmp += (sol["x"][j] * y[j] * gram_matrix[i][j])
            tmp_list.append(y[i]-tmp)
        self.b = np.mean(tmp_list)

    def predict(self, X):
        assert self.w is not None or self.support_vector is not None or self.b is not None, "not call fit method yet"
        #print(self.w.shape)
        #print(self.kernel(X, Y=self.support_vector).shape)
        y = np.dot(np.array([self.w]), self.kernel(X, Y=self.support_vector).T) + self.b
        y = y.reshape(len(X))
        return np.array([1 if pred > 0 else -1 for pred in y])

実装はぼくの設計図共有サイト(o_o)に公開します.
アヤメの花の分類結果は,

     pcost       dcost       gap    pres   dres
 0: -1.2160e+00 -9.3480e+01  3e+02  1e+00  3e-16
 1:  4.6055e-01 -3.3964e+01  3e+01  3e-16  4e-16
 2: -1.8046e+00 -6.3648e+00  5e+00  2e-16  5e-16
 3: -2.3472e+00 -3.6182e+00  1e+00  3e-16  3e-16
 4: -2.4902e+00 -2.9605e+00  5e-01  2e-16  2e-16
 5: -2.5731e+00 -2.7007e+00  1e-01  1e-16  2e-16
 6: -2.6078e+00 -2.6359e+00  3e-02  2e-16  2e-16
 7: -2.6172e+00 -2.6185e+00  1e-03  3e-16  2e-16
 8: -2.6177e+00 -2.6178e+00  1e-04  2e-16  2e-16
 9: -2.6178e+00 -2.6178e+00  4e-06  4e-16  3e-16
10: -2.6178e+00 -2.6178e+00  4e-08  5e-16  3e-16
Optimal solution found.
[ 0.74264555  0.03316614 -0.16098464  1.         -0.18675439 -1.
 -0.41870703 -0.03135181 -0.21518834 -0.1965622  -0.18404231  0.61776438]
Number of SV: 12
Accuray: 1.0

で,Accuracyが1.0だった!うれしい!(今回はデータが簡単だったので,それはそう).50サンプル学習させましたが,サポートベクトルの数は12個だったみたいですね.

参考文献

PRML下巻: https://www.amazon.co.jp/dp/4621061240
カーネル多変量解析: https://www.amazon.co.jp/dp/4000069713

5
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?