2
1

回帰モデルをいっぱい作った方がいいかな?そんな時は有限混合モデルに任せて!

Posted at

はじめに

突然だが、こんなデータを考えよう。

Rplot01.png

これは二つの線形モデルを作ればうまく表現できそうだが、無理やり同じ線形モデルで表現しようとすると、おそらく傾きがゼロのモデルが推定される。

こんな時な、もちろん手作業で左下から右上のところのデータを切り出してモデル1を作り、左上から右下のところのデータを切り出してモデル2を作る方法があるが、どうしても恣意性が入ってしまう。

また、二次元の時は目視で観測値をグループ分けすることもできるが、データの次元がこれ以上増えると、可視化して手作業で対処するのは不可能と言っても過言ではない。

そこで、本記事では、Imai and Tingley(2012)の論文を参考に、ベイズ有限混合モデルを使って二つのモデルを同時に推定する方法を紹介する。

データ生成過程

ここでは、上記のデータの生成方法を紹介する

set.seed(12345)

df <- 100000 |>
  # 正規乱数を生成する
  rnorm() |>
  tibble::tibble(
    # 正規乱数をxにしてtibbleに入れる
    x = _
  ) |>
  dplyr::mutate(
    # 撹乱項を生成する
    e = rnorm(n = dplyr::n(), mean = 0, sd = 0.5),
    # 傾きが1のモデル0で生成するyを作る
    y_0 = 1 * x + e,
    # 傾きが-1のモデル0で生成するyを作る
    y_1 = -1 * x + e,
    # モデル1が生成したyを観測するか、モデル0が生成したyを観測するかを決めるフラグを作る
    flag = rbinom(n = dplyr::n(), size = 1, prob = 0.5),
    # 観測されたyを作る
    y_obs = dplyr::case_when(
      flag == 1 ~ y_1,
      TRUE ~ y_0
    )
  ) 

要するに、傾きが1のデータと-1のデータを半々で生成する。中間変数を削減する観点ですべての処理をパイプで繋げた。

次に、下記のコードを実行すれば、はじめにのところの図が描画される

df |>
  ggplot2::ggplot() + 
  ggplot2::geom_point(ggplot2::aes(x = x, y = y_obs), color = ggplot2::alpha("blue", 0.3))

モデル記述

ここではまず、数式で有限混合モデルを記述する:

z \sim beta(0.5, 0.5)
sigma \sim inverse\space gamma(5, 5)
beta_1 \sim normal(0,1)
beta_2 \sim normal(0,1)

すべての観測値iについて

coin_i \sim bernoulli(z)

if coin_i == 1

y_i \sim normal(beta_1 * x_i, sigma)

if coin_i == 0

y_i \sim normal(beta_2 * x_i, sigma)

言葉で説明すると、すべての観測値がzの確率で傾きが$beta_1$のモデルから生成され、1 - zの確率で$beta_2$のモデルから生成される。他は一般的なベイズ線形回帰とあまり変わらないため詳細の説明は割愛する。

このモデルをstanで記述すると

simple_finite_mixture.stan
data {
  int<lower=1> N;
  array[N] real x;
  array[N] real y;
}
parameters {
  real<lower=0,upper=1> z;
  real<lower=0> sigma;
  vector[2] beta;
}
model {
  z ~ beta(0.5, 0.5);
  sigma ~ inv_gamma(5, 5);
  beta ~ normal(0, 1);
  for (n in 1:N){
    // 二つの状況の対数尤度を格納するためのベクトルを作成
    vector[2] mu;
    // coin == 1の場合
    mu[1] = log(z) + normal_lupdf(y[n] | beta[1] * x[n], sigma);
    // coin == 0の場合
    mu[2] = log1m(z) + normal_lupdf(y[n] | beta[2] * x[n], sigma);
    target += log_sum_exp(mu);
  }
}

で表現できる。stanは離散変数を扱えないため、離散変数の全ての実現可能なパターンを列挙し、ベクトルに格納して、最後にlog_sum_exp関数に入れる処理が必要である。

これは要するに、$coin_i$を積分除去する処理と見なしても良い。

モデル推定

では、早速データを整理して、モデルに入れよう。

data_list <- list(
  N = nrow(df),
  x = df$x,
  y = df$y_obs
)

m_init <- cmdstanr::cmdstan_model("simple_finite_mixture.stan")


m_estimate <- m_init$variational(
  data = data_list
)


m_summary <- m_estimate$summary()

変分推論のため36.3秒で推定が終了した。

次に、パラメータを確認する

> m_summary
# A tibble: 6 × 7
  variable           mean      median       sd      mad          q5         q95
  <chr>             <dbl>       <dbl>    <dbl>    <dbl>       <dbl>       <dbl>
1 lp__        -114135.    -114134     15.2     14.8     -114163     -114113    
2 lp_approx__      -2.05       -1.70   1.50     1.28         -4.98       -0.317
3 z                 0.501       0.501  0.00196  0.00199       0.498       0.504
4 sigma             0.506       0.506  0.00212  0.00195       0.503       0.510
5 beta[1]          -1.01       -1.01   0.00393  0.00388      -1.02       -1.01 
6 beta[2]           1.01        1.01   0.00270  0.00274       1.01        1.02 

まず、半々でデータを生成したため、zの値が0.5に近いのは正しい。sigmaもデータ生成過程のところで指定した0.5とほぼ一致している。

傾きのbetaの値に関して、1が先に来るか-1が先に来るかは重要ではないので、1に近い値と-1に近い値がそれぞれ正確に出たので問題なく、パラメータの値を推定できたといえよう。

所属グループの推定

ここで各グループの傾きと全体の中の割合を推定できたが、肝心の$coin_i$が積分除去された影響でモデルの事後分布として出てこない。

よって、このままだとどの観測値がどのグループのデータなのかを判断できない。

そこで、ここではImai and Tingley(2012)の式(8)を参考に、全ての$coin_i$の事後分布を推定する。

発想は簡単で、先ほど推定したzの事後分布を事前分布に、観測値の尤度を利用してさらに更新していくイメージである。数式で書くと

\hat{coin_iの確率} = \frac{z * 正規分布の確率密度関数(y_i, beta_1 * x_i, sigma)}{z * 正規分布の確率密度関数(y_i, beta_1 * x_i, sigma) + (1 - z) * 正規分布の確率密度関数(y_i, beta_2 * x_i, sigma)}\

になる。

ではRで計算する前に、まずは必要なパラメータを取り出す

z <- m_summary |>
  dplyr::filter(variable == "z") |>
  dplyr::pull(mean)

beta <- m_summary |> 
  dplyr::filter(stringr::str_detect(variable, "beta")) |>
  dplyr::pull(mean)

sigma <- m_summary |>
  dplyr::filter(variable == "sigma") |>
  dplyr::pull(mean)

次に、$coin_i$の事後分布を推定する処理を行って、ggplot2で可視化する

df |>
  dplyr::mutate(
    s = purrr::map_dbl(
      seq_len(dplyr::n()),
      # dnormで確率密度を計算する
      ~ (z * dnorm(y_obs[.x], mean = x[.x] * beta[1], sd = sigma))/(z * dnorm(y_obs[.x], mean = x[.x] * beta[1], sd = sigma) + (1 - z) * dnorm(y_obs[.x], mean = x[.x] * beta[2], sd = sigma))
    ),
    # 描画のため、0.5を閾値にグループ分けする
    s_flag = dplyr::case_when(
      s > 0.5 ~ "1",
      TRUE ~ "0"
    )
  ) |>
  ggplot2::ggplot() + 
  ggplot2::geom_point(ggplot2::aes(x = x, y = y_obs, color = s_flag)) + 
  ggplot2::scale_color_manual(
    values = c(
      "0" = ggplot2::alpha("blue", 0.3),
      "1" = ggplot2::alpha("red", 0.3)
    )
  )

Rplot02.png

これで、真ん中こそ難しいが、$coin_i$の事後分布の推定は概ね問題なくできたといえよう。

分析の問題点と意味

さて、データから複数の線形モデルを発見して、かつどの観測値がどのモデルから生成されたのかの確率も推定した。ただ、今のままだと、このモデルは未知の観測値がどのグループに所属しているかを判断できず、汎化性能を持っていないといえる。未知の観測値の結果変数がまだ観測されていないからである。

だが、このような形の有限混合モデルを少し改造することで、ビジネスの意思決定やレコメンドシステムのバックエンドロジックで利用できるように進化させられる。

弊社が運営しているアルバイト求人サイトバイトルでいうと、詳細は絶対に言わないが想像してみてください:サイトを再訪問するユーザーもいる。

再訪問は何を意味するのかというと、すでに何回かバイトルに来ており、かつ弊社としてはある程度ログ(応募した求人など)を取得できる。

なので、観測値についてではなく、ユーザーについてImai and Tingley(2012)の式(11)を参考に$coin_{user_i}$の事後分布を計算すれば、ユーザーに対して質の高いパーソナライズされたレコメンドを行うことが可能になる。

具体的な推定方法はまた別の機会で紹介する。

結論

本記事では、有限混合モデルについて説明した。モデル数を明示的に指定する必要がある有限混合モデルに対し、モデル数を指定する必要のない無限混合モデルもある。

詳細は筆者の以前の記事を是非確認してください。

参考文献

Imai, Kosuke, and Dustin Tingley. "A statistical method for empirical testing of competing theories." American Journal of Political Science 56.1 (2012): 218-236.

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1