1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ガウス過程に基づく回帰モデルの学習ノート

Last updated at Posted at 2019-04-06

はじめに

  • 日本語の書籍が入手できなかったので、以下の資料を参考にしました。
  • Gaussian Processes for Machine Learning (http://www.gaussianprocess.org/gpml/chapters/)
  • 「2 Regression」の再現コードとなります。

Weight-space View

入力マトリクスを$X$、出力ベクトルを$y$、パラメータを$w$とする。

y=X^tw+\epsilon,\quad \epsilon\sim N(0,\sigma_n^2)

テストデータ$x_*$における予測値$y_*$の分布を求める。
$w$の事前分布を$N(0,\Sigma_p)$とする。$w$の事後分布を求めて、予測モデルから積分消去する。

\begin{align}
p(y|X,w)&=\prod N(x_i^tw,\sigma_n^2)=N(X^tw,\sigma_n^2 I)\\
p(w|X,y)&\sim p(y|X,w) p(w)
\sim N(X^tw,\sigma_n^2 I)N(0,\Sigma_p)\\
&=N(\sigma_n^{-2}A^{-1}Xy,A^{-1})\\
A&=\sigma_n^{-2}XX^t+\Sigma_p^{-1}\\
p(y_*|x_*,X,y)&=\int p(y_*|x_*,w)p(w|X,y)dw\\
&=N(x_*^t\sigma_n^{-2}A^{-1}Xy,\sigma_n^2+x_*^tA^{-1}x_*)
\end{align}

パラメータの事前分布と尤度、事後分布、予測値の分布のグラフを作成。
予測値の分布では、予測値の「平均値」の分布を表示している。分散が$\sigma_n^2$だけ小さくなる。
$\Sigma_p=I_2,\sigma_n=1$とする。

par(mfrow=c(2,2),mar=c(4,4,2,2))

my_draw_par = function(FUN,main,w1,w2){
  levels = FUN(w1,w2) - c(1,4)/2
  xgrid = seq(from=-2,to=2,by=0.01)
  ygrid = seq(from=-2,to=2,by=0.01)
  value = outer(xgrid,ygrid,FUN)
  contour(xgrid,ygrid,value,main=main,levels=levels,drawlabels=FALSE)
  abline(v=c(-2,-1,0,1,2),h=c(-2,-1,0,1,2),lty=2)
  points(w1,w2,pch=4)
}

my_prior = function(w1,w2){
  -(w1^2+w2^2) / 2
}
my_draw_par(my_prior,"prior",0,0)

xp = c(-5, 2, 5)
yp = c(-5.5, 0.5, 4.5)
cp = rep(1,length(xp))
X = rbind(cp,xp)
sigma_n = 1
A = X %*% t(X) / sigma_n^2 + diag(2)
w_Sigma = solve(A)
(w_mean = w_Sigma %*% X %*% yp / sigma_n^2)
(w_mle = solve(X %*% t(X)) %*% X %*% yp)

plot(xp,yp,type="p",pch=4,xlim=c(-6,6),ylim=c(-6,6),main="forecast")
abline(v=c(-5,0,5),h=c(-5,0,5),lty=2)
abline(a=w_mean[1],b=w_mean[2])

my_forecast = function(x,score){
  xb = rbind(rep(1,length(x)),x)
  t(xb) %*% w_mean + sqrt(colSums(xb * (w_Sigma %*% xb))) * score
}
curve(my_forecast(x,2), lty=2, add=T)
curve(my_forecast(x,-2), lty=2, add=T)

my_likelihood = function(w1,w2){
  res = t(X) %*% rbind(w1,w2) - yp
  -colSums(res * res) / 2 / sigma_n^2 - log(2*pi*sigma_n^2) * length(yp) / 2
}
my_draw_par(my_likelihood,"likelihood",w_mle[1],w_mle[2])

my_posterior = function(w1,w2){
  my_likelihood(w1,w2) + my_prior(w1,w2)
}
my_draw_par(my_posterior,"posterior",w_mean[1],w_mean[2])

regression_1.png

Function-space View

共分散行列がカーネル関数で表現される場合、予測値$f_*$の条件付き分布は以下のようになる。

\begin{align}
\binom{y}{f_*}&=N\left(0,\begin{bmatrix}
K(X,X)+\sigma_n^2I&K(X,X_*)\\
K(X_*,X)&K(X_*,X_*)\end{bmatrix}\right)\\
E[f_*]&=K_*^t(K+\sigma_n^2I)^{-1}y\\
Var(f_*)&=K(X_*,X_*)-K_*^t(K+\sigma_n^2I)^{-1}K_*\\
\end{align}

入力$x$をある写像$\phi()$で変換してから線形回帰分析することと同値。
変換後の内積$\phi(a)^t\phi(b)$がカーネル関数$K(a,b)$となる。

カーネル関数$K(a,b)=\exp(-(a-b)^2/2)$、ノイズの分散$\sigma_n^2=0$とする。
事前分布と事後分布について、サンプリングと共分散行列のグラフを作成。

par(mfrow=c(2,2),mar=c(4,4,2,2))

set.seed(1)

my_kernel = function(x1,x2){
  outer(x1,x2,function(a,b){exp(-(a-b)^2/2)})
}

my_draw_fun = function(x,mu,Sigma,main){
  f = MASS::mvrnorm(n=3, mu=mu, Sigma=Sigma)
  ymin = min(f,-2)
  ymax = max(f,2)
  plot(x=x,y=f[1,],type="n",ylim=c(ymin,ymax),xlab="",ylab="y",main=main)
  w = sqrt(pmax(diag(Sigma),0)) * 2
  polygon(x=c(x,rev(x)),y=c(mu+w,rev(mu-w)),border=NA,col="grey80")
  cols = rainbow(nrow(f))
  for(i in seq(nrow(f))){
    lines(x=x,y=f[i,],col=cols[i])
  }
}

xgrid = seq(from=-5,to=5,by=0.2)

mu_prior = rep(0,length(xgrid))
Sigma_prior = my_kernel(xgrid,xgrid)

my_draw_fun(xgrid,mu_prior,Sigma_prior,"prior")

xp = c(-4,-3,-1,0,2)
yp = c(-2,0,1,2,-1)

sigma_n = 0
K11 = my_kernel(xp,xp) + diag(length(xp)) * sigma_n^2
K12 = my_kernel(xp,xgrid)
K22 = my_kernel(xgrid,xgrid)
K11_inv = solve(K11)

mu_posterior = t(K12) %*% K11_inv %*% yp
Sigma_posterior = K22 - t(K12) %*% K11_inv %*% K12

my_draw_fun(xgrid,mu_posterior,Sigma_posterior,"posterior")
points(xp,yp,pch=4)

ids = c(16,31,41)
xgrid[ids]

my_draw_cov = function(x,covar,main){
  ymin = min(covar)
  ymax = max(covar)
  plot(x=x,y=covar[1,],type="n",ylim=c(ymin,ymax),xlab="",ylab="covariance",main=main)
  cols = rainbow(nrow(covar))
  for(i in seq(nrow(covar))){
    lines(x=x,y=covar[i,],col=cols[i])
  }
  legend("topleft",legend=paste("x",x[ids],sep="="),lty=1,col=cols)
}

my_draw_cov(xgrid,Sigma_prior[ids,],"prior")
my_draw_cov(xgrid,Sigma_posterior[ids,],"posterior")

regression_2.png

Varying the Hyperparameters

カーネル関数$K(a,b)=\sigma_k^2\exp(-(a-b)^2/2/s^2)$とする。
ハイパーパラメータ$(s,\sigma_k,\sigma_n)$を変化させた場合のグラフを作成。
モデルの良さを周辺尤度で評価する。

\begin{align}
\log p(y|X)&=-\frac{1}{2}y^t(K+\sigma_n^2I)^{-1}y-\frac{1}{2}\log|K+\sigma_n^2I|-\frac{n}{2}\log2\pi
\end{align}
my_kernel = function(x1,x2,k_sigma,k_scale){
  outer(x1,x2,function(a,b){k_sigma^2*exp(-(a-b)^2/2/k_scale^2)})
}

my_draw_xy = function(x,mu,Sigma,main){
  plot(x=x,y=mu,type="n",ylim=c(-3,3),xlab="x",ylab="y",main=main)
  sd = sqrt(pmax(diag(Sigma),0)) * 2
  polygon(x=c(x,rev(x)),y=c(mu+sd,rev(mu-sd)),border=NA,col="grey80")
  lines(x=x,y=mu,col="blue")
}

xp = c(-7.3,-6.3,-6.3,-5.9,-4.8,-4,-3.7,-2.8,-2.2,-0.9,0.5,0.7,1,2.3,2.4,4.2,4.3,4.9,5.9,6.1)
yp = c(-1.8,-0.1,0,0.2,-0.8,-1.3,-1.2,0.3,1.4,1.6,-0.1,-0.8,-1.1,-2.7,-2.3,-1.3,-1.1,-1.6,-1.1,-0.9)

my_plot_xy = function(k_scale,k_sigma,sigma_n){
  xgrid = seq(from=-7.4,to=7.4,by=0.2)

  K11 = my_kernel(xp,xp,k_sigma,k_scale) + diag(length(xp)) * sigma_n^2
  K12 = my_kernel(xp,xgrid,k_sigma,k_scale)
  K22 = my_kernel(xgrid,xgrid,k_sigma,k_scale)
  K11_inv = solve(K11)

  evidence = t(yp) %*% K11_inv %*% yp + log(det(K11)) + log(2*pi) * length(yp)
  evidence = -drop(evidence) / 2
  print(evidence)

  mu_posterior = t(K12) %*% K11_inv %*% yp
  Sigma_posterior = K22 - t(K12) %*% K11_inv %*% K12

  main = paste(sprintf("%.3g",c(k_scale,k_sigma,sigma_n)),collapse=",")
  my_draw_xy(xgrid,mu_posterior,Sigma_posterior,main=main)
  legend("topright",legend=sprintf("evidence=%.3g",evidence),bty="n")
  points(xp,yp,pch=4)
}

par(mfrow=c(2,2),mar=c(4,4,2,2))

my_plot_xy(1,1,0.1)
my_plot_xy(0.3,1.08,0.00005)
my_plot_xy(3,1.16,0.89)

regression_3.png

Marginal likelihood Maximization

周辺尤度を最大化するハイパーパラメータ$(s,\sigma_k,\sigma_n)$を求める。

par(mfrow=c(2,2),mar=c(4,4,2,2))

my_evidence = function(k_scale,k_sigma,sigma_n){
  K11 = my_kernel(xp,xp,k_sigma,k_scale) + diag(length(xp)) * sigma_n^2
  K11_inv = solve(K11)
  evidence = t(yp) %*% K11_inv %*% yp + log(det(K11)) + log(2*pi) * length(yp)
  -drop(evidence) / 2
}

my_evidence(1,1,0.1)
my_evidence(0.3,1.08,0.00005)
my_evidence(3,1.16,0.89)

my_evidence_optim = function(par){
  my_evidence(exp(par[1]),exp(par[2]),exp(par[3]))
}
rt = optim(c(0,0,0),my_evidence_optim,control=list(fnscale=-1))
rt
(par = exp(rt$par))
my_plot_xy(par[1],par[2],par[3])

my_plot_param = function(i,xlab){
  param = par[i] * seq(from=0.7,to=1.3,by=0.01)
  value = sapply(param,function(p){
    if(i==1){
      my_evidence(p,par[2],par[3])
    }else if(i==2){
      my_evidence(par[1],p,par[3])
    }else{
      my_evidence(par[1],par[2],p)
    }
  })
  plot(param,value,type="l",xlab=xlab,ylab="evidence")
  abline(v=par[i],h=rt$value,lty=2)
  legend(x=par[i],y=mean(value),legend=sprintf("%.3g",par[i]),bty="n")
}
my_plot_param(1,"k_scale")
my_plot_param(2,"k_sigma")
my_plot_param(3,"sigma_n")

regression_5.png

Smoothing, Weight Functions and Equivalent Kernels

ガウス過程をスムージングと考えた場合の重み係数のグラフを作成。

\begin{align}
\bar{f}_*&=K_*^t(K+\sigma_n^2I)^{-1}y\\
\end{align}

詳細なグリッドを仮想的に設定することで重み係数の滑らかなグラフを作成する。
トレーニングデータの数を$n$、詳細なグリッドにおけるデータの数を$n_{\mathrm{grid}}$とする。
ノイズの分散$\sigma_n^2$の大きさを以下の式に従って調整する。

\begin{align}
\sigma_{\mathrm{grid}}^2&=\sigma_n^2\frac{n_{\mathrm{grid}}}{n}\\
\bar{f}_*&=K_*^t(K+\sigma_{\mathrm{grid}}^2I)^{-1}y_{\mathrm{grid}}\\
\end{align}
par(mfrow=c(2,2),mar=c(4,4,2,2))

my_weight = function(xt,xa,k_scale,k_sigma,sigma_n,num){
  K11 = my_kernel(xt,xt,k_sigma,k_scale) + diag(length(xt)) * sigma_n^2 * (length(xt) / num)
  K12 = my_kernel(xt,xa,k_sigma,k_scale)
  K11_inv = solve(K11)
  t(K12) %*% K11_inv
}

my_draw_kernel = function(xp,xa,k_scale,k_sigma,sigma_n){
  weight = my_weight(xp,xa,k_scale,k_sigma,sigma_n,length(xp))
  xgrid = seq(from=0,to=1,by=0.005)
  weight_equ = my_weight(xgrid,xa,k_scale,k_sigma,sigma_n,length(xp))
  weight_equ = weight_equ / max(weight_equ) * max(weight)

  kernel = my_kernel(xgrid,xa,k_sigma,k_scale)
  kernel = kernel / max(kernel) * max(weight)

  ymin = min(weight,weight_equ,kernel)
  ymax = max(weight,weight_equ,kernel)
  main = paste(sprintf("%.3g",c(k_scale,k_sigma,sigma_n)),collapse=",")
  plot(x=xp,y=weight,type="p",pch=20,ylim=c(ymin,ymax),xlab="",ylab="",main=main)
  lines(xgrid,weight_equ,type="l",lty=1)
  lines(xgrid,kernel,type="l",lty=2)
  legend("topright",legend=c("t-points","equ_k","org_k"),pch=c(20,NA,NA),lty=c(0,1,2))
}

set.seed(1)
xp = runif(50)

k_scale = sqrt(0.004)
k_sigma = 1
sigma_n = sqrt(0.1)
my_draw_kernel(xp,0.5,k_scale,k_sigma,sigma_n)
my_draw_kernel(xp,0.05,k_scale,k_sigma,sigma_n)
sigma_n = sqrt(10)
my_draw_kernel(xp,0.5,k_scale,k_sigma,sigma_n)

my_draw_kernel2 = function(xa,k_scale,k_sigma,sigma_n){
  xgrid = seq(from=0,to=1,by=0.01)
  weight_a = my_weight(xgrid,xa,k_scale,k_sigma,sigma_n,10)
  weight_b = my_weight(xgrid,xa,k_scale,k_sigma,sigma_n,250)

  ymin = min(weight_a,weight_b)
  ymax = max(weight_a,weight_b)
  main = paste(sprintf("%.3g",c(k_scale,k_sigma,sigma_n)),collapse=",")
  plot(x=xgrid,y=weight_a,type="l",ylim=c(ymin,ymax),lty=2,xlab="",ylab="",main=main)
  lines(x=xgrid,y=weight_b,type="l",lty=1)
  legend("topright",legend=c(10,250),lty=c(2,1))
}
k_scale = sqrt(0.004)
k_sigma = 1
sigma_n = sqrt(0.1)
my_draw_kernel2(0.5,k_scale,k_sigma,sigma_n)

regression_4.png

グラフ上の点は、50個の一様乱数をトレーニングデータとする重み係数。
実線は、トレーニングデータの数を増やして一様に分布させた場合の重み係数。
形状が一致するようにノイズの分散$\sigma_n^2$の大きさを調整している。
破線は、元々のカーネル関数の値。
最大値が一致するようにグラフの高さを調整している。
右下のグラフは、トレーニングデータの数の違いによる重み係数の比較。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?