3
4

More than 5 years have passed since last update.

変分ベイズ法による混合正規分布の推定を理解したい

Last updated at Posted at 2018-04-21

はじめに

変分ベイズ法の考え方

  • 変分ベイズ法は、パラメータの事後確率分布$p(v,w|X)$を確率分布の積$q(v)q(w)$で近似する手法。
  • 近似は、KL情報量を最小化する分布とする。
  • 対数周辺尤度$\log p(X)$は変分下限$L(q)$とKL情報量の和に分解される。
\begin{align}
\log p(X)&=L(q)+KL(q\parallel p)\\
L(q)&=\iint q(v)q(w)\log\frac{p(X,v,w)}{q(v)q(w)}\mathrm{d}v\mathrm{d}w\\
KL(q\parallel p)&=\iint q(v)q(w)\log\frac{q(v)q(w)}{p(v,w|X)}\mathrm{d}v\mathrm{d}w
\end{align}
  • KL情報量を直接計算することは困難なので、代わりに変分下限を最大化する。
  • 条件付き確率分布$p(v,w|X)$の計算は困難だが、同時確率分布$p(X,v,w)$の計算は可能。
  • 変分法により$q(v)$で最大化すると、$\log q(v)\sim\int q(w)\log p(X,v,w)\mathrm{d}w$となる。
  • 変分法により$q(w)$で最大化すると、$\log q(w)\sim\int q(v)\log p(X,v,w)\mathrm{d}v$となる。
  • $q(v)\rightarrow q(w)\rightarrow q(v)\rightarrow\cdots$と計算を繰り返すと、変分下限は増加していき、KL情報量は減少していく。

変分法による変分下限の最大化条件の導出

  • $q(v)$による変分下限の変分を計算する。変分が計算しやすいように変形する。
\begin{align}
L(q)&=\iint q(v)q(w)\log\frac{p(X,v,w)}{q(v)q(w)}\mathrm{d}w\mathrm{d}v\\
&=\int q(v)\left[f(v)-\log q(v)\right]\mathrm{d}v-\int q(w)\log q(w)\mathrm{d}w\\
f(v)&=\int q(w)\log p(X,v,w)\mathrm{d}w
\end{align}
  • $q(v)$の変分を$\delta q(v)$とする。一般に関数$F$の変分は$\frac{\partial F}{\partial q(v)}\delta q(v)$となる。
  • $q(w)$のみに依存する項は変化しないので除外。ラグランジュの未定乗数法を適用。
\begin{align}
L^*(q)&=\int q(v)\left[f(v)-\log q(v)\right]\mathrm{d}v
+\lambda\left[\int q(v)\mathrm{d}v-1\right]\\
\delta L^*(q)&=\int\left[f(v)-\log q(v)-1\right]\delta q(v)\mathrm{d}v
+\lambda\int \delta q(v)\mathrm{d}v\\
&=\int\left[f(v)-\log q(v)-1+\lambda \right]\delta q(v)\mathrm{d}v
\end{align}
  • 変分下限を最大化する分布$q(v)$の条件は、任意の$\delta q(v)$について変分がゼロとなることである。
  • 従って括弧の中がゼロとなる。$\exp(-1+\lambda)$は$q(v)$の正規化定数に含まれる。
\begin{align}
\log q(v)&=f(v)-1+\lambda\\
\log q(v)&\sim\int q(w)\log p(X,v,w)\mathrm{d}w
\end{align}
  • 求めたいパラメータ以外のパラメータについて対数同時確率分布$\log p(X,\cdots)$の平均値を計算すると、変分下限を最大化するパラメータの対数確率分布が得られる。

問題設定

  • 全てのパラメータは分布を持つ。
  • データ$X[x_n]$について混合正規分布を仮定する。
  • 計算が簡単になるように隠れ変数$Z[z_{nk}]$を導入する。事前確率分布として共役事前分布を仮定する。
  • 事後確率分布$p(X,Z,\pi,\mu,\Lambda)$は以下のようになる。
  • $\alpha_0,\beta_0,m_0,\nu_0,W_0$はハイパーパラメータ。$D=\mathrm{dim}(x)$
\begin{align}
p(x_n|\pi,\mu,\Lambda)&=\sum_{k=1}^K \pi_k N(x_n|\mu_k,\Lambda_k^{-1})
\quad(-\infty<x<\infty)\\
p(x_n,z_n|\pi,\mu,\Sigma)&=\prod_{k=1}^K (\pi_k N(x_n|\mu_k,\Lambda_k^{-1}))^{z_{nk}}\\
p(X,Z|\pi,\mu,\Lambda)&=\prod_{n=1}^N \prod_{k=1}^K (\pi_k N(x_n|\mu_k,\Lambda_k^{-1}))^{z_{nk}}\\
p(\pi)&=C(\alpha_0)\prod_{k=1}^K\pi_k^{\alpha_0-1}\quad ディリクレ分布\\
p(\Lambda)&=\prod_{k=1}^KB(W_0,\nu_0)|\Lambda_k|^{\frac{\nu_0-d-1}{2}}\exp\left(-\frac{1}{2}\mathrm{tr}(W_0^{-1}\Lambda_k)\right)\quad ウィッシャート分布\\
p(\mu|\Lambda)&=\prod_{k=1}^K N(\mu_k|m_0,(\beta_0\Lambda_k)^{-1})\quad 多次元正規分布\\
p(X,Z,\pi,\mu,\Lambda)&=p(X,Z|\pi,\mu,\Lambda)p(\pi)p(\mu|\Lambda)p(\Lambda)\\
\end{align}
  • パラメータの事後確率分布$p(Z,\pi,\mu,\Lambda|X)$を$q(Z)q(\pi)q(\mu,\Lambda)$で近似する。

具体的な計算手順

結果のみ記す。導出についは参考文献を参照。

  • $\alpha_k,\beta_k,m_k,\nu_k,W_k \rightarrow r_{nk}$
\begin{align}
r_{nk}&=\frac{\rho_{nk}}{\sum_{k=1}^K\rho_{nk}}\\
\log\rho_{nk}=\textstyle{}E[\log\pi_k]+\frac{1}{2}E[\log|\Lambda_k|]&-\frac{D}{2}\log(2\pi)
-\frac{1}{2}E[(x_n-\mu_k)^t\Lambda_k(x_n-\mu_k)]\\
E[\log\pi_k]&=\textstyle{}\psi(\alpha_k)-\psi(\sum\alpha_k)\\
E[\log|\Lambda_k|]&=\textstyle{}\sum_{i=1}^D\psi(\frac{\nu_k+1-i}{2})+D\log 2+\log|W_k|\\
E[(x_n-\mu_k)^t\Lambda_k(x_n-\mu_k)]&=D\beta_k^{-1}+\nu_k(x_n-m_k)^tW_k(x_n-m_k)
\end{align}
  • $q(Z)$の分布
\begin{align}
\log q(Z)&\sim E_{\pi\mu\Lambda}[\log p(X,Z,\pi,\mu,\Lambda)]\\
q(Z)&=\prod_{n=1}^N\prod_{k=1}^K r_{nk}^{z_{nk}}
\end{align}
  • $r_{nk} \rightarrow \alpha_k,\beta_k,m_k,\nu_k,W_k$
\begin{align}
N_k=\sum_{n=1}^Nr_{nk},\quad\bar{x}_k&=\frac{1}{N_k}\sum_{n=1}^Nr_{nk}x_n,\quad
\bar{\Sigma}_k=\frac{1}{N_k}\sum_{n=1}^N r_{nk}(x_n-\bar{x}_k)(x_n-\bar{x}_k)^t\\
\alpha_k=\alpha_0+N_k,\quad\beta_k&=\beta_0+N_k,\quad
m_k=(\beta_0m_0+N_k\bar{x}_k)\beta_k^{-1}\\
\nu_k=\nu_0+N_k,\quad
W_k^{-1}&=W_0^{-1}+N_k\bar{\Sigma}_k+\frac{\beta_0 N_k}{\beta_0+N_k}(\bar{x}_k-m_0)(\bar{x}_k-m_0)^t
\end{align}
  • $q(\pi)q(\mu,\Lambda)$の分布
\begin{align}
\log q(\pi)&\sim E_{Z\mu\Lambda}[\log p(X,Z,\pi,\mu,\Lambda)]\\
\log q(\mu,\Lambda)&\sim E_{Z\pi}[\log p(X,Z,\pi,\mu,\Lambda)]\\
q(\pi)&=C(\alpha)\prod_{k=1}^K \pi_{k}^{\alpha_k-1}\quad ディリクレ分布\\
q(\mu,\Lambda)&=q(\mu|\Lambda)q(\Lambda)\\
q(\Lambda)&=\prod_{k=1}^K B(W_k,\nu_k)|\Lambda_k|^{\frac{\nu_k-D-1}{2}}\exp\left(-\frac{1}{2}\mathrm{tr}(W_k^{-1}\Lambda_k)\right)\quad ウィッシャート分布\\
q(\mu|\Lambda)&=\prod_{k=1}^K N(\mu_k|m_k,(\beta_k\Lambda_k)^{-1})\quad 多次元正規分布
\end{align}
  • 変分下限の計算式
\begin{align}
L(q)&=\int q(Z)q(\pi)q(\mu,\Lambda)\log\frac{p(X,Z,\pi,\mu,\Lambda)}
{q(Z)q(\pi)q(\mu,\Lambda)}
\mathrm{d}Z\mathrm{d}\pi\mathrm{d}\mu\mathrm{d}\Lambda\\
&=\log \frac{C(\alpha_0)}{C(\alpha)}
+\frac{D}{2}\sum_{k=1}^K\log\frac{\beta_0}{\beta_k}
+\sum_{k=1}^K\log \frac{B(W_0,\nu_0)}{B(W_k,\nu_k)}\\
&\quad-\sum_{n=1}^N\sum_{k=1}^Kr_{nk}\log r_{nk}-N\frac{D}{2}\log(2\pi)
\end{align}

Rプログラムと実行結果

# Multivariate Normal Distribution
my_mvrnorm = function(n, mm, sig){
  ar = t(chol(sig))
  m = length(mm)
  y = ar %*% matrix(rnorm(n*m),nrow=m,ncol=n) + mm
  t(y)
}
my_mvdnorm = function(xn, mm, sig){
  sig_ = solve(sig)
  xm = t(xn) - mm
  ss = colSums(xm * (sig_ %*% xm))
  f1 = log(det(sig)) + length(mm)*log(2*pi)
  f2 = (f1 + ss) * 0.5
  exp(-f2)
}

# Normalizing constant of Dirichlet distribution
my_log_C = function(a){
  lgamma(sum(a))-sum(lgamma(a))
}
# Normalizing constant of Wishart distribution
my_log_B = function(W,nu){
  d = nrow(W)
  -nu/2*log(det(W)) - nu*d/2*log(2) - d*(d-1)/4*log(pi) - sum(lgamma((nu+1-1:d)/2))
}

my_plot_ellipse = function(mm, sig){
  ar = t(chol(sig))
  theta = seq(-pi, pi, length=100)
  y = matrix(c(cos(theta),sin(theta)),nrow=2,ncol=100,byrow=T)
  x = ar %*% y + mm
  polygon(x[1,], x[2,])
}

# Data generation
N = 100 # number of points
D = 2  # dimension
K = 3  # number of cluster
set.seed(10)

rt = (function(){
  kr = c(1,2,1); kr = kr/sum(kr)  # allocation ratio
  kd = c(0.3,0.5,1)/2  # diagonal term of variance
  mu = matrix(c(0,0,0,2,2,1),nrow=K,ncol=D,byrow=T) # mean
  sigma = lapply(1:K,function(k){ # variance-covariance matrix
    diag(kd[k],nrow=D,ncol=D)
  })
  nk = as.integer(kr*N+0.5)
  nk[1] = N - sum(nk[-1])
  xn = NULL # data points
  gn = NULL # cluster id
  for(k in 1:K){
    xn = rbind(xn,my_mvrnorm(nk[k],mu[k,],sigma[[k]]))
    gn = c(gn,rep(k,nk[k]))
  }
  list(xn=xn,gn=gn,mu=mu,sigma=sigma)
})()
xn = rt$xn

cols = rainbow(K)
plot(rt$xn,xlab="x",ylab="y",type="p",pch=20,col=cols[rt$gn],asp=1,cex.axis=1.5,cex.lab=1.5)
for(k in 1:K){
  my_plot_ellipse(rt$mu[k,], rt$sigma[[k]])
}

# Variational Bayesian methods

N = nrow(xn)
D = ncol(xn)  # dimension
K = 3  # number of cluster
cols = rainbow(K)

# Hyperparameter
a0 = 1
b0 = 1
m0 = rep(0,D)
nu0 = D
W0 = diag(D)
W0_inv = solve(W0)

# Initial parameter
ak = rep(a0,K) + N / K
bk = rep(b0,K) + N / K
nuk = rep(nu0,K) + N / K
set.seed(10)
mk = matrix(runif(K*D,min=0,max=2),nrow=K,ncol=D)
Wk = rep(list(W0),K)

log_lik = numeric()
for(iter in 1:20){
  # E step
  E_log_pi = digamma(ak) - digamma(sum(ak))
  E_log_lambda = sapply(1:K,function(k){
    sum(digamma((nuk[k]+1-(1:D))/2))+D*log(2)+log(det(Wk[[k]]))
  })
  log_rho = sapply(1:K,function(k){
    xnm = t(xn) - mk[k,]
    E_xn_lambda = D/bk[k] + nuk[k] * colSums(xnm * (Wk[[k]] %*% xnm))
    E_log_pi[k]+0.5*E_log_lambda[k]-0.5*D*log(2*pi)-0.5*E_xn_lambda
  })
  rho = t(apply(log_rho,1,function(a){exp(a-max(a))}))
  rnk = rho / rowSums(rho)

  # M step
  nk = colSums(rnk)
  xbk = (t(rnk) %*% xn) / nk
  sgk = lapply(1:K,function(k){
    s1 = t(t(xn) - xbk[k,])
    s2 = s1 * rnk[,k]
    (t(s2) %*% s1) / nk[k]
  })
  ak = nk + a0
  bk = nk + b0
  mk = t(sapply(1:K,function(k){(b0 * m0 + nk[k] * xbk[k,])/bk[k]}))
  nuk = nk + nu0
  Wk = lapply(1:K,function(k){
    xm = xbk[k,] - m0
    solve(W0_inv + nk[k] * sgk[[k]] + (xm %*% t(xm)) * b0 * nk[k] / bk[k])
  })

  log_BW = sapply(1:K,function(k){my_log_B(W0,nu0) - my_log_B(Wk[[k]],nuk[k])})
  L = sum(log_BW)
  L = L + my_log_C(rep(a0,K)) - my_log_C(ak)
  L = L + D/2*sum(log(b0/bk))
  L = L - N*D/2*log(2*pi) - sum(rnk*log(rnk))
  log_lik[iter] = L

  msg = sprintf("iter=%d, log_lik=%.3f",iter,L)
  gnk = apply(rho,1,which.max)
  plot(xn,main=msg,xlab="x",ylab="y",type="p",pch=20,col=cols[gnk],asp=1,cex.axis=1.5,cex.lab=1.5)
  for(k in unique(gnk)){
    sig = solve(bk[k] * Wk[[k]])
    my_plot_ellipse(mk[k,], sig)
  }
  Sys.sleep(1)
}
data.frame(log_lik,diff=diff(c(NA,log_lik)))

vb.png

     log_lik         diff
1  -300.9549           NA
2  -293.7659 7.189060e+00
3  -292.0074 1.758448e+00
4  -291.0707 9.367485e-01
5  -290.4214 6.493015e-01
6  -289.7014 7.200212e-01
7  -288.6599 1.041495e+00
8  -286.8410 1.818830e+00
9  -283.4597 3.381360e+00
10 -280.4321 3.027625e+00
11 -279.6208 8.112251e-01
12 -279.5314 8.942274e-02
13 -279.5247 6.709431e-03
14 -279.5242 5.365955e-04
15 -279.5241 5.353788e-05
16 -279.5241 7.483725e-06
17 -279.5241 1.595826e-06
18 -279.5241 4.672365e-07
19 -279.5241 1.581595e-07
20 -279.5241 5.630500e-08
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4