ベータ二項分布の解析的性質を用いて Stan を高速化する

  • 6
    Like
  • 0
    Comment
More than 1 year has passed since last update.

確率変数 $K$ がベータ二項分布に従うとする。
ベータ二項分布というのは、基本的に成功確率 $p \in [0, 1]$ の二項分布であるが、この $p$ が一定ではなく、ベータ分布に従うという確率分布である。

すなわち、

\begin{align}
p &\sim \textrm{Beta}(\alpha, \beta) \\
k &\sim \textrm{Binomial}(n, p)
\end{align}

である。

このとき、ベータ分布のパラメータ $\alpha, \beta$ を推定したい。

1. 単純な方法

これを単純に推定するなら、Stan コードは次のようになる。

beta_binom_model1.stan
data {
  int<lower=0> N;
  int<lower=0> k[N];
  int<lower=1> n[N];
}
parameters {
  real<lower=0, upper=1> p[N];
  real<lower=0> alpha;
  real<lower=0> beta;
}
model {
  p ~ beta(alpha, beta);
  k ~ binomial(n, p);
}

擬似データを作ってパラメータ推定してみよう。

generate_data <- function(size, alpha=2, beta=5, lambda=3) {
  p <- rbeta(size, alpha, beta)
  n <- rpois(size, lambda) + 1
  k <- rbinom(size, n, p)
  data.frame(k, n)
}

N <- 1000
data <- generate_data(N, alpha = 2, beta = 5)

library(rstan)
model1 <- stan_model("beta_binom_model1.stan")
fit1 <- sampling(model1, data = list(N = N, k = data$k, n = data$n), 
                pars = c("alpha", "beta"), chains = 3)
print(fit1)
          mean se_mean    sd     2.5%      25%      50%      75%    97.5% n_eff Rhat
alpha     1.90    0.02  0.27     1.44     1.71     1.88     2.06     2.53   163 1.01
beta      4.74    0.05  0.69     3.58     4.25     4.67     5.16     6.32   170 1.01
lp__  -3487.24    5.27 62.88 -3609.53 -3530.53 -3486.78 -3446.07 -3359.79   142 1.01

うまく推定できたようだ。

このとき、私のパソコンでは 1 チェインの計算に 1 分以上かかった。

#  Elapsed Time: 61.599 seconds (Total)
#  Elapsed Time: 63.545 seconds (Total)
#  Elapsed Time: 64.889 seconds (Total)

2. 解析的性質に基づく Stan の高速化

ところで、ベータ二項分布の尤度は $p$ を積分消去して解析的に求めることができる。

二項分布の確率質量関数は

p(k | n, p) = {}_nC_k p^k (1-p)^{n-k}

であり、ベータ分布の確率密度関数は、$B(\alpha, \beta)$ をベータ関数として、

p(p | \alpha, \beta) = \frac{p^{\alpha-1}(1-p)^{\beta-1}}{B(\alpha, \beta)}

であるため、ベータ二項分布の確率質量関数は、

\begin{align}
p(k | n, \alpha, \beta) &= \int p(k|n,p)p(p|\alpha,\beta) dp \\
&= \int {}_nC_k p^k (1-p)^{n-k} \frac{p^{\alpha-1}(1-p)^{\beta-1}}{B(\alpha, \beta)} dp \\
&= {}_nC_k \int \frac{p^{k+\alpha-1}(1-p)^{n-k+\beta-1}}{B(\alpha, \beta)} dp \\
&= {}_nC_k \frac{B(k+\alpha, n-k+\beta)}{B(\alpha, \beta)} \int \frac{p^{k+\alpha-1}(1-p)^{n-k+\beta-1}}{B(k+\alpha, n-k+\beta)} dp \\
&= {}_nC_k \frac{B(k+\alpha, n-k+\beta)}{B(\alpha, \beta)}
\end{align}

となる。

したがって、ベータ二項分布の対数尤度は

\textrm{LogLik} = \sum_{(k,n)}  \bigl( \log({}_nC_k) +  \log(B(k+\alpha, n-k+\beta)) - \log(B(\alpha, \beta)) \bigr)

と求まる。

これにより、Stan コードは次のように書ける。

beta_binom_model2.stan
functions {
  // Returns the natural logarithm of the number of k-combinations from n.
  real lcombination(int n, int k) {
    return -lbeta(k+1., n-k+1.) - log(n+1.);
  }
}
data {
  int<lower=0> N;
  int<lower=0> k[N];
  int<lower=1> n[N];
}
parameters {
  real<lower=0> alpha;
  real<lower=0> beta;
}
model {
  real ldenom;
  ldenom <- lbeta(alpha, beta);
  for(i in 1:N)
    increment_log_prob(
      lcombination(n[i], k[i]) + 
      lbeta(k[i] + alpha, n[i] - k[i] + beta) -
      ldenom
    );
}

ここで、Stan には組み合わせ ${}_nC_k$ を計算する関数が無いようなので、

\begin{align}
 {}_nC_k &= \frac{1}{(n+1)\textrm{Beta}(k+1, n-k+1)} \\
 \log({}_nC_k) &= -\log(\textrm{Beta}(k+1, n-k+1)) - \log(n+1)
\end{align}

を用いてユーザ定義関数 lcombination(n, k) を作成した。

同様にこれを実行してみる。

model2 <- stan_model("beta_binom_model2.stan")
fit2 <- sampling(model2, data = list(N = N, k = data$k, n = data$n), 
                 pars = c("alpha", "beta"), chains = 3)
#  Elapsed Time: 37.055 seconds (Total)
#  Elapsed Time: 32.951 seconds (Total)
#  Elapsed Time: 36.859 seconds (Total)

1 チェインが 35 秒程度になった。
ベータ二項分布の解析的な性質を利用することで、約 2 倍の高速化を実現できた。

こうした、確率分布の解析的な性質をもとに Stan を高速化するための知識については、下記ブログ記事がまとまっている。

3. 組み込み関数

よく調べてみると、Stan にはベータ二項分布に従う確率変数を扱うための組み込み関数 beta_binomial() が存在する。

これを用いて Stan コードを書き換えると次のようになる。

beta_binom_model3.stan
data {
  int<lower=0> N;
  int<lower=0> k[N];
  int<lower=1> n[N];
}
parameters {
  real<lower=0> alpha;
  real<lower=0> beta;
}
model {
  k ~ beta_binomial(n, alpha, beta);
}

これも同様に実行してみる。

model3 <- stan_model("beta_binom_model3.stan")
fit3 <- sampling(model3, data = list(N = N, k = data$k, n = data$n), 
                 pars = c("alpha", "beta"), chains = 3)
#  Elapsed Time: 20.437 seconds (Total)
#  Elapsed Time: 21.549 seconds (Total)
#  Elapsed Time: 22.208 seconds (Total)

1 チェイン 20 秒程度になった。

4. まとめ

確率分布の解析的性質をもとに Stan の高速化が可能である。
また、組み込み関数がある場合はそちらを使った方が速いのでやる前にちゃんと調べましょう(自戒)。