記事の目的
Rで混合ガウスモデルのギブスサンプリングを実装します。
For文をできるだけ使用せず、早く学習するようにしました。
参考: ノンパラメトリックベイズ 点過程と統計的機械学習の数理
目次
No. | 目次
1 |
モデルの説明
|
2 |
データとライブラリ
|
3 |
実装
|
4 |
確認
|
|
1. モデルの説明
2. データとライブラリ
データはirisのデータセットを使用します。
X <- iris[,1:4]
D <- ncol(X)
N <- nrow(X)
library(mvtnorm)
library(MCMCpack)
library(cluster)
3. 実装
library(mvtnorm)
library(MCMCpack)
# (1)Kを求める
K <- 3
# (2)muを乱数で初期化
set.seed(100)
mu <- matrix(rep(apply(X,2,mean),each=K)+rnorm(K*D,0,0.1), nrow=K)
# (3)
a0 <- 1
b0 <- 1
alpha0 <- 1
# (4)
pi <- rep(1/K, K)
t <- 1
# (5)(ⅰ)-(ⅳ)を繰り返す
max.iter <- 5
a <- a0 + N*D/2
for(s in 1:max.iter){
#(ⅰ)zのサンプリング
tmp <- t(apply(mu, 1, function(x) dmvnorm(X, x, diag(D)/t)))*as.vector(pi)
z <- apply(tmp, 2, function(x) which.max(rmultinom(1,3,x)))
#(ⅱ)mu(µ)のサンプリング
n <- tapply(z, z, length)
x.k <- apply(X, 2, function(x) tapply(x, z, mean))
mu <- t(apply(cbind(x.k, n), 1,
function(x) rmvnorm(1, x[D+1]*x[1:(D)]/(x[D+1]+1), diag(D)/(t*(x[D+1]+1)))))
#(ⅲ)t(τ)のサンプリング
b <- b0 + sum(unlist(apply(X, 2, function(x) tapply(x, z, function(x) (x-mean(x))^2))))/2 +
sum(as.vector(n/(2*(1+n)))*((x.k%*%t(x.k))*diag(K)))
t <- rgamma(1, a, b)
#(ⅳ)pi(π)のサンプリング
alpha <- alpha0 + n
pi <- rdirichlet(1, alpha)
}
4. 確認
左が正解で、右が実装の結果です。
library(cluster)
par(mfrow=c(1,2))
clusplot(X, iris[,5], color=TRUE, shade=FALSE, labels=4, lines=0)
clusplot(X, z, color=TRUE, shade=FALSE, labels=4, lines=0)