2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Rでガウス過程回帰の基礎

Last updated at Posted at 2020-11-18

記事の目的

この記事では、ガウス過程でいろいろなカーネルを使ってサンプリングします。
また、予測分布も可視化してみます。
参考: ガウス過程と機械学習 (機械学習プロフェッショナルシリーズ)

目次

No. 目次
1 いろいろなカーネルの計算
2 ガウス過程の可視化
3 ガウス過程回帰の予測分布
4 予測分布の可視化

1. いろいろなカーネルの計算

IMG_0202.jpeg

#データ
X <- seq(-4,4, 0.1)
N <- length(X)
#ガウスカーネル計算
K1 <- matrix(NA, nrow=N, ncol=N)
K2 <- matrix(NA, nrow=N, ncol=N)
K3 <- matrix(NA, nrow=N, ncol=N)
K4 <- matrix(NA, nrow=N, ncol=N)
theta1.1 <- 1
theta1.2 <- 1
theta3 <- 1
theta4.1 <- 1
theta4.2 <- 0.5
for(i in 1:N){
  for(j in 1:N){
    #ガウスカーネル計算
    K1[i,j] <- exp(-(X[i]-X[j])^2/theta1.2)/theta1.1
    #線形カーネル計算 (1次元ならfor文必要なし)
    K2 <- X%*%t(X)
    #指数カーネル計算
    K3[i,j] <- exp(-abs(X[i]-X[j])/theta3)
    #周期カーネル計算
    K4[i,j] <- exp(theta4.1*cos(abs(X[i]-X[j])/theta4.2))
  }
}

2. ガウス過程の可視化

library(mvtnorm)
library(ggplot2)
library(gridExtra)
set.seed(10)
kernel.plot <- function(K, name){
  y1 <- rmvnorm(1, rep(0,N), K)
  y2 <- rmvnorm(1, rep(0,N), K)
  y3 <- rmvnorm(1, rep(0,N), K)
  y4 <- rmvnorm(1, rep(0,N), K)
  ggplot(NULL,aes(x=X)) +
    geom_line(aes(y=y1)) + geom_line(aes(y=y2)) + 
    geom_line(aes(y=y3)) + geom_line(aes(y=y4)) + 
    ylim(-6, 6) + labs(x="x",y="y", title=name)
}
p1 <- kernel.plot(K1, "ガウスカーネル")
p2 <- kernel.plot(K2, "線形カーネル")
p3 <- kernel.plot(K3, "指数カーネル")
p4 <- kernel.plot(K4, "周期関数")
grid.arrange(p1, p2, p3, p4)

image.png

3. ガウス過程回帰の予測分布

IMG_0201.jpeg

#データ
set.seed(100)
N <- 10
M <- 101
X <- append(rnorm(N, 0, 5), seq(-5,5,0.1))
#カーネルの計算
K <- matrix(NA, nrow=N+M, ncol=N+M)
theta1 <- 1
theta2 <- 1
for(i in 1:(N+M)){
  for(j in 1:(N+M)){
    #ガウスカーネル
    K[i,j] <- exp(-(X[i]-X[j])^2/theta2)/theta1
  }
}
K <- K + 0.1*diag(nrow(K))
y <- rmvnorm(1, rep(0,N+M), K)
#予測分布の計算
mu <- K[(N+1):(N+M),1:(N)] %*% solve(K[1:N,1:N]) %*% y[1:N]
var <- K[(N+1):(N+M), (N+1):(N+M)] - 
  K[(N+1):(N+M),1:(N)] %*% solve(K[1:N,1:N]) %*% K[1:N, (N+1):(N+M)]

4. 予測分布の可視化

sigma.2 <- diag(var)
ggplot(NULL,aes(x=X[(1+N):(N+M)])) +
  geom_line(aes(y=mu), size=1.2, color="blue") + 
  geom_ribbon(aes(ymin=mu+2*sigma.2,ymax=mu-2*sigma.2), alpha=0.1, fill="blue") + 
  geom_point(aes(x=X[1:N],y=y[1:N]), color="blue", shape=1, size=5) +
  xlim(-5, 5)+ ylim(-2.5,2.5) +
  labs(x="x",y="y", title="ガウス過程回帰の予測分布")

image.png

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?