1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ベイズ棒折り過程基底関数モデルで複雑な関数を推定してみた

Posted at

はじめに

こんにちは、事業会社で働いているデータサイエンティストです。

今回の記事では、ガウス過程の代用品として、棒折り過程を活用して、無限にある基底関数から自動で有用なものをピックアップして、複雑な関数を推定する手法を提案してみます。

ガウス過程の問題点

ガウス過程は、基底関数法の拡張として見なすことができ、目的は今回提案する手法と似ています。

詳細はこちらを参照してください:

ただ、ガウス過程はとにかく推定に時間がかかります。元々、会社の業務で検討しているモデルにガウス過程を組み込もうとしていたんですが(詳細は話せません)、一回の推定に8時間以上がかかります。モデルの根幹のパーツでもない部分にこれだけの時間を持っていかれると、モデル定式化のPDCAサイクルが回らないし、自分の工数ももったいないと判断して、ガウス過程の導入を断念しました。

モデルの詳細は話せないですが、ガウス過程の代わりに簡易的なアプローチを使った結果、推定時間が40分程度に短縮されました。どうしてこんなに時間がかかるのかというと、ガウス過程が丸暗記をしているからです。ガウス過程は裏で巨大な共分散行列を保持して、その逆行列も計算しています。

丸暗記的なやり方は、そもそも計算のボトルネックになるだけでなく、そもそも統計学・機械学習的な観点から見て、学習データからの情報を要約・圧縮してなくね?おかしくね?と思いました。

そこで、原点に立ち戻って、そもそも基底関数を本当に無限の関数から自動で一部抽出して利用できないかと考え、棒折り過程を活用してみようと思いました。

では次は具体的なモデル構造の話をします。

モデル説明

まず、結果変数と独立変数の関係の定式化から説明します。

$$
Y_{i} \sim Normal(\alpha+\sum_{e = 1}^{\infty}\beta_{e}b_{e}(x_{i}), \sigma)
$$

$\beta_{e}$は基底係数で、$b_{e}$は基底関数です。基底関数は「管轄範囲」を決めるとイメージした方がわかりやすいです。xが-1 ~ 2だったら、$\beta_{1}$が管轄する、xが1.5 ~ 5だったら、$\beta_{2}$が管轄する、を判断する関数です。ただ、実際の管轄範囲は何かしらの距離を基準に計算された「近さ」で判断されることが多いため、ルールベース的な感じで管轄を区切るのではなく、管轄範囲がかぶることは許容されます。

次に、基底関数の構造を説明します。

まず、無限次元の確立分布(simplex)の配列を棒折り過程で作成します。

棒折り過程を作ることで、無限にある基底関数の可能性を肯定しつつ、データから見ると実は役に立たない基底関数に0の重みを付与することで変数選択できます。

$$
\alpha \sim Gamma(1, 1)
$$

eを1から無限大までループ:

$$
\pi_{e}\sim Beta(1, \alpha)
$$

$$
p_{e} = \pi_{e} \prod\limits_{l=1}^{e - 1} (1 - \pi_{l})
$$

ここで$p_{e}$とe番目の基底関数の「管轄範囲」を示す正規分布を組み合わせれば$b_{e}$の完成です:

$$
b_{e}(x) = p_{e}\space \frac{1}{\sqrt{2 \pi \sigma_{e}^2}} \exp \left( -\frac{(x - \mu_{e})^2}{2 \sigma_{e}^2} \right)
$$

最後の基底係数の事前分布に一般化ダブルパレート事前分布(generalized double Pareto prior distribution)を指定します。

$$
\rho_{global} \sim Gamma(0.01, 0.01)
$$

eを1から無限大までループ:

$$
\rho_{e} \sim Exponential(0.5*\rho_{global}^2)
$$

$$
\beta_{e} \sim Normal(0, \rho_{e}^2)
$$

モデル実装

Stanでの実装はこちらになります:

dirichlet_process_basis.stan
functions {
  vector stick_breaking(vector breaks){
    int length = size(breaks) + 1;
    vector[length] result;
    
    result[1] = breaks[1];
    real summed = result[1];
    for (d in 2:(length - 1)) {
      result[d] = (1 - summed) * breaks[d];
      summed += result[d];
    }
    result[length] = 1 - summed;
    
    return result;
  }
  real partial_sum_lpdf(
    array[] real y,
    
    int start, int end,
    
    int section_type,
    
    array[] real x,
    
    vector section,
    vector midpoint_x, vector spread_x,
    
    real sigma,
    real intercept,
    vector beta
  ){
    vector[end - start + 1] log_likelihood;
    int count = 1;
    for (i in start:end){
      vector[section_type] case_vector;
      for (j in 1:section_type){
        case_vector[j] = log(section[j]) +
                         normal_lpdf(x[i] | midpoint_x[j], spread_x[j]^2);
      }
      log_likelihood[count] = normal_lpdf(
        y[count] | intercept + softmax(case_vector) '* beta, sigma
      );
      count += 1;
    }
    return sum(log_likelihood);
  }
}
data {
  int section_type;
  
  int N;
  array[N] real x;
  array[N] real y;
  
  int val_N;
  array[val_N] real val_x;
  array[val_N] real val_y;
}
parameters {
  real<lower=0> section_alpha;
  vector<lower=0, upper=1>[section_type - 1] section_breaks;
  
  vector[section_type] midpoint_x;
  vector<lower=0>[section_type] spread_x;
  
  real<lower=0> rho_global;
  real<lower=0> sigma;
  real intercept;
  vector<lower=0>[section_type] rho_beta;
  vector[section_type] beta;
}
transformed parameters {
  simplex[section_type] section;
  
  section = stick_breaking(section_breaks);
}
model {
  section_alpha ~ gamma(1, 1);
  section_breaks ~ beta(1, section_alpha);
  
  midpoint_x ~ normal(0, 1);
  spread_x ~ inv_gamma(0.01, 0.01);
  
  rho_global ~ gamma(0.01, 0.01);
  rho_beta ~ exponential(0.5 * (rho_global)^2);
  sigma ~ inv_gamma(0.01, 0.01);
  intercept ~ normal(0, 1);
  beta ~ normal(0, rho_beta^2);
  
  int grainsize = 1;
  
  target += reduce_sum(
    partial_sum_lupdf, y,
    
    grainsize,
    
    section_type,
    
    x,
    
    section,
    midpoint_x, spread_x,
    
    sigma,
    intercept,
    beta
  );
}
generated quantities {
  vector[val_N] predict;
  vector[val_N] signal;
  for (i in 1:val_N){
    vector[section_type] case_vector;
    for (j in 1:section_type){
      case_vector[j] = log(section[j]) +
                       normal_lpdf(val_x[i] | midpoint_x[j], spread_x[j]^2);
    }
    signal[i] = intercept + softmax(case_vector) '* beta;
    predict[i] = normal_rng(intercept + softmax(case_vector) '* beta, sigma);
  }
}

モデル推定

では、このモデルはノイズがある中でsin関数の形をうまく推定できるかを確認しましょう!

sin_df <- seq(-7, 7, by = 0.00005) |>
  tibble::tibble(
    x = _
  ) |>
  dplyr::mutate(
    e = rnorm(dplyr::n()),
    x_std = (x - mean(x))/sd(x),
    y = sin(x) + e
  )

set.seed(12345)
val_id <- sin_df |>
  nrow() |>
  seq_len() |>
  sample(x = _, size = 3000)

data_list <- list(
  # 最大で考慮する基底関数の数
  # 本当は無限大を設定したいですができないので10で近似する
  section_type = 10,
  
  N = nrow(sin_df[-val_id,]),
  x = sin_df$x_std[-val_id],
  y = sin_df$y[-val_id],
  
  val_N = nrow(sin_df[val_id,]),
  val_x = sin_df$x_std[val_id],
  val_y = sin_df$y[val_id]
)

サンプル数は27万件です:

> data_list$N
[1] 277001

次に、モデルコンパイルして:

m_dpb_init <- cmdstanr::cmdstan_model("dirichlet_process_basis.stan",
                                                   cpp_options = list(
                                                     stan_threads = TRUE
                                                   )
)

モデル推定を実施します:

> m_dpb_estimate <- m_dpb_init$variational(
     seed = 12345,
     threads = 20,
     iter = 10000,
     data = data_list
 )
------------------------------------------------------------ 
EXPERIMENTAL ALGORITHM: 
  This procedure has not been thoroughly tested and may be unstable 
  or buggy. The interface is subject to change. 
------------------------------------------------------------ 
Gradient evaluation took 0.062539 seconds 
1000 transitions using 10 leapfrog steps per transition would take 625.39 seconds. 
Adjust your expectations accordingly! 
Begin eta adaptation. 
Iteration:   1 / 250 [  0%]  (Adaptation) 
Iteration:  50 / 250 [ 20%]  (Adaptation) 
Iteration: 100 / 250 [ 40%]  (Adaptation) 
Iteration: 150 / 250 [ 60%]  (Adaptation) 
Iteration: 200 / 250 [ 80%]  (Adaptation) 
Success! Found best value [eta = 1] earlier than expected. 
Begin stochastic gradient ascent. 
  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes  
   100      -653579.287             1.000            1.000 
   200      -446911.217             0.731            1.000 
   300      -462161.982             0.498            0.462 
   400      -436466.842             0.389            0.462 
   500      -425826.024             0.316            0.059 
   600      -423224.219             0.264            0.059 
   700      -420147.710             0.228            0.033 
   800      -407494.900             0.203            0.033 
   900      -395970.004             0.184            0.031 
  1000      -394789.591             0.166            0.031 
  1100      -394190.081             0.066            0.029 
  1200      -394082.809             0.020            0.025 
  1300      -394835.709             0.016            0.007   MEDIAN ELBO CONVERGED 
Drawing a sample of size 1000 from the approximate posterior...  
COMPLETED. 
Finished in  152.9 seconds.

27万件のデータでも152秒で推定が終わりました。ガウス過程だったら152年くらいですかね、、、

結果の可視化

まずは機械学習がよくやるように、精度を確認しましょう:

> m_dpb_summary |> 
     dplyr::filter(stringr::str_detect(variable, "predict")) |>
     dplyr::bind_cols(
         val_y = data_list$val_y
     ) |>
     dplyr::mutate(
         ape = abs((mean - val_y)/val_y)
     ) |>
     dplyr::pull(ape) |>
     mean()
[1] 3.492835

MAPEは相当悪いですね、、、この数字を見て、モデルの学習がうまくいかなかったと、機械学習の人が判断するかもしれません。

ただ、学習データにも検証データにも正規分布からサンプリングされた予測不可能なノイズが注入されており、当たらない部分があるのは当たり前です。また、アカデミアとビジネスにおいて、価値があって悪手なぶるなインサイトに繋がるのは、往々にして目に見えない傾向です。画像に猫が入っているかを判定したい、信号が赤なのかを判定したい、などの純粋な工学的なタスクでなければ、予測精度をモデルの良し悪しを判定する唯一の指標にするのは危険です

なので、モデルがうまくsin関数の形を推定できたかを見てみましょう:

m_dpb_estimate$draws("signal") |>
  dplyr::as_tibble() |>
  dplyr::mutate(
    dplyr::across(
      dplyr::everything(),
      ~ as.numeric(.)
    ),
    id = dplyr::row_number()
  ) |>
  tidyr::pivot_longer(!id, names_to = "sample") |>
  split(~ id) |>
  purrr::map(
    \(df){
      df |>
        dplyr::bind_cols(
          x = sin_df$x[val_id]
        )
    }
  ) |>
  dplyr::bind_rows() |>
  ggplot2::ggplot() + 
  ggplot2::geom_line(ggplot2::aes(x = x, y = value, group = id), color = ggplot2::alpha("blue", 0.1))

sin.png

うまくできています!

なので、ベイズ棒折り過程基底関数モデルはデータの裏に隠された傾向を正しく抽出できたと判断できます。

結論

いかがでしょうか?このモデルは、ガウス過程と比べて圧倒的に高速で、同様の効果を発揮します。低い計算コストで任意の(?)関数を推定する際にぜひこの手法をご活用ください!

最後に、私たちと一緒に、データサイエンスの力で社会を改善したい方はこちらをご確認ください:

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?