1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ベイズ猛者向け:インド料理過程モデルをStanでガチ実装してみた

Posted at

はじめに

こんにちは、事業会社で働いているデータサイエンティストです。

皆さんは、「インド料理過程」というモデルを聞いたことがありますか? 私の考えですが、このモデルは正直なところ、提案者であるGriffiths and Ghahramani(2011, p. 1198)が住んでいたロンドンの特定の地域の文化的背景に基づいた比喩であり、他の地域に住む私たちにとっては直感的に理解しにくい表現かもしれません。そこで、本記事ではビュッフェ形式のインド料理の比喩(?)を用いるのではなく、アカデミアやビジネスの現場にも馴染みやすい形で概念を説明していきます。

また、インド料理過程は、その実装の難しさでも知られています。主な理由は、モデルに離散確率変数が含まれているため、直接微分できない点にあります。もちろん、離散確率変数を積分消去すれば対応できますが、単純に実装すると、ネストされた膨大なforループが発生し、計算のボトルネックになってしまいます。しかし、このボトルネックは、工夫次第で巧妙に計算負荷を軽減することが可能です。

では、早速詳細を説明に入ります!

モデルの概念

インド料理過程の考え方は非常にシンプルです。要するに、

観測値に対応する、観測できない無限のフラグ(ダミー変数)が存在し、それらの正体や影響を推定したい

ということです。例えば、すべてのデータが観測可能であれば、「取締役フラグ」や「首都出身フラグ」が年収に与える影響を推定できます。しかし、これらのフラグが欠損している場合や、事前に仮説を置かずデータから推定したい場合、それらは観測不能な要素となります。

説明だけだとわかりづらいので、他の似たようなモデルと比較します。

ディリクレ過程という手法は、インド料理過程とすごく似ています。

ただ、ディリクレ過程は観測値が一つのグループに属すると仮定するのに対して、インド料理過程は観測値が複数のフラグを持ちうると仮定します。

したがって、例えば、履歴書の分析の際に、弁護士を経てコンサルタントに転身したAさんは、インド料理過程モデルでは「弁護士歴ありフラグ」と「コンサル歴ありフラグ」が同時に立っている人として判定されます。一方で、ディリクレ過程では、「弁護士歴とコンサル歴を同時に持っている人のグループ」を形成するか、「専門職グループ」に分類されます。このように、観測値の背後に一つのグループしか存在しないと仮定すると、ふわっとした分類(「専門職グループ」)か、細かすぎる分類(「弁護士歴とコンサル歴を同時に持っている人のグループ」)になってしまいます。一方、観測値の背後に複数のフラグが存在することを仮定すれば、分析において有益な個々の特性をより的確に捉えることができ、実際のデータ構造を反映した柔軟な分析が可能になります。

では、LDAの構造を用いたトピックモデルと比較すると、両者とも観測値が複数のトピックやフラグを持つことが許容されていますが、トピックモデルの場合、一つの観測値のトピック分布には足して1になるという制約があります。そのため、「トピック1の割合が少し増えて、他のトピックの割合が不変の場合の効果」などは論理上推定できません。具体的には、「トピック1の割合」が0.0001でも増加すると、他のトピックの割合は数学的に必然として減少するからです。一方、インド料理過程モデルにはそのような制約がなく、フラグが全く立っていない、すべてのフラグが立っている、または一部のフラグしか立っていないといった状況に柔軟に対応しています。このため、理論的に正当化された「フラグ1が立って、他のフラグが変わらない時の効果」を推定することが可能です。

では、早速具体的なモデルの形を確認したいと思います!

モデル説明

ここでは、Fong and Grimmer(2016)などを参考に、インド料理過程モデルを定式化します。

まず、ディリクレ過程とは少し異なる形の棒折り過程を用いて、各フラグの出現頻度$\pi$をサンプリングします:

$$
\eta_{k} \sim Beta(\alpha, 1)
$$

$$
\pi_{k} = \prod_{z = 1}^{k}\eta_{k}
$$

各観測値におけるフラグ$z_{j,k}$は、$\pi_k$の確率で1になるベルヌーイ分布からサンプリングされます:

$$
z_{j,k} \sim Bernoulli(\pi_{k})
$$

最後に、例えば観測値$j$のフラグをすべて集めてベクトル$Z_j$にすると、

$$
y_{j} \sim Normal(Z_{j} '* \beta, \sigma)
$$

のように、目に見えないフラグを用いた回帰モデルが構築できます。

一見単純そうに思えるかもしれませんが、$z_{j,k}$がパラメータであるにもかかわらず、離散確率分布に従っています。Stanは離散確率分布に従うパラメータには対応していないため、実装が困難です。離散確率分布に従うパラメータを推定する方法が存在するかもしれませんが、離散変数は微分できないため、根本的に難しいのです。

実装のコツ

では、どうすればよいのでしょうか?離散確率変数を扱う際には、出てくる全パターンを列挙して計算する方法が、ディリクレ過程やより単純なクラスタリングアルゴリズムでよく使われます。ただし、Pythonのようなコードで表現すると、クラスタリング系のモデルは次のように書けます。

for i in range(k):
  case_when[i] = log(pi[i]) + normal_log_likelihood(y, beta[i], sigma)

これに対して、インド料理過程モデルでは、次のように理論上、フラグの全ての組み合わせ($2^{\text{フラグ数}}$)を列挙する必要があります。

for first_flag in [0, 1]:
  for second_flag in [0, 1]:
    for third_flag in [0, 1]:
      for fourth_flag in [0, 1]:
        for fifth_flag in [0, 1]:
          ...

このようにforループを用いると、組み合わせの数が急増し、計算量が膨大になってしまいます。

ここで重要なコツがあります!

組み合わせパターンをあらかじめ計算してデータフレームに格納し、それをStanに渡せば良いのです!

具体的には、R言語を使用する場合、以下のコードで六つのフラグの全ての組み合わせを計算できます。これはStanを用いたモデル推定の際に固定なので、直接データとしてStanに渡せば良いのです。この方法により、Stan側のパターン管理や生成の負担が軽減され、コードの可読性も向上します。

> tibble::tibble(
     one = c(0, 1),
     two = c(0, 1),
     three = c(0, 1),
     four = c(0, 1),
     five = c(0, 1),
     six = c(0, 1)
 ) |>
     tidyr::complete(one, two, three, four, five, six)
# A tibble: 64 × 6
     one   two three  four  five   six
   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
 1     0     0     0     0     0     0
 2     0     0     0     0     0     1
 3     0     0     0     0     1     0
 4     0     0     0     0     1     1
 5     0     0     0     1     0     0
 6     0     0     0     1     0     1
 7     0     0     0     1     1     0
 8     0     0     0     1     1     1
 9     0     0     1     0     0     0
10     0     0     1     0     0     1
# ℹ 54 more rows
# ℹ Use `print(n = ...)` to see more rows

このような入力データを前提としたStanコードはこちらになります:

indian_buffet_process.stan
data {
  int latent_flag_type;

  int latent_flag_combination_type;
  array[latent_flag_combination_type, latent_flag_type] int latent_flag_combination_int;
  array[latent_flag_combination_type] vector[latent_flag_type] latent_flag_combination_vector;

  int N;
  array[N] real y;
}
parameters {
  real<lower=0> alpha;
  vector<lower=0,upper=1>[latent_flag_type] eta;

  vector<lower=0>[latent_flag_type] sigma_beta;
  vector[latent_flag_type] beta;

  real intercept;
  real<lower=0> sigma_y;
}
transformed parameters {
  vector[latent_flag_type] latent_group_probability;
  for (i in 1:latent_flag_type){
    latent_group_probability[i] = prod(eta[1:i]);
  }
}
model {
  alpha ~ gamma(0.01, 0.01);
  eta ~ beta(alpha, 1);

  sigma_beta ~ gamma(0.01, 0.01);
  beta ~ double_exponential(0, sigma_beta);

  intercept ~ normal(0, 10);
  sigma_y ~ gamma(0.01, 0.01);

  for (i in 1:N){
    vector[latent_flag_combination_type] case_vector;
    for (p in 1:latent_flag_combination_type){
      case_vector[p] = bernoulli_lpmf(latent_flag_combination_int[p] | latent_group_probability) +
                       normal_lpdf(y[i] | intercept + latent_flag_combination_vector[p] '* beta, sigma_y);
    }
    target += log_sum_exp(case_vector);
  }
}

特に、bernoulli_lpmf(latent_flag_combination_int[p] | latent_group_probability)の部分にご留意ください。Stanは理論上、離散変数のパラメータに対応していませんが、その離散変数の値をデータとして渡すことで、自然なコードで表現できるようになります。

モデル推定

今回の記事では、性能確認を目的としてシミュレーションを用いてデータを生成します。データ生成過程の正解として、二つの隠れたフラグが存在し、フラグ1が立った場合には目的変数に+3.8の効果があり、フラグ2が立った場合には目的変数に-3.8の効果がある状況で分析を行います。

ただ、フラグ数の推定も重要なので、インド料理過程モデルには、最大でフラグ数6まで考慮してもらって、そこから正しい結果を再現できるかを確認します;

set.seed(123)
df <- 5000 |>
  rnorm(n = _) |>
  tibble::tibble(
    e = _
  ) |>
  dplyr::mutate(
    one = rbinom(dplyr::n(), 1, 0.5),
    two = rbinom(dplyr::n(), 1, 0.5),
    three = rbinom(dplyr::n(), 1, 0.5),
    four = rbinom(dplyr::n(), 1, 0.5),
    five = rbinom(dplyr::n(), 1, 0.5),
    six = rbinom(dplyr::n(), 1, 0.5),
    y = 0.8 + 3.8 * one - 3.8 * two + e
  )

他のデータの準備とモデルコンパイルも実施します:

latent_flag_combination <- tibble::tibble(
  one = c(0, 1),
  two = c(0, 1),
  three = c(0, 1),
  four = c(0, 1),
  five = c(0, 1),
  six = c(0, 1)
  ) |>
  tidyr::complete(one, two, three, four, five, six)


data_list <- list(
  latent_flag_type = 6,
  latent_flag_combination_type = nrow(latent_flag_combination),
  latent_flag_combination_int = latent_flag_combination |> as.matrix(),
  latent_flag_combination_vector = latent_flag_combination |> as.matrix(),

  N = nrow(df),
  y = df$y
)

m_ibp_init <- cmdstanr::cmdstan_model("indian_buffet_process.stan")

では、推定を変分推論で実施します:

> m_ibp_estimate <- m_ibp_init$variational(
     data = data_list
 )
------------------------------------------------------------ 
EXPERIMENTAL ALGORITHM: 
  This procedure has not been thoroughly tested and may be unstable 
  or buggy. The interface is subject to change. 
------------------------------------------------------------ 
Gradient evaluation took 0.089992 seconds 
1000 transitions using 10 leapfrog steps per transition would take 899.92 seconds. 
Adjust your expectations accordingly! 
Begin eta adaptation. 
Iteration:   1 / 250 [  0%]  (Adaptation) 
Iteration:  50 / 250 [ 20%]  (Adaptation) 
Iteration: 100 / 250 [ 40%]  (Adaptation) 
Iteration: 150 / 250 [ 60%]  (Adaptation) 
Iteration: 200 / 250 [ 80%]  (Adaptation) 
Success! Found best value [eta = 1] earlier than expected. 
Begin stochastic gradient ascent. 
  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes  
   100       -12409.314             1.000            1.000 
   200       -11905.261             0.521            1.000 
   300       -11856.618             0.349            0.042 
   400       -11848.499             0.262            0.042 
   500       -11845.030             0.209            0.004   MEDIAN ELBO CONVERGED 
Drawing a sample of size 1000 from the approximate posterior...  
COMPLETED. 
Finished in  65.0 seconds.

60秒くらいで終わりました。

結果を保存します:

m_ibp_summary <- m_ibp_estimate$summary()

結果の可視化

可視化の前に、改めてシミュレーションで生成したyの分布を確認しましょう:

df$y |> hist("scott", col = "blue", main = "simulation")

simulation.png

三つのクラスターが存在すると判断される可能性がある構造ですが、真ん中の群に関しては、フラグ1とフラグ2が同時に立っているだけなので、クラスターとして推定するとデータ生成過程としては誤った結果になります。

クラスタリング手法では、真ん中の群にクラスターを設定する可能性が高いです。一方で、インド料理過程モデルでは、真ん中の群がフラグ1とフラグ2が同時に1であることを正しく識別できるのでしょうか?

早速確認しましょう:

> m_ibp_summary |>
     dplyr::filter(stringr::str_detect(variable, "^latent_group_probability\\["))
# A tibble: 6 × 7
  variable                       mean   median      sd      mad         q5     q95
  <chr>                         <dbl>    <dbl>   <dbl>    <dbl>      <dbl>   <dbl>
1 latent_group_probability[1] 0.407   0.407    0.0211  0.0212   0.372      0.442  
2 latent_group_probability[2] 0.402   0.402    0.0213  0.0214   0.367      0.437  
3 latent_group_probability[3] 0.0550  0.0368   0.0532  0.0356   0.00597    0.160  
4 latent_group_probability[4] 0.0124  0.00550  0.0192  0.00649  0.000382   0.0471 
5 latent_group_probability[5] 0.00432 0.00123  0.00923 0.00164  0.0000422  0.0188 
6 latent_group_probability[6] 0.00179 0.000297 0.00490 0.000417 0.00000458 0.00738

インド料理過程は二つのフラグしか推定していません。正解です!

次に、フラグの効果量を確認すると:

> m_ibp_summary |>
     dplyr::filter(stringr::str_detect(variable, "^beta\\["))
# A tibble: 6 × 7
  variable     mean   median     sd    mad     q5    q95
  <chr>       <dbl>    <dbl>  <dbl>  <dbl>  <dbl>  <dbl>
1 beta[1]  -3.61    -3.61    0.0333 0.0335 -3.67  -3.56 
2 beta[2]   4.03     4.03    0.0342 0.0336  3.97   4.08 
3 beta[3]   0.0151   0.0142  0.122  0.125  -0.178  0.217
4 beta[4]   0.00822  0.00112 0.501  0.495  -0.830  0.806
5 beta[5]   0.0918   0.102   0.695  0.673  -1.04   1.20 
6 beta[6]   0.0188  -0.0111  0.685  0.671  -1.03   1.16 

beta[1]が-3.8付近、beta[2]が3.8付近であるため、+3.8の効果を持つフラグと-3.8の効果を持つフラグがそれぞれ一つずつ存在する真のデータ構造をある程度再現できています。しかし、両方の信用区間が真の値をカバーしておらず、若干ずれているのは少し残念です。

結論

いかがでしたか?

本記事で紹介したモデルを活用することで、データに隠された変数を発見し、それを因果推論などの意思決定に有益な形で転用できます。次回の記事では、インド料理過程を用いてテキストデータから処置変数を発見し、その効果を推定する方法を紹介します。

最後に、私たちと一緒に、データサイエンスの力で社会を改善したい方はこちらをご確認ください:

参考文献

Fong, Christian, and Justin Grimmer. "Discovery of treatments from text corpora." Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2016.

Griffiths, Thomas L., and Zoubin Ghahramani. "The Indian buffet process: An introduction and review." Journal of Machine Learning Research 12.4 (2011).

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?