2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

RとStanで実装する有限混合モデル

Last updated at Posted at 2025-09-08

はじめに

現実のデータは、しばしば単一の分布ではうまく説明できない。
例えば次のようなケースが考えられる:

  • 顧客の行動が複数のタイプに分かれている(価格に敏感な層とそうでない層、ヘビーユーザーとライトユーザーなど)。

  • 観測データが多峰性(複数の山を持つ分布)を示す(テストの点数が「できる層」「苦手な層」に分かれる)。

  • 平均や分散の異なる集団が混ざっている(異なる地域や条件下で計測された測定値)。

有限混合モデルでは、こうした異質性を明示的にモデル化できる。
確率的には「潜在的に異なる分布から生成されたデータが観測されている」と考え、それぞれの分布をクラスター、クラスターが選ばれる確率を混合比率として表す。

一方で、クラスター数が恣意的に決定されたり、パラメータの識別性が弱かったりとデメリットも抱えている。今回は、有限混合モデルをRとStanで実装することを通して、その限界を確認する。

データ生成過程

観測データの生成過程は以下のとおり。

$$p(y_i) = \sum_{k=1}^{K}\theta_k \mathcal N(y_i|\mu_k, \sigma_k^2)$$
ただし、各データ点$i=1, \cdots,N$に対し、
$$y_i |z_i \sim \mathcal N(\mu_{z_i}, \sigma_{z_i}^2) $$
$$z_i \sim Categorical(\theta_1, \theta_2, \theta_3) $$

また、事前分布として以下を設定する。

$$ \theta = (\theta_1, \theta_2, \theta_3) \sim Dirichlet(1,1,1),\quad \sum_{k=1}^{K}\theta_k = 1, \quad \theta_k \geq0 $$
$$ \mu_k \sim \mathcal N(0,1),\quad \mu_1 \leq \mu_2 \leq \cdots \leq \mu_K$$
$$ \sigma_k \sim LogNormal(0,2), \quad \sigma_k \geq 0$$

Rコードの実装

set.seed(123)

# Parameters
K <- 3      
N <- 10000     

# Priors
theta <- as.numeric(gtools::rdirichlet(1, rep(1, K))) 
mu <- sort(rnorm(K, 0, 1))                    
sigma <- rlnorm(K, meanlog = 0, sdlog = 2)            

df <- tibble::tibble(
  z = sample(1:K, size = N, replace = TRUE, prob = theta),
  y = rnorm(N, mean = mu[z], sd = sigma[z])  
)      

各クラスターごとのサンプルの分布は以下のようになる。

ggplot2::ggplot(df, ggplot2::aes(x = y, fill = factor(z))) +
  ggplot2::geom_histogram(ggplot2::aes(y = ..density..), bins = 30, alpha = 0.6, position = "identity") +
  ggplot2::facet_wrap(~ z, scales = "free_y") +  
  ggplot2::scale_fill_brewer(palette = "Set1", name = "Cluster") +
  ggplot2::labs(title = "分布の比較 (各Kごと)", x = "y", y = "Density") +
  ggplot2::theme_minimal() +
  ggplot2::theme(legend.position = "top")

image.png

上記の各分布が事前分布で設定した比率で混合されたものが、以下のプロットである。

ggplot2::ggplot(df, ggplot2::aes(x = y)) +
  ggplot2::geom_histogram(ggplot2::aes(y = ..density..), bins = 30, fill = "grey80", color = "white") 

image.png

Stanファイルの実装

data {
  int<lower=1> K;          
  int<lower=1> N;          
  array[N] real y;      
}

parameters {
  simplex[K] theta;          
  ordered[K] mu;            
  vector<lower=0>[K] sigma;  
}

model {
  vector[K] log_theta = log(theta);  
  sigma ~ inv_gamma(5, 5);
  mu ~ normal(5, 10);
  for (n in 1:N) {
    vector[K] lps = log_theta;
    for (k in 1:K) {
      lps[k] += normal_lpdf(y[n] | mu[k], sigma[k]);
    }
    target += log_sum_exp(lps);
  }
}

パラメータの事後分布の推定

stan_data <- list(
  K = K,
  N = N,
  y = df$y
)

stan_model <- rstan::stan_model("finite_mixture.stan")
vb_fit <- rstan::vb(stan_model,
             data = stan_data,
             output_samples = 1000, 
             iter = 10000,          
             tol_rel_obj = 0.01)    
             
vb_samples <- rstan::extract(vb_fit)

samples <- as.data.frame(vb_fit) 

mu_df <- samples |>
  dplyr::select(dplyr::starts_with("mu")) |>
  tidyr::pivot_longer(dplyr::everything(), names_to = "component", values_to = "mu") |>
  dplyr::mutate(component = readr::parse_number(component))

sigma_df <- samples |>
  dplyr::select(dplyr::starts_with("sigma")) |>
  tidyr::pivot_longer(dplyr::everything(), names_to = "component", values_to = "sigma") |>
  dplyr::mutate(component = readr::parse_number(component))

theta_df <- samples |>
  dplyr::select(starts_with("theta")) |>
  tidyr::pivot_longer(dplyr::everything(), names_to = "component", values_to = "theta") |>
  dplyr::mutate(component = readr::parse_number(component))

パラメータ$\mu$, $\sigma$, $\theta$それぞれ、クラスターごとの事後サンプリングを1000個ずつ格納したデータフレームができた。

生成したデータの分布と事後サンプリングされたパラメータから構成した分布を比較する。

# Take posterior means as representative parameters
mu_est <- mu_df |>
  dplyr::group_by(component) |>
  dplyr::summarise(mu = mean(mu))

sigma_est <- sigma_df |>
  dplyr::group_by(component) |>
  dplyr::summarise(sigma = mean(sigma))

theta_est <- theta_df |> 
  dplyr::group_by(component) |>
  dplyr::summarise(theta = mean(theta))

# Mixture density function
mix_density <- function(y, mu, sigma, theta) {
  rowSums(sapply(1:length(mu), function(k) theta[k] * dnorm(y, mu[k], sigma[k])))
}

y_grid <- seq(min(df$y), max(df$y), length.out = 200)
density_df <- tibble::tibble(
  y = y_grid,
  density = mix_density(y_grid, mu_est$mu, sigma_est$sigma, theta_est$theta)
)

ggplot2::ggplot(df, ggplot2::aes(x = y)) +
  ggplot2::geom_histogram(ggplot2::aes(y = ..density..), bins = 30, fill = "grey80", color = "white") +
  ggplot2::geom_line(data = density_df, ggplot2::aes(x = y, y = density), color = "red", size = 1.2) +
  ggplot2::labs(title = "Observed data and estimated mixture density",
       x = "y", y = "Density")

image.png

分布はうまく再現できているが、そもそも事後パラメータはうまく推定できているのだろうか。$\mu$, $\sigma$, $\theta$それぞれ可視化する。

ggplot2::ggplot(mu_df, ggplot2::aes(x = mu, fill = factor(component))) +
  ggplot2::geom_density(alpha = 0.5) +
  ggplot2::geom_vline(xintercept = mu, linetype = "dashed") +
  ggplot2::labs(title = "Posterior of mixture means (μ)",
       x = "mu", y = "Density", fill = "Component")


ggplot2::ggplot(theta_df, ggplot2::aes(x = theta, fill = factor(component))) +
  ggplot2::geom_density(alpha = 0.5) +
  ggplot2::geom_vline(xintercept = theta, linetype = "dashed") +
  ggplot2::labs(title = "Posterior of mixture weights (θ)",
       x = "Mixing proportion", y = "Density", fill = "Component")

image.png

$\mu$の事後平均と事前に設定した値の平均の比較。

image.png

$\theta$の事後平均と事前に設定した値の平均の比較。

以上より、サンプルの分布はうまく再現できているものの、事後パラメータが真のパラメータからずれている。つまり、識別に問題がある。

考察

一般に、有限混合モデルにおいてパラメータが識別できていない原因としては、以下が挙げられる。

  • ラベルスイッチング問題

混合分布の成分は入れ替えても同じ分布となる。例えば「成分1の平均が2、成分2の平均が5」と「成分1の平均が5、成分2の平均が2」は、全体の分布としては同一となる。

こちらについてはparametersブロックにおいてordered[K] mu;としているので、平均値が自動的に昇順に並び、ラベルスイッチング問題はほぼ防げているはず。

  • 識別性の弱さ

混合モデルは「複数のパラメータ組み合わせで同じ分布」を表せてしまう。よって、データ数が少ない場合や分布が近い場合、異なるパラメータの組み合わせがほぼ同じ尤度を持つことがあり、これが事後分布の「平坦さ」となって推定を不安定する。

今回は標準正規分布に従って各クラスターの成分をサンプリングしたこともあり、各クラスターの平均が似た値となっているところが問題だと考えられる。

混合モデルの推定目的が「分布の形をうまく再現する」ことであればまだしも、「真のパラメータを推定する」ことであるならばディリクレ過程など、成分数を事前に固定せずにベイズ非パラメトリックで推定することでクラスターの任意性に影響されにくいモデルを使うことも考えられる。

参考文献

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?