LoginSignup
1
0

More than 5 years have passed since last update.

Rでマルコフ連鎖の周辺分布を厳密推論

Last updated at Posted at 2013-03-08

PRML 8.4.1に記載の通り、N個のノードを持ちそれぞれのノードがK個の状態を取るようなマルコフ連鎖について、各ノードxnの周辺分布p(xn)を推論します。

まず、素朴に式8.50の通り、xが取り得る状態全ての場合について同時分布を求めて、それらの合計により周辺分布を求めます。全ての場合について同時分布を求めるので、計算量はO(K^N)となります。

一方、式8.52の通りポテンシャルの和積計算を、連鎖の前方および後方からのメッセージパッシングとして解釈することにより、O(N*K^2)の計算量で同じ結果が得られることを確認します。

なお、与えるマルコフ連鎖は、以下の条件付き分布の積として定義されるベイジアンネットワークを、式8.45に従って等価なマルコフ確率場に変換したものです。

x1 p(x1)
0 0.1
1 0.1
2 0.8
x(n-1) xn p(xn|x(n-1))
0 0 0.8
0 1 0.1
0 2 0.1
1 0 0.1
1 1 0.8
1 2 0.1
2 0 0.1
2 1 0.1
2 2 0.8

またノードx4に値1を観測したときの周辺分布も求めます。

frame()
set.seed(0)
par(mfcol=c(2, 2))
par(mar=c(2, 2, 1, 0))
par(mgp=c(1, 0, 0))
N <- 7
K <- 3
potential1 <- function(n, x1, x2) {
    # p(x1)
    if (n == 1) {
        p <- ifelse(x1 == 2, 0.8, 0.1)
    } else {
        p <- 1
    }
    # p(xn|xn-1)
    p <- p * ifelse(x1 == x2, 0.8, 0.1)
    p
}
potential2 <- function(n, x1, x2) {
    # p(x1)
    if (n == 1) {
        p <- ifelse(x1 == 2, 0.8, 0.1)
    } else {
        p <- 1
    }
    # p(xn|xn-1)
    p <- p * ifelse(x1 == x2, 0.8, 0.1)
    # x4に1を観測
    OBSERVED <- 1
    if (n == 4) {
        p <- p * ifelse(x1 == OBSERVED, 1, 0)
    } else if (n + 1 == 4) {
        p <- p * ifelse(x2 == OBSERVED, 1, 0)
    }
    p
}

doplot.joint <- function(potential) {
    # 同時分布を求める
    d <- data.frame()
    for (i in 0:(K^N - 1)) {  # O(K^N)
        total.potential <- 1
        xs <- c()
        xs <- c(xs, (i) %% K)
        for (n in 1:(N-1)) {
            xn1 <- (i %/% K ^ (n - 1)) %% K
            xn2 <- (i %/% K ^ n) %% K
            xs <- c(xs, xn2)
            psi <- potential(n, xn1, xn2)
            total.potential <- total.potential * psi
        }
        d <- rbind(d, c(xs, total.potential))
    }
    names(d) <- c(paste0("x", 1:N), "p")
    cat("p(x)\n");print(rbind(head(d),tail(d)))

    # 同時分布の和により周辺分布を求める
    ps <- matrix(nrow=N, ncol=K)
    rownames(ps) <- 1:N
    colnames(ps) <- 0:(K-1)
    z <- sum(d$p)
    for (n in 1:N) {
        for (xn in 0:(K-1)) {
            ps[n, xn + 1] <- sum(d[d[, n] == xn, ]$p) / z
        }
    }
    cat("p(xn)\n");print(ps)
    barplot(t(ps), legend=0:(K-1), xlab="xn", ylab="p(xn)")
    title("sum of joint")
}

doplot.message <- function(potential) {
    # 前方からのメッセージパッシング
    mualpha <- matrix(nrow=N, ncol=K)
    colnames(mualpha) <- 0:(K-1)
    mualpha[1, ] <- rep(1, K)
    for (n in 2:N) {  # O(N*K^2)
        mu <- c()
        for (x2 in 0:(K-1)) {  # O(K^2)
            mu <- c(mu, sum(potential(n - 1, 0:(K-1), x2) * mualpha[n - 1, ]))
        }
        mualpha[n, ] <- mu
    }
    cat("mualpha\n");print(mualpha)

    # 後方からのメッセージパッシング
    mubeta <- matrix(nrow=N, ncol=K)
    colnames(mubeta) <- 0:(K-1)
    mubeta[N, ] <- rep(1, K)
    for (n in (N-1):1) {  # O(N*K^2)
        mu <- c()
        for (x1 in 0:(K-1)) {  # O(K^2)
            mu <- c(mu, sum(potential(n, x1, 0:(K-1)) * mubeta[n + 1, ]))
        }
        mubeta[n, ] <- mu
    }
    cat("mubeta\n");print(mubeta)

    # メッセージに基づく周辺分布の計算
    ps <- matrix(nrow=N, ncol=K)
    rownames(ps) <- 1:N
    colnames(ps) <- 0:(K-1)
    for (n in 1:N) {
        p <- mualpha[n, ] * mubeta[n, ]
        z <- sum(p)
        ps[n, ] <- p / z
    }
    cat("p(xn)\n");print(ps)
    barplot(t(ps), legend=0:(K-1), xlab="xn", ylab="p(xn)")
    title("message passing")
}

doplot.joint(potential1)
doplot.message(potential1)
doplot.joint(potential2)
doplot.message(potential2)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0