Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
4
Help us understand the problem. What are the problem?

More than 3 years have passed since last update.

posted at

updated at

{rstan} Rstanでガウス過程の実装

一次元のN個の入力$x$と対応する出力$y$を想定する。
出力が未知の入力$x_2$を考え、これに対応する$y_2$を推定する。

承前:ガウス過程の考え方

参考:
ガウス過程シリーズ1 概要
ガウス過程シリーズ2 高速化&フルベイズ
『StanとRでベイズ統計モデリング』
『Gaussian Processes for Machine Learning』
『パターン認識と機械学習 (下)』

0. sample data

まずはRでサンプルデータを作ろう。

y \sim N(sin(x),\ ε)

で、いっか。

set.seed(123)
N <- 20
x <- runif(N, -5, 5)
y <- sin(x) + rnorm(N, 0, 0.01)

x2も準備しておく。

N2 <- 75
x2 <- seq(-5, 5, length= N2)

1. stan code

ではstan。

以下、ブロックごとに順番に書いていけば良い。

1-1. data

上3行は既知のxとそれに対応するyの入力。
下2行は、予測したい定義域x2を格納。

data{
 int<lower = 1> N;
 vector[N] x;
 vector[N] y; 

 int<lower = 1> N2;
 vector[N2] x2;
}

1-2. transformed data

y\sim N_n(\mu,\ cov)\tag{1}

としたい(コード中で $\mu$ はmu)。
データyを事前に標準化する事で、$\mu=0$と仮定する。
$\mu$の長さは当然$y$と等しいのでN。

transformed data{
 vector[N] mu;
 for(i in 1:N)
   mu[i] = 0;
}

1-3. parameters

式(1)の$cov$を推定したい。
kernel関数を用い、

cov[i, j] = k(x_i,x_j)
=\eta^2exp\big(-\rho^2(x_i-x_j)^2\big)+σ^2δ_{ij} \tag{2}

($\delta_{ij}$はクロネッカのデルタ)

従って、推定すべきハイパーパラメータは3つ。
$\eta^2$(eta_sq), $\rho^2$(rho_sq), $\sigma^2$(sigma_sq)。

parameters{
 real<lower = 0> eta_sq;
 real<lower = 0> rho_sq;
 real<lower = 0> sigma_sq;
}

1-4. transformed parameters

式(2)をそのまま書く。
$cov$はCov

transformed parameters{
 matrix[N, N] Cov;

 for(i in 1:(N-1)){
   for(j in (i+1):N){
     Cov[i, j] = eta_sq * exp(- rho_sq * pow(x[i] - x[j], 2));
     Cov[j, i] = Cov[i, j];
   }
 }

 for(k in 1:N)
   Cov[k, k] = sigma_sq;
}

1-5. model

モデルは式(1)そのもの。
$Cauchy$分布を使って裾野広い事前分布を与えておく。

model{
 y ~ multi_normal(mu, Cov);

 eta_sq ~ cauchy(0, 5);
 rho_sq ~ cauchy(0, 5);
 sigma_sq ~ cauchy(0, 5); 
}

1-6. generated quantities

で、$y$の推定結果に基づいて、いよいよ、$y_2$の推定。
式(1)で与えられる$y$の分布による条件付き多変量正規分布を推定する。
$\mu_2$(mu2)と$cov_2$(Cov2)が分かれば、$y_2$(y2)が推定できる。

y_2 \sim N_{n_2}(\mu_2,\ cov_2)\tag{3}

$y$と$y_2$が同一の多変量正規分布から派生したとして、

\begin{bmatrix}y \\y_2 \end{bmatrix} \sim N_{n+n_2}(\hat{\mu},\ \hat{cov})\tag{4}

標準化を前提として、$\hat{\mu}=0$。

\hat{cov} = \begin{bmatrix}cov & K \\ K^t & \Sigma\end{bmatrix} \tag{5}

この時、$K$(K)は、

K(x_{i}, x_{2j})=\eta^2exp\big(-\rho^2(x_{i}-x_{2j})^2\big) \tag{6}

$x$と$x_2$の長さが違う事に注意。(Kは正方行列では無い。)

また、$\Sigma$(Sigma)は、

\Sigma(x_{2i}, x_{2j})=\eta^2exp\big(-\rho^2(x_{2i}-x_{2j})^2\big)+σ^2δ_{ij} \tag{7}

求めたい$\mu_2$(mu2)と$cov_2$(Cov2)は、条件付き多変量正規分布の性質から、

\begin{eqnarray}
\mu_2&=& 0+K^tcov^{-1}(y-\mu)=K^tcov^{-1}y\\
cov_2&=&\Sigma-K^t cov^{-1} \tag{8}
\end{eqnarray}

$\mu_2$の初項のゼロは何かというと、$\hat{\mu}=(\mu, \hat{\mu_2})$の$\hat{\mu_2}=0$。
記号がややこしいのであえて書かなかった。詳細は前回記事

これを実装。
stanでは、$K^t$は、K'と書くことに注意して、書き下す。

generated quantities{
  vector[N2] y2;
  vector[N2] mu2;
  matrix[N2, N2] Cov2;
  matrix[N, N2] K;
  matrix[N2, N2] Sigma;
  matrix[N2, N] K_t_Cov;

// 式(6)
  for (i in 1:N)
    for (j in 1:N2)
      K[i, j] = eta_sq * exp(-rho_sq * pow(x[i] - x2[j],2));

// 式(7)
  for(i in 1:(N2-1)){
    for(j in (i+1):N2){
      Sigma[i, j] = eta_sq * exp(-rho_sq * pow(x2[i] - x2[j], 2));
      Sigma[j, i] = Sigma[i,j];
    }
  }

  for(k in 1:N2)
    Sigma[k, k] = sigma_sq;

// 式(8)-1
  K_t_Cov = K' / Cov;  
  mu2 = K_t_Cov * y;

// 式(8)-2
  Cov2 = Sigma - K_t_Cov * K;

  for(i in 1:N2)
    for(j in (i+1):N2)
      Cov2[i, j] = Cov2[j, i];

// 式(3)
  y2 = multi_normal_rng(mu2, Cov2);
}

※ 最後の式(3)の実装部分。
stanコード中の$\sim$は、sampling statement formという。
内部では対数確率の足し上げをしている。

今回は乱数を生成したいので、multi_normal_rngを使う。
multi_normal_rngをコレスキー分解を使って回避して、normal_rngのみで計算する事も可能です。
cf. 続ガウス過程実装:コレスキー分解を使った表記

2. R code

これは、{rstan}経験者であれば特に問題ないでしょう。
最初のセクションで用意したデータをlistにしておいて、fit関数に保存した.stanファイル名を指定するだけです。

library("rstan")
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

dat <- list(N = N, x = x, y = y, N2 = N2, x2 = x2)
fit <- stan(file = "gp.stan", data = dat, seed =123,
            chain = 3, iter = 2000, warmup = 500)

2-3行は、並列計算のおまじない。
rstan::stanのオプションはさっさと計算するための指定。
手元の環境で、上の設定だと1chainあたり200-700秒ぐらい(遅)。
色々警告が出ますが、収束しました。

可視化は、まぁ、こんな感じでしょうか。

A <- extract(fit)

y2_med <- apply(A$y2, 2, median)
y2_max <- apply(A$y2, 2, quantile, probs = 0.05)
y2_min <- apply(A$y2, 2, quantile, probs = 0.95)

dat_g <- data.frame(x2, y2_med, y2_max, y2_min)
dat_g2 <- data.frame(x, y)

ggplot(dat_g, aes(x2, y2_med))+
  theme_classic()+
  geom_ribbon(aes(ymax = y2_max, ymin = y2_min), alpha = 0.2)+
  geom_line()+
  geom_point(data = dat_g2, aes(x, y))+
  xlab("x") + ylab("y")

環境

sessionInfo()

> R version 3.4.0 (2017-04-21)
> Platform: x86_64-apple-darwin15.6.0 (64-bit)
> Running under: macOS Sierra 10.12.4
> 
> attached base packages:
> [1] stats     graphics  grDevices utils     datasets  methods   base     
> 
> other attached packages:
> [1] rstan_2.15.1         StanHeaders_2.15.0-1 ggplot2_2.2.1       
> 
> colorspace_1.3-2
> [17] gridExtra_2.2.1  tibble_1.3.1
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
4
Help us understand the problem. What are the problem?