#はじめに
- 変分ベイズ法の考え方のメモです。
- 参考文献 変分法をごまかさずに変分ベイズの説明をする
- 参考文献 変分推論法(変分ベイズ法)(PRML第10章)
- 参考文献 パターン認識と機械学習の学習 普及版
#変分ベイズ法の考え方
- 変分ベイズ法は、パラメータの事後確率分布$p(v,w|X)$を確率分布の積$q(v)q(w)$で近似する手法。
- 近似は、KL情報量を最小化する分布とする。
- 対数周辺尤度$\log p(X)$は変分下限$L(q)$とKL情報量の和に分解される。
\begin{align}
\log p(X)&=L(q)+KL(q\parallel p)\\
L(q)&=\iint q(v)q(w)\log\frac{p(X,v,w)}{q(v)q(w)}\mathrm{d}v\mathrm{d}w\\
KL(q\parallel p)&=\iint q(v)q(w)\log\frac{q(v)q(w)}{p(v,w|X)}\mathrm{d}v\mathrm{d}w
\end{align}
- KL情報量を直接計算することは困難なので、代わりに変分下限を最大化する。
- 条件付き確率分布$p(v,w|X)$の計算は困難だが、同時確率分布$p(X,v,w)$の計算は可能。
- 変分法により$q(v)$で最大化すると、$\log q(v)\sim\int q(w)\log p(X,v,w)\mathrm{d}w$となる。
- 変分法により$q(w)$で最大化すると、$\log q(w)\sim\int q(v)\log p(X,v,w)\mathrm{d}v$となる。
- $q(v)\rightarrow q(w)\rightarrow q(v)\rightarrow\cdots$と計算を繰り返すと、変分下限は増加していき、KL情報量は減少していく。
#変分法による変分下限の最大化条件の導出
- $q(v)$による変分下限の変分を計算する。変分が計算しやすいように変形する。
\begin{align}
L(q)&=\iint q(v)q(w)\log\frac{p(X,v,w)}{q(v)q(w)}\mathrm{d}w\mathrm{d}v\\
&=\int q(v)\left[f(v)-\log q(v)\right]\mathrm{d}v-\int q(w)\log q(w)\mathrm{d}w\\
f(v)&=\int q(w)\log p(X,v,w)\mathrm{d}w
\end{align}
- $q(v)$の変分を$\delta q(v)$とする。一般に関数$F$の変分は$\frac{\partial F}{\partial q(v)}\delta q(v)$となる。
- $q(w)$のみに依存する項は変化しないので除外。ラグランジュの未定乗数法を適用。
\begin{align}
L^*(q)&=\int q(v)\left[f(v)-\log q(v)\right]\mathrm{d}v
+\lambda\left[\int q(v)\mathrm{d}v-1\right]\\
\delta L^*(q)&=\int\left[f(v)-\log q(v)-1\right]\delta q(v)\mathrm{d}v
+\lambda\int \delta q(v)\mathrm{d}v\\
&=\int\left[f(v)-\log q(v)-1+\lambda \right]\delta q(v)\mathrm{d}v
\end{align}
- 変分下限を最大化する分布$q(v)$の条件は、任意の$\delta q(v)$について変分がゼロとなることである。
- 従って括弧の中がゼロとなる。$\exp(-1+\lambda)$は$q(v)$の正規化定数に含まれる。
\begin{align}
\log q(v)&=f(v)-1+\lambda\\
\log q(v)&\sim\int q(w)\log p(X,v,w)\mathrm{d}w
\end{align}
- 求めたいパラメータ以外のパラメータについて対数同時確率分布$\log p(X,\cdots)$の平均値を計算すると、変分下限を最大化するパラメータの対数確率分布が得られる。
#問題設定
- 全てのパラメータは分布を持つ。
- データ$X[x_n]$について混合正規分布を仮定する。
- 計算が簡単になるように隠れ変数$Z[z_{nk}]$を導入する。事前確率分布として共役事前分布を仮定する。
- 事後確率分布$p(X,Z,\pi,\mu,\Lambda)$は以下のようになる。
- $\alpha_0,\beta_0,m_0,\nu_0,W_0$はハイパーパラメータ。$D=\mathrm{dim}(x)$
\begin{align}
p(x_n|\pi,\mu,\Lambda)&=\sum_{k=1}^K \pi_k N(x_n|\mu_k,\Lambda_k^{-1})
\quad(-\infty<x<\infty)\\
p(x_n,z_n|\pi,\mu,\Sigma)&=\prod_{k=1}^K (\pi_k N(x_n|\mu_k,\Lambda_k^{-1}))^{z_{nk}}\\
p(X,Z|\pi,\mu,\Lambda)&=\prod_{n=1}^N \prod_{k=1}^K (\pi_k N(x_n|\mu_k,\Lambda_k^{-1}))^{z_{nk}}\\
p(\pi)&=C(\alpha_0)\prod_{k=1}^K\pi_k^{\alpha_0-1}\quad ディリクレ分布\\
p(\Lambda)&=\prod_{k=1}^KB(W_0,\nu_0)|\Lambda_k|^{\frac{\nu_0-d-1}{2}}\exp\left(-\frac{1}{2}\mathrm{tr}(W_0^{-1}\Lambda_k)\right)\quad ウィッシャート分布\\
p(\mu|\Lambda)&=\prod_{k=1}^K N(\mu_k|m_0,(\beta_0\Lambda_k)^{-1})\quad 多次元正規分布\\
p(X,Z,\pi,\mu,\Lambda)&=p(X,Z|\pi,\mu,\Lambda)p(\pi)p(\mu|\Lambda)p(\Lambda)\\
\end{align}
- パラメータの事後確率分布$p(Z,\pi,\mu,\Lambda|X)$を$q(Z)q(\pi)q(\mu,\Lambda)$で近似する。
#具体的な計算手順
結果のみ記す。導出についは参考文献を参照。
- $\alpha_k,\beta_k,m_k,\nu_k,W_k \rightarrow r_{nk}$
\begin{align}
r_{nk}&=\frac{\rho_{nk}}{\sum_{k=1}^K\rho_{nk}}\\
\log\rho_{nk}=\textstyle{}E[\log\pi_k]+\frac{1}{2}E[\log|\Lambda_k|]&-\frac{D}{2}\log(2\pi)
-\frac{1}{2}E[(x_n-\mu_k)^t\Lambda_k(x_n-\mu_k)]\\
E[\log\pi_k]&=\textstyle{}\psi(\alpha_k)-\psi(\sum\alpha_k)\\
E[\log|\Lambda_k|]&=\textstyle{}\sum_{i=1}^D\psi(\frac{\nu_k+1-i}{2})+D\log 2+\log|W_k|\\
E[(x_n-\mu_k)^t\Lambda_k(x_n-\mu_k)]&=D\beta_k^{-1}+\nu_k(x_n-m_k)^tW_k(x_n-m_k)
\end{align}
- $q(Z)$の分布
\begin{align}
\log q(Z)&\sim E_{\pi\mu\Lambda}[\log p(X,Z,\pi,\mu,\Lambda)]\\
q(Z)&=\prod_{n=1}^N\prod_{k=1}^K r_{nk}^{z_{nk}}
\end{align}
- $r_{nk} \rightarrow \alpha_k,\beta_k,m_k,\nu_k,W_k$
\begin{align}
N_k=\sum_{n=1}^Nr_{nk},\quad\bar{x}_k&=\frac{1}{N_k}\sum_{n=1}^Nr_{nk}x_n,\quad
\bar{\Sigma}_k=\frac{1}{N_k}\sum_{n=1}^N r_{nk}(x_n-\bar{x}_k)(x_n-\bar{x}_k)^t\\
\alpha_k=\alpha_0+N_k,\quad\beta_k&=\beta_0+N_k,\quad
m_k=(\beta_0m_0+N_k\bar{x}_k)\beta_k^{-1}\\
\nu_k=\nu_0+N_k,\quad
W_k^{-1}&=W_0^{-1}+N_k\bar{\Sigma}_k+\frac{\beta_0 N_k}{\beta_0+N_k}(\bar{x}_k-m_0)(\bar{x}_k-m_0)^t
\end{align}
- $q(\pi)q(\mu,\Lambda)$の分布
\begin{align}
\log q(\pi)&\sim E_{Z\mu\Lambda}[\log p(X,Z,\pi,\mu,\Lambda)]\\
\log q(\mu,\Lambda)&\sim E_{Z\pi}[\log p(X,Z,\pi,\mu,\Lambda)]\\
q(\pi)&=C(\alpha)\prod_{k=1}^K \pi_{k}^{\alpha_k-1}\quad ディリクレ分布\\
q(\mu,\Lambda)&=q(\mu|\Lambda)q(\Lambda)\\
q(\Lambda)&=\prod_{k=1}^K B(W_k,\nu_k)|\Lambda_k|^{\frac{\nu_k-D-1}{2}}\exp\left(-\frac{1}{2}\mathrm{tr}(W_k^{-1}\Lambda_k)\right)\quad ウィッシャート分布\\
q(\mu|\Lambda)&=\prod_{k=1}^K N(\mu_k|m_k,(\beta_k\Lambda_k)^{-1})\quad 多次元正規分布
\end{align}
- 変分下限の計算式
\begin{align}
L(q)&=\int q(Z)q(\pi)q(\mu,\Lambda)\log\frac{p(X,Z,\pi,\mu,\Lambda)}
{q(Z)q(\pi)q(\mu,\Lambda)}
\mathrm{d}Z\mathrm{d}\pi\mathrm{d}\mu\mathrm{d}\Lambda\\
&=\log \frac{C(\alpha_0)}{C(\alpha)}
+\frac{D}{2}\sum_{k=1}^K\log\frac{\beta_0}{\beta_k}
+\sum_{k=1}^K\log \frac{B(W_0,\nu_0)}{B(W_k,\nu_k)}\\
&\quad-\sum_{n=1}^N\sum_{k=1}^Kr_{nk}\log r_{nk}-N\frac{D}{2}\log(2\pi)
\end{align}
#Rプログラムと実行結果
# Multivariate Normal Distribution
my_mvrnorm = function(n, mm, sig){
ar = t(chol(sig))
m = length(mm)
y = ar %*% matrix(rnorm(n*m),nrow=m,ncol=n) + mm
t(y)
}
my_mvdnorm = function(xn, mm, sig){
sig_ = solve(sig)
xm = t(xn) - mm
ss = colSums(xm * (sig_ %*% xm))
f1 = log(det(sig)) + length(mm)*log(2*pi)
f2 = (f1 + ss) * 0.5
exp(-f2)
}
# Normalizing constant of Dirichlet distribution
my_log_C = function(a){
lgamma(sum(a))-sum(lgamma(a))
}
# Normalizing constant of Wishart distribution
my_log_B = function(W,nu){
d = nrow(W)
-nu/2*log(det(W)) - nu*d/2*log(2) - d*(d-1)/4*log(pi) - sum(lgamma((nu+1-1:d)/2))
}
my_plot_ellipse = function(mm, sig){
ar = t(chol(sig))
theta = seq(-pi, pi, length=100)
y = matrix(c(cos(theta),sin(theta)),nrow=2,ncol=100,byrow=T)
x = ar %*% y + mm
polygon(x[1,], x[2,])
}
# Data generation
N = 100 # number of points
D = 2 # dimension
K = 3 # number of cluster
set.seed(10)
rt = (function(){
kr = c(1,2,1); kr = kr/sum(kr) # allocation ratio
kd = c(0.3,0.5,1)/2 # diagonal term of variance
mu = matrix(c(0,0,0,2,2,1),nrow=K,ncol=D,byrow=T) # mean
sigma = lapply(1:K,function(k){ # variance-covariance matrix
diag(kd[k],nrow=D,ncol=D)
})
nk = as.integer(kr*N+0.5)
nk[1] = N - sum(nk[-1])
xn = NULL # data points
gn = NULL # cluster id
for(k in 1:K){
xn = rbind(xn,my_mvrnorm(nk[k],mu[k,],sigma[[k]]))
gn = c(gn,rep(k,nk[k]))
}
list(xn=xn,gn=gn,mu=mu,sigma=sigma)
})()
xn = rt$xn
cols = rainbow(K)
plot(rt$xn,xlab="x",ylab="y",type="p",pch=20,col=cols[rt$gn],asp=1,cex.axis=1.5,cex.lab=1.5)
for(k in 1:K){
my_plot_ellipse(rt$mu[k,], rt$sigma[[k]])
}
# Variational Bayesian methods
N = nrow(xn)
D = ncol(xn) # dimension
K = 3 # number of cluster
cols = rainbow(K)
# Hyperparameter
a0 = 1
b0 = 1
m0 = rep(0,D)
nu0 = D
W0 = diag(D)
W0_inv = solve(W0)
# Initial parameter
ak = rep(a0,K) + N / K
bk = rep(b0,K) + N / K
nuk = rep(nu0,K) + N / K
set.seed(10)
mk = matrix(runif(K*D,min=0,max=2),nrow=K,ncol=D)
Wk = rep(list(W0),K)
log_lik = numeric()
for(iter in 1:20){
# E step
E_log_pi = digamma(ak) - digamma(sum(ak))
E_log_lambda = sapply(1:K,function(k){
sum(digamma((nuk[k]+1-(1:D))/2))+D*log(2)+log(det(Wk[[k]]))
})
log_rho = sapply(1:K,function(k){
xnm = t(xn) - mk[k,]
E_xn_lambda = D/bk[k] + nuk[k] * colSums(xnm * (Wk[[k]] %*% xnm))
E_log_pi[k]+0.5*E_log_lambda[k]-0.5*D*log(2*pi)-0.5*E_xn_lambda
})
rho = t(apply(log_rho,1,function(a){exp(a-max(a))}))
rnk = rho / rowSums(rho)
# M step
nk = colSums(rnk)
xbk = (t(rnk) %*% xn) / nk
sgk = lapply(1:K,function(k){
s1 = t(t(xn) - xbk[k,])
s2 = s1 * rnk[,k]
(t(s2) %*% s1) / nk[k]
})
ak = nk + a0
bk = nk + b0
mk = t(sapply(1:K,function(k){(b0 * m0 + nk[k] * xbk[k,])/bk[k]}))
nuk = nk + nu0
Wk = lapply(1:K,function(k){
xm = xbk[k,] - m0
solve(W0_inv + nk[k] * sgk[[k]] + (xm %*% t(xm)) * b0 * nk[k] / bk[k])
})
log_BW = sapply(1:K,function(k){my_log_B(W0,nu0) - my_log_B(Wk[[k]],nuk[k])})
L = sum(log_BW)
L = L + my_log_C(rep(a0,K)) - my_log_C(ak)
L = L + D/2*sum(log(b0/bk))
L = L - N*D/2*log(2*pi) - sum(rnk*log(rnk))
log_lik[iter] = L
msg = sprintf("iter=%d, log_lik=%.3f",iter,L)
gnk = apply(rho,1,which.max)
plot(xn,main=msg,xlab="x",ylab="y",type="p",pch=20,col=cols[gnk],asp=1,cex.axis=1.5,cex.lab=1.5)
for(k in unique(gnk)){
sig = solve(bk[k] * Wk[[k]])
my_plot_ellipse(mk[k,], sig)
}
Sys.sleep(1)
}
data.frame(log_lik,diff=diff(c(NA,log_lik)))
log_lik diff
1 -300.9549 NA
2 -293.7659 7.189060e+00
3 -292.0074 1.758448e+00
4 -291.0707 9.367485e-01
5 -290.4214 6.493015e-01
6 -289.7014 7.200212e-01
7 -288.6599 1.041495e+00
8 -286.8410 1.818830e+00
9 -283.4597 3.381360e+00
10 -280.4321 3.027625e+00
11 -279.6208 8.112251e-01
12 -279.5314 8.942274e-02
13 -279.5247 6.709431e-03
14 -279.5242 5.365955e-04
15 -279.5241 5.353788e-05
16 -279.5241 7.483725e-06
17 -279.5241 1.595826e-06
18 -279.5241 4.672365e-07
19 -279.5241 1.581595e-07
20 -279.5241 5.630500e-08