はじめに
MCMCは関数f(θ)からサンプリングをする方法として広く用いられている方法です。
サンプリングの目的は色々あると思いますが、推定の枠組みではサンプリングしたデータから平均や分散を求めることが主な目的となります。確率密度関数が分かっているのであればわざわざサンプリングせずとも平均や分散を求めることができますが、確率密度関数は定義より積分値が1になっていないといけません。ところがこの条件は比較的強く、"ある関数が求めたい密度関数に比例しているのはわかっているけれども正規化されていない"という状況は多く発生します。そのようなときに威力を発揮するのがサンプリングによる推定で、MCMCはそのサンプリング方法の一種です。
"ある関数が求めたい密度関数に比例しているのはわかっているけれども正規化されていない"という状況が最も顕著なのがベイズを用いた推定です。ベイズ主義では事後分布を用いて推定をしますが、この事後分布は多くの場合正規化されていません。このような場合でもMCMCを使えば問題なく推定を遂行できるため広く使われている、というわけです。
私は生物統計の研究室に所属していますが、そこでよく使われる言語はRです。MCMCをRで動かす場合、Rstanと呼ばれるパッケージ(背景ではCが動いて計算を行います)を使いますが、かなりブラックボックス化されているため実際どのように計算をしているかはよくわかりません。今回、勉強がてらMCMCの一種であるギブスサンプリングによるサンプリングのコードを書いてみましたのでここで共有しようと覆います。なお、ベイズ的な線形回帰の推定は既知とします。
推定内容
今回は線形の関係にあるx,yのデータからその傾きをベイズ的に推定する、という設定にしました。簡単のため、切片はわかっていることにしました。データは$y=3x+4$に従うように作っていますので、$y=ax+4$のaを推定するような問題設定です。
データの生成
データは$y=3x+4$にノイズをかける形で生成しました。ノイズは平均0、分散2に従うようにしています。
################################################################################
############################ Data Generation ###################################
x <- seq(0, 10, 0.1)
y <- 3*x + 4
y <- y + rnorm(length(x), 0, 4)
plot(x, y, pch=19)
こちらが生成されたデータです。ノイズはあるものの、線形の関係ははっきりと確認できます。
ベイズの定理によるモデル化
ベイズの定理によると次が成り立つことが知られています。ここでθはパラメータ、Dはデータを表しています。
$$ p(θ|D)∝p(D|θ)p(θ)$$
今回は$θ=a$ですので、より正確には以下のようになります。
$$ p(a|D)∝p(D|a)p(a)$$
さて、$p(D|a)$は尤度を表しており、データは平均0、分散2のノイズに従うので
$$p(D|a)=\Pi_{i=1}^{N}\frac{1}{\sqrt{2\pi\sigma^2}}\exp\Big(\frac{(y_{i}-(ax_{i}+4))^2}{2\sigma^2}\Big)$$
となります(詳しくはprmlなどの参考書をご覧ください)。なお、今回はaが推定値であり、その値はまだわからないため、ここではaとおいています。つまり尤度はaの関数になります。
次にaの事前分布であるp(a)を決める必要がありますが、特に事前情報はありませんので裾の広い正規分布(分散2000000)に従うものとしました。
$$p(a)=\frac{1}{\sqrt{2\pi×2000000}}\exp\Big(\frac{a^2}{2×2000000}\Big)$$
さて、尤度はaの関数であり、事前分布もaの関数ですので事後分布もaの関数となります。つまり
$$p(a|D)∝f(a)=\Pi_{i=1}^{N}\frac{1}{\sqrt{2\pi\sigma^2}}\exp\Big(\frac{(y_{i}-(ax_{i}+4))^2}{2\sigma^2}\Big)×\frac{1}{\sqrt{2\pi×2000000}}\exp\Big(\frac{a^2}{2×2000000}\Big)$$
です。このf(a)は事後分布に比例していますので、ここからサンプリングをすればよいことになります。
尤度関数、事前分布、事後分布の関数化
さて、上で定義した尤度関数、事前分布、事後分布(正確には事後分布に比例する関数f)をコードに書き下していきます。特に説明は必要ないと思いますが、上から尤度関数、事前分布、事後分布の順番になっています。
################################################################################
########################## Function Definition #################################
# likelihood function
like <- function(a) {
sum <- 1
for(i in 1:length(x)){
sum <- sum*(1/sqrt(2*pi*4)*exp(-(y[i] - a*x[i] - 4)^2/(2*4)))
}
return(sum)
}
# prior distribution
prio <- function(b) {
1/sqrt(2*pi*2000000)*exp(-b^2/(2*2000000))
}
# posterior distribution
dist <- function(c) {
return(like(c)*prio(c))
}
# confirmation of function form
x.2 <- seq(2.5, 3.5, 0.001)
plot(x.2, dist(x.2), ylab="posterior distribution")
最後に事後分布に比例する関数f(x)の形をグラフ化しています。その結果がこちら↓です。
この関数は正規化されていないことが容易にこのグラフからも読み取れるかと思います。この関数の期待値を求めるためにMCMCでサンプリングをしていきます。
MCMC
最後はMCMCです。MCMCにはいくつか種類がありますが、ここではメトロポリスヘイスティングと呼ばれる方法を用いました。このサンプリング方法はWikiに詳しく書いてありますのでそちらをご覧ください。
https://ja.wikipedia.org/wiki/%E3%83%A1%E3%83%88%E3%83%AD%E3%83%9D%E3%83%AA%E3%82%B9%E3%83%BB%E3%83%98%E3%82%A4%E3%82%B9%E3%83%86%E3%82%A3%E3%83%B3%E3%82%B0%E3%82%B9%E6%B3%95。
なお、サンプリングのために乱数を発生させる必要がありますが、今回は正規分布を用いました。サンプリングはsampと呼ばれるベクトルに格納されています。またサンプリングは前後で互いに相関を持つためある間隔ごとにサンプリングを行うのが普通です。100ごとに抜いたサンプルはsamp.2に格納しました。
################################################################################
################################### MCMC #######################################
# number of sampling
N.samp <- 100000
# interval
N.int <- 100
# burnout
burn <- 100
# initial value
init <- 2.0
# vector to store sampling
samp < -rep(NA, N.samp)
# vector to store sampling with certain interval
samp.2 <- rep(NA, N.samp/N.int)
samp[1] <- init
for (j in 1:length(samp)) {
# previous sampling
x1 <- samp[j]
# next candidate sampling
x2 <- rnorm(1, x1, 1)
# store sampling to samp following sampling role
alpha <- dist(x2)/dist(x1)
if (alpha >= 1) {
samp[j + 1] <- x2
} else {
p <- runif(1, 0, 1)
if (p < alpha) {
samp[j + 1] <- x2
} else {
samp[j + 1] <- x1
}
}
# store sampling to samp.2 if satisfies interval condition
if (j >= burn) {
if (j %% N.int == 0) {
samp.2[j %/% 100] <-samp[j]
}
}
}
# plot all sampling
plot(samp)
# plot extracted sampling
plot(samp.2)
# histogram of the extracted sampling
mean(samp.2)
⇓こちらが全サンプリングをプロットしたものです。サンプリング初期に少しばらついていますが、これは探索をしている状態です。不確定要素を含むためこれは取りに除くのが一般的です
⇓こちらは100サンプルごとにサンプリングした結果です。先ほどより100分の1ほどデータ数が減っています。
⇓そしてこちらがサンプル結果のヒストグラムです。事後分布の関数形に似ているのが分かると思います。
最後にこのサンプリングの平均値を求めていますが、その結果は3.055712!真値が3ですのでかなりいい値を示してくれました!