4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ガウスカーネルモデルを用いた最小二乗誤差回帰の交差検証

Last updated at Posted at 2019-05-04

この記事ですること

ガウスカーネルモデルによる$l_2$-正則化を用いた最小二乗回帰の交差検証をpythonで実装する。

ガウスカーネルモデル

K(\boldsymbol{x},\boldsymbol{c})= \exp{ \Big(-\frac{||\boldsymbol{x} - \boldsymbol{c}||^2 }{2h^2}\Big)}

で定義される関数$K$をガウスカーネルといい、ガウス分布のような形をしている。
パラメータは$\boldsymbol{c}$と$h$であり、前者が関数の中心、後者が幅に対応している。
ガウスカーネルを基底としたパラメータ$\boldsymbol{\theta}$の線形モデル

f_\boldsymbol{\theta}(\boldsymbol{x}) = \sum^n_{j=1}\theta_jK(\boldsymbol{x},\boldsymbol{x}_j)

をガウスカーネルモデルといい、この関数でデータを近似する。
ここで、

\boldsymbol{x}_1,\boldsymbol{x}_2,...\boldsymbol{x}_n

は手元にあるデータである。また、パラメータ数が訓練標本数に依存するため、ノンパラメトリックモデルと呼ばれる。

正則化に関して

通常の最小二乗回帰で用いられるコスト関数は
モデルの出力 $f_\boldsymbol{\theta}(\boldsymbol{x}_i)$

と実際の値 $y_i$ の差の二乗の和

\frac{1}{2}\sum_{i=1}^n(f_\boldsymbol{\theta}(\boldsymbol{x}_i)-y_i)^2

を最小化するパラメータ$\boldsymbol{\theta}$を選ぶが、ナイーブにこれを実行するとパラメータの要素が非常に大きな値となり、
過学習を起こすことが知られている。これを防ぐために、コスト関数に$\lambda||\boldsymbol{\theta}||^2/2$を加えたものを代わりにコスト関数にする。これを$l_2$-正則化項という。

\frac{1}{2}\sum_{i=1}^n(f_\boldsymbol{\theta}(\boldsymbol{x}_i)-y_i)^2 + \frac{\lambda}{2}||\boldsymbol{\theta}||^2 

交差検証

学習を行う際、データの一部をパラメータを決定するための訓練データとして用い、
残りのデータをそのパラメータの「良さ」の評価のためのテストデータとして用いる。
例えばデータを$k$個に分割し、$k-1$個をパラメータの決定に利用し、残りの$1$個で二乗誤差を求める。
この操作をすべてに対して行い(合計で$k$回)、誤差の平均を算出し、パラメータの決定の指標とする。

実装

import numpy as np
import matplotlib
import random
import matplotlib.pyplot as plt

np.random.seed(0)

def generate_sample_data(xmin, xmax, sample_size):
    x = np.linspace(start=xmin, stop=xmax, num=sample_size)
    pix = np.pi * x
    target = np.cos(pix) + 0.1 * x
    noise = 0.05 * np.random.normal(loc=0., scale=1., size=sample_size)
    return x, target + noise

def calc_design_matrix(x, c, h):
    return np.exp(-(x[None] - c[:, None]) ** 2 / (2 * h ** 2))

def omit_data_i(data, lnt, i, idx_int):
    data_len = len(data)
    group_len = int(data_len/lnt)
    cropper = idx_int[i:i + group_len]
    data_omitted = data[cropper]
    data_i = np.delete(data, cropper)
    return data_i, data_omitted

def calc_squre_err(theta, k, y):
    s_err = np.linalg.norm(np.dot(k.T, theta).T-y)/10
    return s_err

def validate(x, y, lamb, h, lnt, i, idx_int):
    x_i, x_omitted = omit_data_i(x, lnt, i, idx_int)
    y_i, y_omitted = omit_data_i(y, lnt, i, idx_int)
    k = calc_design_matrix(x_i, x_i, h)

    theta = np.linalg.solve(
        k.T.dot(k) + lamb * np.identity(len(k)),
        k.T.dot(y_i[:, None]))

    K_valid = calc_design_matrix(x_omitted, x_i, h)
    err = calc_squre_err(theta, K_valid, y_omitted)

    xmin, xmax = -3, 3
    X = np.linspace(start=xmin, stop=xmax, num=5000)
    K = calc_design_matrix(x_i, X, h)
    prediction = K.dot(theta)

    # visualization
    plt.clf()
    plt.scatter(x_i, y_i, c='green', marker='o')
    plt.plot(X, prediction)
    return err

def main():
    sample_size = 50
    xmin, xmax = -3, 3
    x, y = generate_sample_data(xmin=xmin, xmax=xmax, sample_size=sample_size)
    lamb = np.logspace(-3,3,7) # [0.0001, 0.01, 0.1, 1, 10, 100, 1000]
    h = np.logspace(-3,3,7) # [0.0001, 0.01, 0.1, 1, 10, 100, 1000]
    lnt = len(h)
    idx_int = np.random.randint(0,50,50)

    for _lamb in lamb:
        for _h in h:
            print('*'*20)
            print('lamb = ', _lamb)
            print('h = ', _h)
            err_array = np.array([])
            for i in np.arange(0, 50, 10):
                err = validate(x, y, _lamb, _h, lnt, i, idx_int)
                err_array = np.append(err_array, err)
                plt.show()
            print(err_array)
            print('mean = ',np.mean(err_array))

if __name__ == '__main__':
    main()

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?