PRML図8.51の木構造因子グラフについて、積和アルゴリズム(sum-productアルゴリズム)により、各変数ノードxnの周辺分布p(xn)を求めます。
まず、素朴に式8.61の通り、xが取り得る状態全ての場合について同時分布を求めて、それらの合計により周辺分布を求めます。(なお、この実装ではグラフを解析しているわけではなく、今回の例に合わせて因子の積を決めうちで計算しています。)
一方、積和アルゴリズムを用いて、木の葉から根方向と、根から葉方向へのメッセージパッシングによって全てのメッセージを求め、式8.63の通りメッセージの積として、同じ結果が得られることを確認します。
なお、各変数ノードxnは3個の状態をとるものとし、各因子ノードの因子f(x1,x2)は、x1==x2のとき0.8、そうでないとき0.1としています。
また、変数ノードx1に値1を観測したとき、および変数ノードx2に値1を観測したときの周辺分布も求めます。
library(igraph)
frame()
set.seed(0)
par(mfrow=c(3, 3))
par(mar=c(2, 2, 1, 0))
par(mgp=c(1, 0, 0))
K <- 3
N <- 4
ROOT <- 3
g <- graph(c(1, N+1, N+1, 2, 4, N+2, N+2, 2, 2, N+3, N+3, 3), directed=F)
V(g)$size <- 40
V(g)$label.cex <- 1.5
V(g)$shape <- "circle"
V(g)$shape[(N+1):(N+3)] <- "square"
potential1 <- function(n, xparent, xchild) {
# 因数ノードnのポテンシャル関数
# xparent:上層変数ノード
# xchild:下層変数ノード
p <- ifelse(xparent == xchild, 0.8, 0.1)
}
potential2 <- function(n, xparent, xchild) {
p <- potential1(n, xparent, xchild)
# 変数ノード1に1を観測
OBSERVED <- 1
if (n == 5) {
p <- p * ifelse(xchild == OBSERVED, 1, 0)
}
p
}
potential3 <- function(n, xparent, xchild) {
p <- potential1(n, xparent, xchild)
# 変数ノード2に1を観測
OBSERVED <- 1
if (n == 5 || n == 6) {
p <- p * ifelse(xparent == OBSERVED, 1, 0)
} else if (n == 7) {
p <- p * ifelse(xchild == OBSERVED, 1, 0)
}
p
}
doplot.joint <- function(potential) {
# 同時分布を求める
d <- data.frame()
for (i in 0:(K^N - 1)) { # O(K^N)
total.potential <- 1
xs <- c()
xs <- c(xs, (i) %% K)
for (n in 1:(N-1)) {
xn1 <- (i %/% K ^ (n - 1)) %% K
xn2 <- (i %/% K ^ n) %% K
xs <- c(xs, xn2)
}
total.potential <- (potential(5, xs[2], xs[1]) * potential(6, xs[2], xs[4]) * potential(7, xs[3], xs[2]))
d <- rbind(d, c(xs, total.potential))
}
names(d) <- c(paste0("x", 1:N), "p")
cat("p(x)\n");print(rbind(head(d),tail(d)))
# 同時分布の和により周辺分布を求める
ps <- matrix(nrow=N, ncol=K)
rownames(ps) <- 1:N
colnames(ps) <- 0:(K-1)
z <- sum(d$p)
for (n in 1:N) {
for (xn in 0:(K-1)) {
ps[n, xn + 1] <- sum(d[d[, n] == xn, ]$p) / z
}
}
cat("p(xn)\n");print(ps)
barplot(t(ps), legend=0:(K-1), xlab="xn", ylab="p(xn)")
title("sum of joint")
}
doplot.message <- function(potential) {
mu <- as.data.frame(matrix(c(-1, -1, rep(NA, K)), ncol=2 + K))
names(mu) <- c("from", "to", 0:(K-1))
mufxUp <- function(from, to) {
# 下層からのメッセージを求め、それを基に、因子ノードfromから変数ノードtoへのメッセージを求める
# 下層からのメッセージを先に求める
children <- neighbors(g, from)
for (child in children) {
if (child != to) {
muxfUp(child, from)
}
}
# to以外からのメッセージ(の積)にポテンシャルを掛けたものの、to以外の確率変数値に関する和を求める
p <- rep(NA, K)
for (x in 0:(K-1)) { # O(K^隣接ノード数)
m <- as.matrix(mu[mu$from == children[children != to] & mu$to == from, c(-1, -2)])
p[x + 1] <- sum(potential(from, x, 0:(K-1)) * m)
}
mu <<- rbind(mu, c(from, to, p))
p
}
muxfUp <- function(from, to) {
# 下層からのメッセージを求め、それを基に、変数ノードfromから因子ノードtoへのメッセージを求める
# to以外のメッセージ(の積)を求める
p <- rep(1, K)
for (child in neighbors(g, from)) {
if (child != to) {
p <- p * mufxUp(child, from)
}
}
mu <<- rbind(mu, c(from, to, p))
p
}
mufxDown <- function(from, to) {
# 因子ノードfromから変数ノードtoへのメッセージを求め、それを基に下層へのメッセージも求める
# to以外からのメッセージ(の積)にポテンシャルを掛けたものの、to以外の確率変数値に関する和を求める
children <- neighbors(g, from)
p <- rep(NA, K)
for (x in 0:(K-1)) { # O(K^隣接ノード数)
m <- as.matrix(mu[mu$from == children[children != to] & mu$to == from, c(-1, -2)])
p[x + 1] <- sum(potential(from, 0:(K-1), x) * m)
}
mu <<- rbind(mu, c(from, to, p))
# 下層へのメッセージを求める
for (child in neighbors(g, to)) {
if (child != from) {
muxfDown(to, child)
}
}
p
}
muxfDown <- function(from, to) {
# 変数ノードfromから因子ノードtoへのメッセージを求め、それを基に下層へのメッセージも求める
# to以外のメッセージ(の積)を求める
p <- rep(1, K)
for (child in neighbors(g, from)) {
if (child != to) {
p <- p * as.matrix(mu[mu$from == child & mu$to == from, c(-1, -2)])
}
}
mu <<- rbind(mu, c(from, to, p))
# 下層へのメッセージを求める
for (child in neighbors(g, to)) {
if (child != from) {
mufxDown(to, child)
}
}
p
}
mufxUp(neighbors(g, ROOT)[1], ROOT)
muxfDown(ROOT, neighbors(g, ROOT)[1])
cat("mu\n");print(mu)
ps <- matrix(nrow=N, ncol=K)
rownames(ps) <- 1:N
colnames(ps) <- 0:(K-1)
for (n in 1:N) {
p <- apply(mu[mu$to == n, c(-1, -2)], 2, prod)
z <- sum(p)
ps[n, ] <- p / z
}
barplot(t(ps), legend=0:(K-1), xlab="xn", ylab="p(xn)")
title("message passing")
}
V(g)$color="white"
plot(g, layout=layout.reingold.tilford(g, root=ROOT))
doplot.joint(potential1)
doplot.message(potential1)
V(g)$color="white"
V(g)$color[1]="gray"
plot(g, layout=layout.reingold.tilford(g, root=ROOT))
doplot.joint(potential2)
doplot.message(potential2)
V(g)$color="white"
V(g)$color[2]="gray"
plot(g, layout=layout.reingold.tilford(g, root=ROOT))
doplot.joint(potential3)
doplot.message(potential3)