search
LoginSignup
4

More than 5 years have passed since last update.

posted at

updated at

{rstan} Rstanでガウス過程の実装

一次元のN個の入力$x$と対応する出力$y$を想定する。
出力が未知の入力$x_2$を考え、これに対応する$y_2$を推定する。

承前:ガウス過程の考え方

参考:
ガウス過程シリーズ1 概要
ガウス過程シリーズ2 高速化&フルベイズ
『StanとRでベイズ統計モデリング』
『Gaussian Processes for Machine Learning』
『パターン認識と機械学習 (下)』

0. sample data

まずはRでサンプルデータを作ろう。

y \sim N(sin(x),\ ε)

で、いっか。

set.seed(123)
N <- 20
x <- runif(N, -5, 5)
y <- sin(x) + rnorm(N, 0, 0.01)

x2も準備しておく。

N2 <- 75
x2 <- seq(-5, 5, length= N2)

1. stan code

ではstan。

以下、ブロックごとに順番に書いていけば良い。

1-1. data

上3行は既知のxとそれに対応するyの入力。
下2行は、予測したい定義域x2を格納。

data{
 int<lower = 1> N;
 vector[N] x;
 vector[N] y; 

 int<lower = 1> N2;
 vector[N2] x2;
}

1-2. transformed data

y\sim N_n(\mu,\ cov)\tag{1}

としたい(コード中で $\mu$ はmu)。
データyを事前に標準化する事で、$\mu=0$と仮定する。
$\mu$の長さは当然$y$と等しいのでN。

transformed data{
 vector[N] mu;
 for(i in 1:N)
   mu[i] = 0;
}

1-3. parameters

式(1)の$cov$を推定したい。
kernel関数を用い、

cov[i, j] = k(x_i,x_j)
=\eta^2exp\big(-\rho^2(x_i-x_j)^2\big)+σ^2δ_{ij} \tag{2}

($\delta_{ij}$はクロネッカのデルタ)

従って、推定すべきハイパーパラメータは3つ。
$\eta^2$(eta_sq), $\rho^2$(rho_sq), $\sigma^2$(sigma_sq)。

parameters{
 real<lower = 0> eta_sq;
 real<lower = 0> rho_sq;
 real<lower = 0> sigma_sq;
}

1-4. transformed parameters

式(2)をそのまま書く。
$cov$はCov

transformed parameters{
 matrix[N, N] Cov;

 for(i in 1:(N-1)){
   for(j in (i+1):N){
     Cov[i, j] = eta_sq * exp(- rho_sq * pow(x[i] - x[j], 2));
     Cov[j, i] = Cov[i, j];
   }
 }

 for(k in 1:N)
   Cov[k, k] = sigma_sq;
}

1-5. model

モデルは式(1)そのもの。
$Cauchy$分布を使って裾野広い事前分布を与えておく。

model{
 y ~ multi_normal(mu, Cov);

 eta_sq ~ cauchy(0, 5);
 rho_sq ~ cauchy(0, 5);
 sigma_sq ~ cauchy(0, 5); 
}

1-6. generated quantities

で、$y$の推定結果に基づいて、いよいよ、$y_2$の推定。
式(1)で与えられる$y$の分布による条件付き多変量正規分布を推定する。
$\mu_2$(mu2)と$cov_2$(Cov2)が分かれば、$y_2$(y2)が推定できる。

y_2 \sim N_{n_2}(\mu_2,\ cov_2)\tag{3}

$y$と$y_2$が同一の多変量正規分布から派生したとして、

\begin{bmatrix}y \\y_2 \end{bmatrix} \sim N_{n+n_2}(\hat{\mu},\ \hat{cov})\tag{4}

標準化を前提として、$\hat{\mu}=0$。

\hat{cov} = \begin{bmatrix}cov & K \\ K^t & \Sigma\end{bmatrix} \tag{5}

この時、$K$(K)は、

K(x_{i}, x_{2j})=\eta^2exp\big(-\rho^2(x_{i}-x_{2j})^2\big) \tag{6}

$x$と$x_2$の長さが違う事に注意。(Kは正方行列では無い。)

また、$\Sigma$(Sigma)は、

\Sigma(x_{2i}, x_{2j})=\eta^2exp\big(-\rho^2(x_{2i}-x_{2j})^2\big)+σ^2δ_{ij} \tag{7}

求めたい$\mu_2$(mu2)と$cov_2$(Cov2)は、条件付き多変量正規分布の性質から、

\begin{eqnarray}
\mu_2&=& 0+K^tcov^{-1}(y-\mu)=K^tcov^{-1}y\\
cov_2&=&\Sigma-K^t cov^{-1} \tag{8}
\end{eqnarray}

$\mu_2$の初項のゼロは何かというと、$\hat{\mu}=(\mu, \hat{\mu_2})$の$\hat{\mu_2}=0$。
記号がややこしいのであえて書かなかった。詳細は前回記事

これを実装。
stanでは、$K^t$は、K'と書くことに注意して、書き下す。

generated quantities{
  vector[N2] y2;
  vector[N2] mu2;
  matrix[N2, N2] Cov2;
  matrix[N, N2] K;
  matrix[N2, N2] Sigma;
  matrix[N2, N] K_t_Cov;

// 式(6)
  for (i in 1:N)
    for (j in 1:N2)
      K[i, j] = eta_sq * exp(-rho_sq * pow(x[i] - x2[j],2));

// 式(7)
  for(i in 1:(N2-1)){
    for(j in (i+1):N2){
      Sigma[i, j] = eta_sq * exp(-rho_sq * pow(x2[i] - x2[j], 2));
      Sigma[j, i] = Sigma[i,j];
    }
  }

  for(k in 1:N2)
    Sigma[k, k] = sigma_sq;

// 式(8)-1
  K_t_Cov = K' / Cov;  
  mu2 = K_t_Cov * y;

// 式(8)-2
  Cov2 = Sigma - K_t_Cov * K;

  for(i in 1:N2)
    for(j in (i+1):N2)
      Cov2[i, j] = Cov2[j, i];

// 式(3)
  y2 = multi_normal_rng(mu2, Cov2);
}

※ 最後の式(3)の実装部分。
stanコード中の$\sim$は、sampling statement formという。
内部では対数確率の足し上げをしている。

今回は乱数を生成したいので、multi_normal_rngを使う。
multi_normal_rngをコレスキー分解を使って回避して、normal_rngのみで計算する事も可能です。
cf. 続ガウス過程実装:コレスキー分解を使った表記

2. R code

これは、{rstan}経験者であれば特に問題ないでしょう。
最初のセクションで用意したデータをlistにしておいて、fit関数に保存した.stanファイル名を指定するだけです。

library("rstan")
rstan_options(auto_write=TRUE)
options(mc.cores=parallel::detectCores())

dat <- list(N = N, x = x, y = y, N2 = N2, x2 = x2)
fit <- stan(file = "gp.stan", data = dat, seed =123,
            chain = 3, iter = 2000, warmup = 500)

2-3行は、並列計算のおまじない。
rstan::stanのオプションはさっさと計算するための指定。
手元の環境で、上の設定だと1chainあたり200-700秒ぐらい(遅)。
色々警告が出ますが、収束しました。

可視化は、まぁ、こんな感じでしょうか。

A <- extract(fit)

y2_med <- apply(A$y2, 2, median)
y2_max <- apply(A$y2, 2, quantile, probs = 0.05)
y2_min <- apply(A$y2, 2, quantile, probs = 0.95)

dat_g <- data.frame(x2, y2_med, y2_max, y2_min)
dat_g2 <- data.frame(x, y)

ggplot(dat_g, aes(x2, y2_med))+
  theme_classic()+
  geom_ribbon(aes(ymax = y2_max, ymin = y2_min), alpha = 0.2)+
  geom_line()+
  geom_point(data = dat_g2, aes(x, y))+
  xlab("x") + ylab("y")

環境

sessionInfo()

> R version 3.4.0 (2017-04-21)
> Platform: x86_64-apple-darwin15.6.0 (64-bit)
> Running under: macOS Sierra 10.12.4
> 
> attached base packages:
> [1] stats     graphics  grDevices utils     datasets  methods   base     
> 
> other attached packages:
> [1] rstan_2.15.1         StanHeaders_2.15.0-1 ggplot2_2.2.1       
> 
> colorspace_1.3-2
> [17] gridExtra_2.2.1  tibble_1.3.1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
4