3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

グループ間の相関を考慮した階層ベイズモデル

Last updated at Posted at 2024-10-31

はじめに

こんにちは、事業会社で働いている新卒データサイエンティストです。

今回は、階層ベイズモデルでグループ間の相関を入れることを考えます。

具体的には、一つのサービスのもとで複数のセクションが存在している状況などがこれにあたります。この場合、各セクションは独立しているわけではなく、互いに何らかの相関を持っている可能性があります。こうした状況でセクションごとの売上予測を推定することを考えます。

売上データの生成

説明の簡単のため、売上がトレンド成分と季節成分から構成されると考えます。2020年4月から2024年9月までの月次データを生成させました。今回は、4つのセクションがあると考えます。

df_sales <- tibble::tibble(date = rep(seq(from = as.Date("2020-04-01"),
                                 to = as.Date("2024-09-01"),
                                 by = "month"), each = 4)) |>
  #セクションの設定
  dplyr::mutate(section = rep(c("A", "B", "C", "D"), length.out = dplyr::n())) |>
  #季節成分とトレンド成分の設定
  dplyr::group_by(section) |>
  dplyr::mutate(day_of_year = lubridate::yday(date),
                seasonality = 0.5*(-sin(2*2*pi*day_of_year/365.5 + 
                                          cos(1*2*pi*day_of_year/365.5)) 
                                   + rnorm(n = dplyr::n(), mean = 0, sd = 1))) |>
  dplyr::mutate(trend = (seq(from = 1000,
                             to = 2000,
                             length.out = dplyr::n()) ^(1/4) - 1 + rnorm(n = dplyr::n(), mean = 0, sd = 1))) |>
  dplyr::ungroup() |>
  #誤差項を追加
  dplyr::mutate(epsilon = rnorm(n = dplyr::n(), mean = 0, sd = 0.25)) |>
  #季節成分・トレンド成分・誤差項を足し合わせて売上とする
  dplyr::mutate(sales = trend + seasonality + epsilon)  |>
  dplyr::group_by(section) |>
  #売上を標準化
  dplyr::mutate(sales_mean = mean(sales),
                sales_sd = sd(sales),
                sales_normalized = (sales - sales_mean)/sales_sd) |>
  dplyr::ungroup()

各セクションについて、以下のプロットで示されるような売上データが生成されました(本当は、もう少しセクション別に売上額の幅が異なるデータになるとよいのですが)。

ggplot2::ggplot(df_sales, ggplot2::aes(x = date, y = sales)) +
  ggplot2::geom_line() +
  ggplot2::facet_wrap(~section, nrow = 2) +
  ggplot2::theme_bw()

image.png

状態空間モデルによる推定

以下のような線形ガウス状態空間モデルを考えます。

$\mu_{s,t} - \mu_{s,t}= \mu_{s,t-1} - \mu_{s,t-2} + \xi_t, \quad where \quad \xi_t \sim N(0, \sigma_{\xi}^2) \quad \cdots ①$
$y_{s,t} = \mu_{s,t} + \gamma_t + r + \epsilon_t, \quad where \quad \epsilon_t \sim N(0, \sigma_{\epsilon}^2) \quad \cdots ②$

①が状態方程式、②が観測方程式にあたります。

ただし、
$s$:セクション$\in \lbrace1,\cdots, S \rbrace$
$t$:時間$\in \lbrace1,\cdots, T \rbrace$
$\mu_{s,t}$:セクション$s$の$t$期の水準(状態)
$\gamma_s$:セクション$s$の季節成分
$y_{s,t}$:セクション$s$の$t$期の売上
$\xi_t$:状態方程式の誤差項
$\epsilon_t$:観測方程式の誤差項
とします。

ここで、状態$\mu_{s,t}$がセクション間で相関を持つことを考えます。

${\bf \mu_{t}} \sim MVN(2{\bf \mu_{t-1}-\mu_{t-2}}, \Sigma)$

ただし、

$ \Sigma = \left(
\begin{array}{cccc}
w_{11} & w_{12} & \cdots & w_{1S}
\\w_{21} & w_{22} & \cdots & w_{2S}
\\・ & \vdots & \ddots & \vdots
\\w_{S1} & w_{S2} & \cdots & w_{SS}
\end{array}
\right)$

で、とくに

$ w = \left(
\begin{array}{cccc}
w_{21}
\\w_{31} & w_{32}
\\・ & \vdots & \ddots
\\w_{S1} & w_{S2} & \cdots & w_{S,S-1}
\end{array}
\right),$

$ \tau = \left(
\begin{array}{cccc}
w_{11}
\\ & w_{22}
\\ & & \ddots &
\\ & & & w_{SS}
\end{array}
\right)$
とします。

ここでは
$w \sim InvWishart(S,I_S) \quad (I_SはS\times Sの単位行列),$
$\tau \sim Cauchy(0,5)$
となるように設定しました。

Stanで実装

以下のstanファイルを作成します。

local_linear_trend_correlation.stan
data {
   int time_type;                            // データ取得期間の長さ
   int pred_type;                            // 予測期間の長さ
   int section_type;                         // セクションの数
   matrix[time_type, section_type] outcome;  // 観測値
 }
 
 parameters {
   matrix[section_type, time_type] mu;       // 水準+ドリフト成分の推定値
   vector[time_type] gamma;            // 季節成分の推定値
   vector[section_type] r;            // セクションのランダム効果
   real<lower=0> s_v;                         // 観測誤差の標準偏差
   real<lower=0> s_s;                         // 季節成分の標準偏差
   real<lower=0> s_r;                         // ランダム効果の標準偏差
   vector<lower=0>[section_type] tau;         // 共分散
   corr_matrix[section_type] omega;           // 相関
 }
 
 transformed parameters {
   cov_matrix[section_type] Sigma;
   Sigma = quad_form_diag(omega, tau);       
 
   matrix[time_type, section_type] alpha;  // 各成分の和
   
   for (i in 1:time_type) {
     for (j in 1:section_type) {
         alpha[i, j] = mu[j,i] + gamma[i] + r[j];
     }
   }
 }
 
 model {
   // Adjusted priors for stability
   gamma ~ normal(0, 1);
   s_v ~ cauchy(0, 2);
   s_s ~ cauchy(0, 2);
   s_r ~ cauchy(0, 2);
   tau ~ cauchy(0, 5); 
   omega ~ inv_wishart(section_type, identity_matrix(section_type));  // 相関行列の事前分布
   
   // ランダム効果
   r ~ normal(0, s_r);
 
   // 水準+ドリフト成分
   for (i in 3:time_type) {
     mu[:,i] ~ multi_normal(2 * mu[:,i-1] - mu[:,i-2], Sigma);
   }


   // 季節成分
   for (i in 12:time_type) {
     gamma[i] ~ normal(-sum(gamma[(i - 11):(i - 1)]), s_s);
   }
   
   // 観測方程式に従い、観測値が得られる
   for (i in 1:time_type) { 
     for (j in 1:section_type) {
       outcome[i, j] ~ normal(alpha[i, j], s_v);
     }
   }
 }
 
 generated quantities {
   matrix[section_type, time_type + pred_type] mu_pred;       // 予測値も含めた状態の推定値
   vector[time_type + pred_type] gamma_pred;             // 予測値も含めた季節成分の推定値
   matrix[time_type + pred_type, section_type] alpha_pred;    // 予測値も含めた状態と季節成分の推定値
 
   // データ取得期間においては、状態推定値muと同じ
   for (i in 1:time_type) {
     mu_pred[:,i] = mu[:,i];
     gamma_pred[i] = gamma[i];
     for (j in 1:section_type) {
       alpha_pred[i, j] = alpha[i, j];
     }
   }
 
   // データ取得期間を超えた部分を予測
   for (i in 1:pred_type) {
     int t = time_type + i;
     mu_pred[:,t] = multi_normal_rng(2 * mu_pred[:,t - 1] - mu_pred[:,t - 2], Sigma);
     gamma_pred[t] = normal_rng(-sum(gamma_pred[(t - 11):(t - 1)]), s_s);
     for (j in 1:section_type) {
       alpha_pred[t, j] = mu_pred[j, t] + gamma_pred[t] + r[j];
     }
   }
 }

変分推論

まず、1か月先の予測値を格納するためのデータフレームを作成します。あとでdf_salesと結合できるように、同じカラム名で作成しました。

df_sales_pred <- tibble::tibble(date = rep(df_date$date[nrow(df_date)] |> lubridate::`%m+%`(months(1)), each = 4)) |>
  dplyr::mutate(section = c("A", "B", "C", "D"),
                day_of_year = lubridate::yday(date),
                seasonality = rep(NA, each = 4),
                trend = rep(NA, each = 4),
                epsilon = rep(NA, each = 4),
                sales = rep(NA, each = 4),
                sales_mean = c(df_sales_merged$sales_mean[1],
                                   df_sales_merged$sales_mean[2],
                                   df_sales_merged$sales_mean[3],
                                   df_sales_merged$sales_mean[4]),
                sales_sd = c(df_sales_merged$sales_sd[1],
                                 df_sales_merged$sales_sd[2],
                                 df_sales_merged$sales_sd[3],
                                 df_sales_merged$sales_sd[4]),
                sales_normalized = rep(NA, each = 4),
                ) 

df_sales_predをdf_salesと結合させます。

df_sales_merged <- df_sales |>
  dplyr::bind_rows(df_sales_pred)

いま、df_sales_mergedはロング形式のデータです。stanに渡す便宜上、ワイド形式のデータに変換しておきます。

df_sales_merged_wider <- df_sales_merged |>
  dplyr::select(c(date, section,  sales_normalized)) |>
  tidyr::pivot_wider(names_from = "section",
                     values_from = "sales_normalized")

続いて、

  • stanファイルの指定
  • stanコンパイル
  • データリストの作成
    を行います。
#stanファイルの指定
ssm <- "local_linear_trend_correlation.stan"

#コンパイル
stan_model <- rstan::stan_model(file = ssm)

#データリストの作成
data_list <- list(
  time_type = length(df_sales_merged_wider$date[-nrow(df_sales_merged_wider)]),
  pred_type = 1,
  section_type = 4,
  outcome = df_sales_merged_wider[-nrow(df_sales_merged_wider), 2:5]
)

変分ベイズを実行し、サンプリングされたパラメータを取り出します。

vb <- rstan::vb(stan_model, data = data_list, iter = 10000, tol_rel_obj = 1e-3)
sample_pred <- rstan::extract(vb)

ここで、取り出したサンプルから信用区間を構成できるように関数を組んでおくと便利です({cmdstanr}だとsummary()を使って簡単に信用区間を構成できるのですが、{rstan}では上手くいかず…)。

summary_matrix <- function(sample_preds, probs, value_names) {
  #結果をリストとして格納
  results_list <- list() 
  
  #各サンプルについて以下の処理を行う
  for (i in seq_along(sample_preds)) { 
    #各分位について以下の処理を行う
    for (j in seq_along(probs)) { 
      prob <- as.numeric(probs[j])
      
      result <- apply(sample_preds[[i]], 2:3, quantile, probs = prob) |> #MCMCサンプリングの平均をとる
        as.data.frame() |>
        dplyr::rename("A" = "V1",
                      "B" = "V2",
                      "C" = "V3",
                      "D" = "V4") |>
        tidyr::pivot_longer(cols = c("A","B","C","D"),
                            names_to = "section",
                            values_to = paste0(value_names[i], "_", prob))
      
      #各結果を格納する
      results_list[[paste0(value_names[i], "_", prob)]] <- result[[paste0(value_names[i], "_", prob)]]
    }
  }
  
  #各結果をcbind()で一つのマトリックスに格納する
  results_matrix <- do.call(cbind, results_list)
  
  #マトリックスの各列に名前をつける
  colnames(results_matrix) <- names(results_list)
  
  return(results_matrix)
}

sample_pred_list <- list(sample_pred$alpha_pred)
value_name_list <- c("pred")
probs_list <- list(0.5, 0.025, 0.975)  # 必要に応じて適切な確率を設定

result_matrix <- summary_matrix(sample_preds = sample_pred_list, probs = probs_list, value_names = value_name_list) |> 
  as.data.frame()

推定されたパラメータは標準化されているので、売上の平均と標準偏差を使って元の大きさに戻します。

prediction_df <- df_sales_merged |>
  dplyr::bind_cols(result_matrix) |>
  dplyr::mutate(predicted = pred_0.5 * sales_sd + sales_mean,
                Lower95CI = pred_0.025 * sales_sd + sales_mean,
                Upper95CI = pred_0.975 * sales_sd + sales_mean
  ) |>
  dplyr::mutate(date = as.Date(date)) 

最後に、予測値を視覚化します。

ggplot2::ggplot() +
  ggplot2::geom_point(data = prediction_df,ggplot2:: aes(x = date, y = sales, color = section)) +
  ggplot2::geom_line(data = prediction_df, ggplot2::aes(x = date, y = predicted, color = section)) +
  ggplot2::geom_ribbon(data = prediction_df, ggplot2::aes(x = date, ymin = Lower95CI, ymax = Upper95CI), alpha = 0.1) +
  ggplot2::facet_wrap(~section, nrow = 2) +
  ggplot2::scale_x_date(date_breaks = "2 month", date_labels = "%Y-%m") +
  ggplot2::theme_bw() +
  ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)) 

image.png

おわりに

今回は、グループ間の相関を考慮した階層モデルとして、状態空間モデルを扱いました。時系列データで複数のグループに関するデータを扱う時、階層ベイズモデルは非常に有用だと思います。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?