webサイトのログ分析のような、「疎」なデータをいじっていると、大体 *分布なんだけど0のデータが多すぎるよん、という事はよく有る。そういう時は、zero inflated * distribution というのの出番らしい。
これは、要するに、0に確率が集中したデルタ分布と、普通の ** 分布との混合分布で、EMアルゴリズムで最尤推定出来る。
参考文献:Murphy, Kevin P. Machine learning: a probabilistic perspective. MIT Press, 2012. の、Chap.11
zero inflated geometric distribution の例
ソース
zi_geom.r
num <- 1000
eps <- 0.3
prob <- 0.2
z <- ifelse(runif(num) < eps, 1, 0)
y <- ifelse(z==1, rgeom(num, prob), 0)
geom.est <- function(x, weight=1){
a <- sum(weight)
b <- sum(weight * (1 + x))
a / b
}
e.step <- function(obs, p.hat, eps.hat){
num <- length(obs)
ret <- matrix(nrow=num, ncol=2) ## Prob(z=0 | y[i]), Prob(z=1 | y[i])
flag <- obs > 0
ret[flag, 1] <- 0
ret[flag, 2] <- 1
ret[!flag, 1] <- 1 - eps.hat
ret[!flag, 2] <- eps.hat * pgeom(0, p.hat)
c <- apply(ret, 1, sum)
ret / c
}
m.step <- function(obs, post){
num <- length(obs)
eps.hat <- sum(post[,2]) / num
p.hat <- geom.est(obs, post[,2])
list(eps=eps.hat, p=p.hat)
}
eps.hat <- 0.5
prob.hat <- 0.5
for(i in 1:20){
cat(sprintf("eps=%f, prob=%f\n", eps.hat, prob.hat))
post <- e.step(y, prob.hat, eps.hat)
m.ret <- m.step(y, post)
eps.hat <- m.ret$eps
prob.hat <- m.ret$p
}
結果
> source("zi_geom.r")
eps=0.500000, prob=0.500000
eps=0.491333, prob=0.301370
eps=0.409031, prob=0.264227
eps=0.354965, prob=0.237599
eps=0.325227, prob=0.222115
eps=0.310784, prob=0.214366
eps=0.304253, prob=0.210811
eps=0.301402, prob=0.209249
eps=0.300179, prob=0.208576
eps=0.299657, prob=0.208289
eps=0.299435, prob=0.208167
eps=0.299341, prob=0.208115
eps=0.299301, prob=0.208093
eps=0.299284, prob=0.208084
eps=0.299277, prob=0.208080
eps=0.299274, prob=0.208078
eps=0.299273, prob=0.208078
eps=0.299272, prob=0.208077
eps=0.299272, prob=0.208077
eps=0.299272, prob=0.208077