Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
0
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

updated at

Organization

SVMを実装してみたはなし

この記事について

この記事は IQ1の2まいめっ Advent Calendar 2018の17日目の記事です.
最近AoV(伝説対決)とかいうのをやりすぎて昼夜逆転してるIQ1です(o_o).
そんなIQ1のぼくですが,どうしてもアヤメの花の分類をしてみたくなりました.しかし,ぼくはIQ1なのでできません.しかし,SVMはIQ5000兆だからできるのでは?ってことでSVMをつくってみました.
(Advent Calendar以外から来た人は茶番ですいません,内容はまともです)

SVM実装方針

学習

SVMにおける学習とは以下の目的関数を最小化する$a_i$を求めることになります.

L(a_1, ..., a_N) = \sum_i a_i - \frac{1}{2} \sum_i \sum_j a_i a_j y_i y_j k({\bf x}_i, {\bf x}_j) \\
0\leq a_i \leq C \\
\sum_i a_iy_i = 0

ここで ${\bf x}_i$ は $i$ 番目のサンプルの特徴ベクトル, $y_i$ は$i$ 番目のサンプルのラベル(1 or -1),$k(\cdot,\cdot)$ はカーネル関数,$C$はハイパーパラメータです.

この式をよくみていると,$a_i$に関する二次計画問題になっているなーという気持ちがわいてきます.そのため,pythonの二次計画ソルバーライブラリである cvxoptを使用してこの式を解いていきます.

cvxoptでは以下の形の式を最適化してくれます.

minimize \ \frac{1}{2} {\bf x}^TP{\bf x} + {\bf q}^T{\bf x} \\
G{\bf x} < {\bf h} \\
A{\bf x} = {\bf b}

要するに等式条件と不等式条件を考慮することができます.
この式の$P$, ${\bf q}$, $G$, ${\bf h}$, $A$, ${\bf b}$ を当てはめると,

P = KY \\
{\bf q} = -{\bf 1} \\
G = [I, -I] \\
{\bf h} = [C{\bf 1}, {\bf 0}] \\
A = Y \\
b = {\bf 0} 

ただし$[A,B]$は行列$A, B$をたてに連結する操作をあらわします.また,$C$は(スカラの)ハイパーパラメータで, $K$ は $k({\bf x}_i, {\bf x}_j ) $ をならべた行列です.

学習における特徴

詳しいはなしは別の記事に委ねますが,SVMの損失関数の導出過程は以下のような手順で行います.

  1. 境界面からのマージン最大化の定式化
  2. ラグランジュの未定定数法によるラグランジュ関数の導出
  3. ラグランジュ関数の双対問題による計算の単純化

このときに最適な$a_i$においてKKT条件というものが成立します.このKKT条件の中には相補性条件という条件が含まれており,目的関数中の$a_i$がスパースになることが保証されます.
この中で$a_i \neq 0$ であるようなサンプル$i$を サポートベクトル とよびます.実装上はこの $a_i \neq 0$ となるようなサンプルと対応する $a_i$ を保存しておけばいいことになります(なので実はSVMでは学習したすべてのサンプルを予測に使うわけではなく,いらないサンプルが存在することになります).

予測

では実際に新しいベクトル${\bf x}$をどのように予測するかというと,

y({\bf x}) = \sum_{i \in S} a_i y_i k({\bf x}, {\bf x}_i ) + b \\
b = \frac{1}{N_S} \sum_{i \in S} (y_i - \sum_{j \in S} a_i y_i k({\bf x}_i, {\bf x}_j ) )

と予測します.ここで$S$はサポートベクトルの集合を表します.

実装

以上実装すると以下のようになります.

from constants import *
import numpy as np
import cvxopt as cv

class MySVM:
    def __init__(self, C=1.0, kernel="rbf"):
        self.C = C
        self.kernel = kernel_dct[kernel] if kernel in kernel_dct else kernel
        self.w = None
        self.support_vector = None
        self.b = None

    def fit(self, X, y):
        X = np.array(X)
        y = np.array(y)
        gram_matrix = self.kernel(X)
        assert len(X) == len(y), "not same shape label and feature"
        N = len(X)

        T = np.array([[y[i] * y[j] for j in range(N)] for i in range(N)])
        P = gram_matrix * T
        P = cv.matrix(P)

        q = cv.matrix(-np.ones(N))
        G = cv.matrix(np.r_[np.identity(N), -np.identity(N)])
        h = cv.matrix(np.r_[self.C*np.ones(N).T, np.zeros(N).T])

        A = cv.matrix(np.array([y], dtype="double"))
        b = cv.matrix(0.0)

        sol = cv.solvers.qp(P, q, G=G, h=h, A=A, b=b)
        #print(len(list(filter(lambda x: x > eps, sol["x"]))))

        index_list = list(filter(lambda x: sol["x"][x] > eps, range(N)))
        self.w = np.array(sol["x"])[index_list].reshape(len(index_list)) * y[index_list]
        #print(np.array(sol["x"]).shape)
        print(self.w)
        self.support_vector = X[index_list]
        # calc b
        tmp_list = []
        for i in index_list:
            tmp = 0    
            for j in index_list:
                tmp += (sol["x"][j] * y[j] * gram_matrix[i][j])
            tmp_list.append(y[i]-tmp)
        self.b = np.mean(tmp_list)

    def predict(self, X):
        assert self.w is not None or self.support_vector is not None or self.b is not None, "not call fit method yet"
        #print(self.w.shape)
        #print(self.kernel(X, Y=self.support_vector).shape)
        y = np.dot(np.array([self.w]), self.kernel(X, Y=self.support_vector).T) + self.b
        y = y.reshape(len(X))
        return np.array([1 if pred > 0 else -1 for pred in y])

実装はぼくの設計図共有サイト(o_o)に公開します.
アヤメの花の分類結果は,

     pcost       dcost       gap    pres   dres
 0: -1.2160e+00 -9.3480e+01  3e+02  1e+00  3e-16
 1:  4.6055e-01 -3.3964e+01  3e+01  3e-16  4e-16
 2: -1.8046e+00 -6.3648e+00  5e+00  2e-16  5e-16
 3: -2.3472e+00 -3.6182e+00  1e+00  3e-16  3e-16
 4: -2.4902e+00 -2.9605e+00  5e-01  2e-16  2e-16
 5: -2.5731e+00 -2.7007e+00  1e-01  1e-16  2e-16
 6: -2.6078e+00 -2.6359e+00  3e-02  2e-16  2e-16
 7: -2.6172e+00 -2.6185e+00  1e-03  3e-16  2e-16
 8: -2.6177e+00 -2.6178e+00  1e-04  2e-16  2e-16
 9: -2.6178e+00 -2.6178e+00  4e-06  4e-16  3e-16
10: -2.6178e+00 -2.6178e+00  4e-08  5e-16  3e-16
Optimal solution found.
[ 0.74264555  0.03316614 -0.16098464  1.         -0.18675439 -1.
 -0.41870703 -0.03135181 -0.21518834 -0.1965622  -0.18404231  0.61776438]
Number of SV: 12
Accuray: 1.0

で,Accuracyが1.0だった!うれしい!(今回はデータが簡単だったので,それはそう).50サンプル学習させましたが,サポートベクトルの数は12個だったみたいですね.

参考文献

PRML下巻: https://www.amazon.co.jp/dp/4621061240
カーネル多変量解析: https://www.amazon.co.jp/dp/4000069713

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
0
Help us understand the problem. What are the problem?