12
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ガウス過程回帰 Numpy実装とGPy

Last updated at Posted at 2019-12-08

はじめに

こんな回帰直線引いちゃっていいのと疑問に思うようなグラフを見ることがあります。
例えば

import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Times New Roman"      #全体のフォントを設定
plt.rcParams["xtick.direction"] = "in"               #x軸の目盛線を内向きへ
plt.rcParams["ytick.direction"] = "in"               #y軸の目盛線を内向きへ
plt.rcParams["xtick.minor.visible"] = True           #x軸補助目盛りの追加
plt.rcParams["ytick.minor.visible"] = True           #y軸補助目盛りの追加
plt.rcParams["xtick.major.width"] = 1.5              #x軸主目盛り線の線幅
plt.rcParams["ytick.major.width"] = 1.5              #y軸主目盛り線の線幅
plt.rcParams["xtick.minor.width"] = 1.0              #x軸補助目盛り線の線幅
plt.rcParams["ytick.minor.width"] = 1.0              #y軸補助目盛り線の線幅
plt.rcParams["xtick.major.size"] = 10                #x軸主目盛り線の長さ
plt.rcParams["ytick.major.size"] = 10                #y軸主目盛り線の長さ
plt.rcParams["xtick.minor.size"] = 5                 #x軸補助目盛り線の長さ
plt.rcParams["ytick.minor.size"] = 5                 #y軸補助目盛り線の長さ
plt.rcParams["font.size"] = 14                       #フォントの大きさ
plt.rcParams["axes.linewidth"] = 1.5                 #囲みの太さ

def func(x):
    return 1.5*x+2.0 + np.random.normal(0.0,0.1,len(x))

x = np.array([0.0,0.2,0.5,1.0,10.0])
y = func(x) #測定点
a,b = np.polyfit(x,y,1) #線形回帰

fig,axes = plt.subplots()
axes.scatter(x,y)
axes.plot(x,a*x+b)

output_3_1.png

線形回帰によって1本の直線を引くことができました。しかしながら、1から10の間に測定点はなく、この部分は直感的には不確かそうに思えます。
こうした不確かさを表せる良い回帰分析はないかと調べてみたら、ガウス過程回帰というものを見つけたので勉強のため実装して試したいと思います。

参考

数学に疎い私でも読むことができました。この投稿は2章と3章にあたります。
ガウス過程と機械学習

ガウス過程回帰ではハイパーパラメータが存在し、その推定にMCMC法を用いました。
emceeでマルコフ連鎖モンテカルロ法(MCMC)によるサンプリング

自分で実装しなくてもライブラリがあります。
Gpy vs scikit-learn: pythonでガウス過程回帰

ガウス過程回帰

1. ガウス過程

どんな自然数$N$についても、入力$x_1,x_2,...,x_N$に対応する出力のベクトル
$${\mathbf f}=(f(x_1),f(x_2),...,f(x_N))$$
が平均${\mathbf \mu}=(\mu(x_1),\mu(x_2),...,\mu(x_N))$, $K_{nn^{'}}=k(x_n,x_n^{'})$を要素とする行列$\rm K$を共分散行列とするガウス分布${\mathcal N}({\mathbf \mu},{\rm K})$に従うとき、$f$はガウス過程に従うといい、これを
$$f \sim {\rm GP}(\mu({\mathbf x}), k({\mathbf x},{\mathbf x^{'}}))$$
と書きます。

とは言えガウス過程の性質がよくわからないので、試しにカーネル関数として動径基底関数を用いてカーネル行列を計算して、そのカーネル行列に従うガウス分布からサンプリングしてみます。

def rbf(x,s=[1.0,1.0,1.0e-6]):
    """
    Radial Basis Function, RBF : 動径基底関数
    """
    s1,s2,s3 = s[0],s[1],s[2]
    
    X, X_ = np.meshgrid(x,x)
    K = s1**2 * np.exp(-((X-X_)/s2)**2) + np.identity(len(x)) * s3**2
    return K

N = 100
x = np.linspace(0.0,10.0,N) #入力
K = rbf(x) #カーネル行列

fig,axes = plt.subplots()
axes.imshow(K)

output_9_1.png

分散共分散行列が$\rm K$の多次元ガウス分布からサンプリングするためには、$\mathbf x$をランダムに生成し、$\mathbf{y} = \mathbf{Lx}$と変換します。$\mathbf L$は${\rm K}$をコレスキー分解することで得られます。

L = np.linalg.cholesky(K) #カーネル行列のコレスキー分解

fig,axes = plt.subplots()
axes.plot()
for i in range(10):
    _x = np.random.normal(0.0,1.0,N)
    axes.plot(x,np.dot(L,_x))

output_12_0.png

サイコロをふると数字がでるように、関数がランダムにでてきます。
また、似た値からは似た値が出てくるというのがガウス過程の直感的な性質だそうです。

2. ガウス過程回帰

$N$個の測定値があります。簡単のためyの平均値は0です。
$${\mathcal D} = {(x_1,y_1),(x_2,y_2),...,(x_N,y_N)}$$
このとき、$\mathbf x$と$y$の間に
$$y=f({\mathbf x})$$
の関係があり、この関数$f$が平均0のガウス過程
$$f \sim {\rm GP}(0,k(x,x^{'}))$$
から生成されているとします。${\mathbf y}=(y_1,y_2,...,y_N)^{T}$とおけば、この$\mathbf y$はガウス分布に従い、入力のすべてのペアについてカーネル関数$k$を用いて
$${\rm K}\ =k(x,x^{'})$$
で与えられるカーネル行列$\rm K$を使って、
$$y \sim {\mathcal N}(0,{\rm K})$$
が成り立ちます。

ここから、データに含まれない$x^{\ast}$での$y^{\ast}$の値の求め方を示します。
$\mathbf y$に$y^{\ast}$を加えたものを${\mathbf y}^{'}=(y_1,y_2,...,y_N,y^{\ast})^{T}$とします。$(x_1,x_2,...,x_N,x^{\ast})^{T}$から計算されるカーネル行列を$\rm K^{'}$をすれば、これら全体もガウス分布に従うので
$$y^{'} \sim {\mathcal N}(0,{\rm K}^{'})$$
となります。すなわち
$$\begin{pmatrix} {\mathbf y} \\ y^{\ast} \end{pmatrix} \sim {\mathcal N}
\begin{pmatrix}0, {\begin{pmatrix}{\rm K} & {k_{\ast}} \\ {k_{\ast}^T} & {k_{\ast \ast}}\end{pmatrix}}\end{pmatrix}$$
が成り立ちます。これは$y^{\ast}$と$\mathbf y$の同時分布の式なので、$\mathbf y$が与えられたときの$y^{\ast}$の条件付き確率は、ガウス分布の要素間の条件付き確率から求められます。その条件付き確率はこうなります。
$$p(y^{\ast}|x^{\ast},{\mathcal D}) = {\mathcal N}(k_{\ast}^T {\rm K}^{-1} {\mathbf y}, k_{\ast \ast}-k_{\ast}^T {\rm K}^{-1} k_{\ast})$$

上式に従ってコードを書きます。

def rbf(x_train,s,x_pred=None):
    """
    Radial Basis Function, RBF : 動径基底関数
    """
    s1,s2,s3 = s[0],s[1],s[2]

    if x_pred is None:
        x = x_train
        X, X_ = np.meshgrid(x,x)
        K = s1**2 * np.exp(-((X-X_)/s2)**2) + np.identity(len(x)) * s3**2 #カーネル行列
        return K
    else:
        x = np.append(x_train,x_pred)
        X, X_ = np.meshgrid(x,x)
        K_all = s1**2 * np.exp(-((X-X_)/s2)**2) + np.identity(len(x)) * s3**2 #カーネル行列
        K = K_all[:len(x_train),:len(x_train)]
        K_s = K_all[:len(x_train),len(x_train):]
        K_ss = K_all[len(x_train):,len(x_train):]
        return K,K_s,K_ss
    
def pred(x_train,y_train,x_pred,s):
    K,K_s,K_ss = rbf(x_train,s,x_pred=x_pred)
    K_inv = np.linalg.inv(K) #逆行列
    y_pred_mean = np.dot(np.dot(K_s.T,K_inv), y_train) #yの期待値
    y_pred_cov = K_ss - np.dot(np.dot(K_s.T,K_inv), K_s) #分散共分散行列
    y_pred_std = np.sqrt(np.diag(y_pred_cov)) #標準偏差
    return y_pred_mean,y_pred_std

適当な関数を作り、サンプリングして測定値とします。

def func(x):
    return np.sin(2.0*np.pi*0.2*x) + np.sin(2.0*np.pi*0.1*x)

pred_N = 100 #予測点の数
N = 10 #測定点の数
x_train = np.random.uniform(0.0,10.0,N) #測定点:x
y_train = func(x_train) + np.random.normal(0,0.1,1) #測定点:y

x_true = x_pred = np.linspace(-2.0,12.0,pred_N)
fig,axes = plt.subplots()
axes.plot(x_true, func(x_true), label="True")
axes.scatter(x_train,y_train, label="Measured")
axes.legend(loc="best")

output_19_1.png

ガウス過程回帰をしてみます。

s = [1.0,1.0,0.1]
y_pred_mean,y_pred_std = pred(x_train,y_train,x_pred,s)

fig,axes = plt.subplots(figsize=(8,6))
axes.set_title("s1=1.0, s2=1.0, s3=0.1")
axes.plot(x_true, func(x_true), label="True")
axes.scatter(x_train,y_train, label="Measured")
axes.plot(x_pred, y_pred_mean, label="Predict")
plt.fill_between(x_pred,y_pred_mean+y_pred_std,y_pred_mean-y_pred_std,facecolor="b",alpha=0.3)
axes.legend(loc="best")
axes.set_xlim(-2.0,12.0)

output_20_1.png

それっぽい回帰ができました。測定点数が少ない箇所では分布が広がっています。今、カーネル関数のハイパーパラメータは手で与えました。これを変えるとどうなるか見てみます。
output_22_1.png

結果が大きく変わってしまいました。ハイパーパラメータも推定したいところです。

ハイパーパラメータをまとめて$\theta$と置きます。カーネル行列は$\theta$に依存し、このとき、学習データの確率は

$$p({\mathbf y}|{\mathbf x},\theta) = \mathcal{N}({\mathbf y}|0,{\rm K})
= \frac{1}{(2\pi)^{N/2}} \frac{1}{|{\rm K}|^{1/2}} \exp(-\frac{1}{2} {\mathbf y}^T {\rm K}_{\theta}^{-1} {\mathbf y}) $$

対数をとれば
$$\ln{p({\mathbf y}|{\mathbf x},\theta)} \propto -\ln{|{\rm K}|}-{\mathbf y}^T {\rm K}_{\theta}^{-1} {\mathbf y} + const.$$
となり、上式を最大にする$\theta$を求めます。

勾配法では局所解にはまりやすいのでMCMC法を使用します。

import emcee
def objective(s): #目的関数
    K = rbf(x_train,s)
    return -np.linalg.slogdet(K)[1] - y_train.T.dot(np.linalg.inv(K)).dot(y_train) #エビデンス

ndim = 3
nwalker = 6
s0 = np.random.uniform(0.0,5.0,[nwalker,ndim]) #初期位置
sampler = emcee.EnsembleSampler(nwalker,ndim,objective) #サンプラーを作る
sampler.run_mcmc(s0,5000) #サンプリング開始

s_dist = sampler.flatchain #サンプリングでできた結果を取得
s = s_dist[sampler.flatlnprobability.argmax()]
y_pred_mean,y_pred_std = pred(x_train,y_train,x_pred,s)

fig,axes = plt.subplots(figsize=(8,6))
axes.plot(x_true, func(x_true), label="True")
axes.scatter(x_train,y_train, label="Measured")
axes.plot(x_pred, y_pred_mean, label="Predict")
plt.fill_between(x_pred,y_pred_mean+y_pred_std,y_pred_mean-y_pred_std,facecolor="b",alpha=0.3)
axes.legend(loc="best")
axes.set_xlim(-2.0,12.0)

output_26_2.png

いい感じになりました。

3. GPyライブラリを用いたガウス過程回帰

ガウス過程回帰の便利なライブラリがあります。カーネルの種類が豊富で、可視化も容易なのでおすすめです。

import GPy
import GPy.kern as gp_kern

kern = gp_kern.RBF(input_dim=1)
gpy_model = GPy.models.GPRegression(X=x_train.reshape(-1, 1), Y=y_train.reshape(-1, 1), kernel=kern, normalizer=None)

fig,axes = plt.subplots(figsize=(8,6))
axes.plot(x_true, func(x_true), c="k", label="True")
gpy_model.optimize()
gpy_model.plot(ax=axes)
axes.legend(loc="best")
axes.set_xlim(-2.0,12.0)

output_29_1.png

まとめ

ガウス過程回帰によって不確かさも含めた推定が行えました。今回の投稿では基礎理論部分のみしか考慮できませんでしたが、ガウス過程回帰は計算量が問題となり、解決方法ついて様々な議論がされています。また、カーネル関数は1種類だけではなく多種多様で、分析対象に合わせて適宜選択したり組み合わせたりする必要があります。

12
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?