6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

線形識別器 - Fisherの線形判別関数,単純パーセプトロン,IRLS

Last updated at Posted at 2016-07-21

線形識別器として有名であろうFisherの線形判別器,単純パーセプトロン,IRLS(反復再重み付け最小二乗法)がどんな分類超平面をひくのか手っ取り早く知りたい人用.

以下のような2つのクラスをガウス分布より生成し,その境界を求めます.

  • クラス1:80個を$N(x | \mu_{1}, \Sigma_{1})$から,20個を$N(x|\mu_{1}', \Sigma_{1}')$から生成
  • クラス2:100個を$N(x | \mu_{2}, \Sigma_{2})$から生成
\Sigma_{1} = \Sigma_{1}' = \Sigma_{2} = ((30, 10), (10, 15))^{-1}
\mu_{1} = (0, 0)^{T}
\mu_{1}' = (-2, -2)^{T}
\mu_{2} = (1, 1)^{T}

data_sample

実装

import numpy as np
import matplotlib.pyplot as plt

def generate_data(mu, cov, num_data):
    cls1 = np.random.multivariate_normal(mu[0], cov, num_data[0])
    cls1_ = np.random.multivariate_normal(mu[1], cov, num_data[1])
    cls2 = np.random.multivariate_normal(mu[2], cov, num_data[2])

    return np.r_[cls1, cls1_], cls2

def plot(filename, cls1, cls2, spr=None):
    x1, x2 = cls1.T
    plt.plot(x1, x2, "bo")
    x1, x2 = cls2.T
    plt.plot(x1, x2, "ro")
    
    if not spr is None:
        plt.plot(spr[0], spr[1], "g-")

    plt.xlim(-3, 3)
    plt.ylim(-3, 3)
    plt.savefig(filename)
    plt.clf()

def step(out):
    out = out >= 0.
    out = out.astype(float)
    for i in range(len(out[0])):
        if out[0][i] == 0.:
            out[0][i] = -1.

    return out

def sigmoid(x):
    return 1./(1.+np.exp(-x))

def fisher(cls1, cls2):
    m1 = np.mean(cls1, axis=0)
    m2 = np.mean(cls2, axis=0)

    dim = len(m1)

    Sw = np.zeros((dim, dim))
    for i in range(len(cls1)):
        xi = np.array(cls1[i]).reshape(dim, 1)
        m1 = np.array(m1).reshape(dim, 1)
        Sw += np.dot((xi - m1), (xi - m1).T)
    for i in range(len(cls2)):
        xi = np.array(cls2[i]).reshape(dim, 1)
        m2 = np.array(m2).reshape(dim, 1)
        Sw += np.dot((xi - m2), (xi - m2).T)
    Sw_inv = np.linalg.inv(Sw)
    w = np.dot(Sw_inv, (m2 - m1))

    m = (m1 + m2) / 2.
    b = -sum(w*m)
    
    x = np.linspace(-3, 3, 1000)
    y = [(w[0][0]*xs+b)/(-w[1][0]) for xs in x]
    plot("fisher.png", cls1, cls2, (x, y))

def perceptron(cls1, cls2, lr=0.5, loop=1000):
    cls1_ = np.c_[cls1, np.ones((len(cls1))), np.ones((len(cls1)))]
    cls2_ = np.c_[cls2, np.ones((len(cls2))), -1*np.ones((len(cls2)))]

    data = np.r_[cls1_, cls2_]
    np.random.shuffle(data)
    data, label = np.hsplit(data, [len(data[0])-1])
    w = np.random.uniform(-1., 1., size=(1, len(data[0])))

    for i in range(loop):
        out = np.dot(w, data.T)
        out = step(out)
        dw = lr * (label - out.T) * data
        w += np.mean(dw, axis=0)
    
    x = np.linspace(-3, 3, 1000)
    y = [(w[0][0]*xs+w[0][2])/(-w[0][1]) for xs in x]
    plot("perceptron.png", cls1, cls2, (x, y))

def IRLS(cls1, cls2, tol=1e-5, maxits=100):
    cls1_ = np.c_[cls1, np.ones((len(cls1))), np.ones((len(cls1)))]
    cls2_ = np.c_[cls2, np.ones((len(cls2))), np.zeros((len(cls2)))]

    data = np.r_[cls1_, cls2_]
    np.random.shuffle(data)
    data, label = np.hsplit(data, [len(data[0])-1])
    w = np.zeros((1, len(data[0])))
    
    itr=0
    while(itr < maxits):
        y = sigmoid(np.dot(w, data.T)).T
        g = np.dot(data.T, (y - label))
        rn = y.T*(1-y.T)
        r = np.diag(rn[0])
        hesse = np.dot(np.dot(data.T, r), data)
        diff = np.dot(np.dot(np.linalg.inv(hesse), data.T), (y - label))
        w -= diff.T
        if np.sum(g**2) <= tol:
            print(itr)
            break
        itr += 1
    
    x = np.linspace(-3, 3, 1000)
    y = [(w[0][0]*xs+w[0][2])/(-w[0][1]) for xs in x]
    plot("IRLS.png", cls1, cls2, (x, y))

if __name__ == "__main__":
    mu = [[0., 0.], [-2., -2.], [1. ,1.]]
    cov = np.linalg.inv([[30., 10.], [10., 15.]])
    num_data = [80, 20, 100]

    cls1, cls2 = generate_data(mu, cov, num_data)
    plot("data.png", cls1, cls2)
    fisher(cls1, cls2)
    perceptron(cls1, cls2)
    IRLS(cls1, cls2)

結果

まずはFisherの線形識別器
fisher
外れ値に大きく引っ張られていることがわかります.

次に単純パーセプトロン
perceptron
汎化性能が良いかはともかく,訓練データを分類できる分類超平面になっています.また,外れ値に引っ張られていません.

最後にIRLS
perceptron
外れ値の影響を受けず,汎化性能も良さそうです.

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?