1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RとStanで実装するガウス過程回帰モデル

Posted at

はじめに

こんにちは。事業会社で働いているカレー好きなデータサイエンティストです。

今回は、マーケティングデータによる売上予測をガウス過程回帰で行います。ガウス過程の強みは、入力$x$が高次元であっても柔軟なモデリングが可能となるところです。

データが複雑な変動を見せている、かつデータ数が大きくなければ、ガウス過程によるモデリングを試してみるのもよいと思います。

ガウス過程回帰モデル

学習データの組$\lbrace x_i, y_i \rbrace _{i=1}^N$について、

y_i = f(x_i) + \epsilon_i, \quad \epsilon_i \sim N(0, \sigma ^2)

なる関係が成立しているとする。

いま、$y = (y_1, \cdots, y_N)$の確率分布は

p(y|x)=N(f, \sigma ^2 I)  

である(ただし、$f=f(x_1), \cdots, f(x_N)$)。

ガウス過程を$f\sim N(\mu, K)$で表すとすると、一般に

p(y|x) =N(\mu, K + \sigma ^2 I).

ただし、カーネル関数$k'(x,x')$は

k'(x,x') = k(x,x')+\sigma ^2 I

で、$k(x,x')$は

k(x,x')=\tau^2\cdot \exp(-\frac{\lVert x - x'\rVert^2}{2 h^2}) \quad \cdots (*)

を満たす。

以上の設定のもとで、予測したい点$x^* = (x_1^*, \cdots, x_M ^ *)'$に対応する出力$y ^ * = (y _ 1 ^ *, \cdots, y _M ^ *)'$を予測するために、$(N+M)\times (N+M)$のカーネル行列$K'$を考える。

ここで、

\left(
\begin{array}{c} 
y \\\ y* 
\end{array} 
\right) \sim N(
\left(
\begin{array}{c} 
0 \\\ 0
\end{array} 
\right), 
\left(
\begin{array}{c} 
K & k_* \\\ k_*' & k_{**}
\end{array} 
\right)) 

であり、

$k_{*} = (k'(x ^ *, x_1), k'(x ^ *, x_2), \cdots, k'(x ^ *, x_N))'$
$k_{**} = k(x ^ *, x ^ *)$

を満たす。

一般に、

\left(
\begin{array}{c} 
y_1 \\\ y_2 
\end{array} 
\right) \sim N(
\left(
\begin{array}{c} 
\mu_1 \\\ \mu_2
\end{array} 
\right), 
\left(
\begin{array}{c} 
\Sigma_{11} & \Sigma_{12} \\\ \Sigma_{21} & \Sigma_{22}
\end{array} 
\right)) 

のとき、ベクトルの一部$y_1$が与えられたときの残りの$y_2$は

p(y_2|y_1) = N(\mu_2 + \Sigma_{21}\Sigma_{11}^{-1}(y_1 - \mu_1), \Sigma_{22}-\Sigma_{11}^{-1}\Sigma_{12})

となることから、$\mu_1 = 0, \mu_2 = 0$より

p(y_2|y_1) = N(\mu_2 + \Sigma_{21}\Sigma_{11}^{-1}y_1, \Sigma_{22}-\Sigma_{11}^{-1}\Sigma_{12})

したがって、ガウス過程の予測分布は以下のようになる。

p(y^*|x^*,\mathfrak D) = N(k_*'K^{-1}y, k_{**}'K^{-1}k_*)

ただし、$\mathfrak D = \lbrace x_i, y_i \rbrace _{i=1}^N.$

RとStanで実装

ここでは70週にわたる架空の売上データを扱います。2種類のメディアとサービス(プロダクト)の検索数を共変量として、最後の5週について予測することを考えます。

set.seed(123) 

data <- tibble::tibble(
  amount = runif(70, min=1000, max=5000),
  media1 = runif(70, min=0, max=500),
  media2 = runif(70, min=0, max=500),
  search = round(runif(70, min=0, max=1000))
) |>
  dplyr::mutate(month = seq(from = as.Date("2024-01-01"),
                            length.out = 70,
                            by = "week")) |>
  dplyr::mutate(dplyr::across(c(amount, media1, media2, search), ~ mean(.), .names = "{.col}_mean")) |>
  dplyr::mutate(dplyr::across(c(amount, media1, media2, search), ~ sd(.), .names = "{.col}_sd")) |>
  dplyr::mutate(dplyr::across(c(amount, media1, media2, search), ~ (. - mean(.))/sd(.), .names = "{.col}_normalized")) 

売上の変動を可視化します。

ggplot2::ggplot(data) +
  ggplot2::geom_line(ggplot2::aes(x = month, y = amount), color = "red") +
  ggplot2::scale_x_date(date_breaks = "2 month", date_labels = "%Y-%m") +
  ggplot2::theme_bw() +
  ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))

売上.png

続いてStanファイルの実装。
$(*)$内の$\tau$はmagnitude, $h$はlength_scaleとしています。

functions {
  vector gp_pred_rng(vector[] x_pred,
                     vector y, 
                     vector[] x,
                     real magnitude, 
                     real length_scale) {

  int N1 = rows(y);
  int N2 = size(x_pred);
  vector[N2] y_pred;
  
  {
  matrix[N1, N1] K = cov_exp_quad(x, magnitude, length_scale);
  matrix[N1, N1] L_K = cholesky_decompose(K);
  vector[N1] L_K_div_y = mdivide_left_tri_low(L_K, y);
  vector[N1] K_div_y = mdivide_right_tri_low(L_K_div_y', L_K)';
  matrix[N1, N2] k_x_x_pred = cov_exp_quad(x, x_pred, magnitude, length_scale);
  vector[N2] y_pred_mu = (k_x_x_pred' * K_div_y);
  matrix[N1, N2] v_pred = mdivide_left_tri_low(L_K, k_x_x_pred);
  matrix[N2, N2] cov_y_pred = cov_exp_quad(x_pred, magnitude, length_scale) - v_pred' * v_pred + diag_matrix(rep_vector(1e-6, N2));

y_pred = multi_normal_rng(y_pred_mu, cov_y_pred);
  }
return y_pred;
 }
}

data {
  int<lower=1> N;
  int<lower=1> D;
  vector[D] x[N];
  vector[N] y;
  int<lower=1> N_pred;
  vector[D] x_pred[N_pred];
}

transformed data {
  vector[N] mu;
  mu = rep_vector(0, N);
}

parameters {
  real<lower=0> magnitude;
  real<lower=0> length_scale;
  real<lower=0> eta;
}

transformed parameters {
  matrix[N, N] L_K;
  {
  matrix[N, N] K;
  K = gp_exp_quad_cov(x, magnitude, length_scale);
  K = add_diag(K, square(eta));
  L_K = cholesky_decompose(K);
  }
}

model {
  magnitude ~ normal(0, 2);
  length_scale ~ inv_gamma(5, 5);
  y ~ multi_normal_cholesky(mu, L_K);
}

generated quantities {
  vector[N_pred] f_pred = gp_pred_rng(x_pred, y, x, magnitude, length_scale);
  vector[N_pred] y_pred;
  for (n in 1:N_pred) {
    y_pred[n] = normal_rng(f_pred[n], 1.0);
  }
}


Stanファイルをコンパイルし、変分推定を行います。そのあと、事後分布からサンプルを抽出し、信用区間を含むデータフレームを作ります。

data_list <- list(
  N = nrow(data) - 5,
  D = 3,
  x = data[1:(nrow(data)-5),15:17],
  y = data$amount_normalized[1:(nrow(data)-5)],
  N_pred = nrow(data[1:nrow(data),]),
  x_pred = data[1:(nrow(data)),15:17]
)

stan_file <- "gp.stan"
stan_model <- rstan::stan_model(file = stan_file)

fit_vb <- rstan::vb(stan_model, data = data_list, iter = 10000, tol_rel_obj = 1e-3)

sample_pred <- rstan::extract(fit_vb)
y_predicted <- sample_pred$y_pred


prediction_df <- tibble::tibble(
  pred_normalized = asplit(y_predicted, 2) |> 
    purrr::map_dbl(mean),
  pred_normalized_lower = asplit(y_predicted, 2) |> 
    purrr::map_dbl(~quantile(., probs = 0.175)),
  pred_normalized_upper = asplit(y_predicted, 2) |>
    purrr::map_dbl(~quantile(., probs = 0.825))
) |>
  dplyr::bind_cols(data) |>
  dplyr::mutate(
    predicted = pred_normalized * amount_sd + amount_mean,         
    Lower65CI = pred_normalized_lower * amount_sd + amount_mean,
    Upper65CI = pred_normalized_upper * amount_sd + amount_mean,
    error_rate = (predicted - amount)/amount*100
  ) 

実際の値と予測値の比較。

実績値と予測値.png

信用区間から外れている実績値もありますが、概ね信用区間内に収まっています。

おわりに

ガウス過程の強みは、入力$x$が高次元であっても柔軟なモデリングが可能となるところです。これは線形回帰モデルのパラメータについて期待値をとり、積分消去するところから来ています。

一方、入力の数が増えれば増えるほど、分散-共分散行列のサイズが大きくなるため、MCMCや変分推定の計算にも負担がかかります。

データが複雑な変動を見せているときに、データ数が今回のように大きくなければ、ガウス過程によるモデリングを行うことも手段の一つではないでしょうか(とくに階層モデリングでは、ユニット間で共通のトレンド成分や季節成分をガウス過程でモデリングすることが有効だと思います。)

参考文献

持橋大地・大羽成征.2019.「ガウス過程と機械学習」.講談社サイエンティフィク.
Stan User's Guide.https://mc-stan.org/docs/stan-users-guide/gaussian-processes.html(2025年6月3日最終閲覧).

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?