LoginSignup
0
4

More than 3 years have passed since last update.

事前分布に依存しないベイズ推定&モデル比較

Last updated at Posted at 2019-10-24

概要・導入

ベイズ統計と頻度主義に基づく統計とを比較したとき、ベイズ統計の欠点として、事後分布が事前分布(prior)に依存することが挙げられる。
これは特に標本数が少ないとき(尤度関数のピークが弱いとき)に顕著になる。
この記事では、arXiv:1910.06646の論文で紹介されている、ベイズ統計における事前分布に依存しないモデル比較の方法についてまとめる。

理論

検証したいモデルを $\mathcal{M}$ ,モデルのパラメータの組を $\Theta=(\theta_1,\theta_2,\dots)$ とおく。
特に、$\Theta$を興味があるパラメータの組 $x$ とそうでないもの(nuisance パラメータ) $\psi$ に分ける:$\Theta\to (x,\psi)$
与えられたデータ $d$ に対し、モデル $\mathcal{M}$ の尤度関数は

\newcommand{\calM}{\mathcal{M}}
\newcommand{\calL}{\mathcal{L}}
{\calL}_{\calM}(\Theta) = {\calL}_{\calM}(x,\psi) = p(d|x,\psi,\calM)

と与えられているとする。

今、 $x$ に対して何かデータ $d$ を用いた信頼区間 or 制限を与えることを考える。
特に、$x=x_0$ なのかそうでないかを考えたいような場合、単純に考えると、これは $x=x_0$ と制限したモデル( $\mathcal{M}_{x_0}$ とする)と $\mathcal{M}$ とのベイズファクター $B$ (Evidence の比)を見れば良い:

\begin{align*}
B = \frac{Z}{Z_{x_0}} &= \frac{\int d\Theta\ {\calL}_{\calM}(\Theta)\ \pi(\Theta|\calM)}{\int d\psi\ \calL_{\calM_{x_0}}(\psi)\ \pi(\psi|\calM_{x_0})}
\end{align*}

ただしここで

\calL_{\calM_{x_0}}(\psi) \equiv \mathcal{L}_{\calM}(x_0,\psi)

とした。

しかし、これは知りたいパラメータである $x$ についての事前分布のとり方( $\pi(\Theta|\calM)$ に含まれている)に依存する。
そこで、 $Z$ についても特に $x$ をある値に制限することを考え( このモデルを$\mathcal{M}_x$ とする)、この場合のベイズファクターを考えることにする:

\mathcal{R}(x,x_0|d) \equiv B_{x,x_0} = \frac{Z_x}{Z_{x_0}} = \frac{\int d\psi\ \mathcal{L}(x,\psi)\ \pi(\psi|\calM_x)}{\int d\psi\ \mathcal{L}(x_0,\psi)\ \pi(\psi|\calM_{x_0})}

この時 $\mathcal{R}(x,x_0|d)$ を "relative belief updating ratio"(相対信頼度更新比?) や "shape distortion function"(形状歪曲関数?) と呼ぶ。
$\mathcal{R}(x,x_0|d)$ は $x$ に関する事前分布なしに定まっているので、この意味で事前分布依存性がない(もちろん他のパラメータに対しての事前分布依存性はある)。

実際に$\mathcal{R}(x,x_0|d)$ の値を評価する場合は、Evidence $Z_x,Z_{x_0}$ を計算する代わりに、
ベイズファクターの満たす関係式1

\frac{p(\mathcal{M_x}|d)}{p({{\calM}_{x_0}}|d)} = B_{x,x_0}\frac{\pi(\mathcal{M_{x}})}{\pi(\calM_{x_0})}

及び

\begin{align}
p({\calM}_x|d) &= p(x|d,\calM)\ p(\calM|d)\\
\pi({\calM}_x) &= \pi(x|\calM)\ \pi(\calM)
\end{align}

と $x$ によらない部分をくくり出せることを用いて導ける

\mathcal{R}(x,x_0|d) = \frac{\frac{p(x|d,M)}{\pi(x|\calM)}}{\frac{p(x_0|d,\calM)}{\pi(x_0|\calM)}}

という表式が便利。

例1:ガウス分布の混合比推定

例として、でかいガウス分布のバックグラウンドにちょこっとだけシグナル成分のガウシアンが乗ったようなデータを考える:

p(x|s) \sim s \mathcal{G}[x;\mu=2,\sigma=0.1] + (1-s) \mathcal{G}[x;\mu=0,\sigma=1] 

実際にプロットしてみるとこんな感じ:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import norm
from scipy.special import logsumexp

n_tot = 1000
n_sig = 50
n_bg = n_tot - n_sig

s_th = n_sig / n_tot

loc_sig = 2
scale_sig = 0.1

data_bg = norm.rvs(size=n_bg)
data_sig = norm.rvs(loc=loc_sig,scale=scale_sig,size=n_sig)
data = pd.Series(np.concatenate([data_bg,data_sig]))

data.hist(bins=64)
plt.show()

image.png

このような分布に対して、異なる事前分布(flat, logflat, halfcauchy)を用いて混合比 s_th を推定してみる。
推定に当たっては、EnsembleSampler(単純なMetropolis-Hasting法と比べてステップ幅をいじらなくてもいい感じにサンプリングができる)が便利なemceeを使ってみる。
s_th=0.05くらいならばどちらの事前分布を使ってもそこまで違いはない:

from emcee import EnsembleSampler
import sys

#epsilon = sys.float_info.epsilon
epsilon = 0


#### MCMC hyperparameters ####
nwalkers = 64
dim = 1
threads = 32
n_burnin_repeat = 3
n_burnin = 500
n_mcmc = 2000


#### configration of the prior ####
#loc_halfcauchy = 0
#scale_halfcauchy = 1


def lnlike(param):
    s = param
    p_bg = norm.pdf(data)
    p_sig = norm.pdf(data,loc=loc_sig,scale=scale_sig)
    p_tot = s * p_sig + (1-s) * p_bg
    return np.sum(np.log(p_tot))


def lnprior_flat(param):
    s = param
    if (s < epsilon) or (s > 1):
        return -np.inf
    else:
        return 0


def lnprior_logflat(param):
    s = param
    if (s < epsilon) or (s > 1):
        return -np.inf
    else:
        return -np.log(s)  # prior(s) = 1/x


def lnprior_halfcauchy(param):
    s = param
    if (s < epsilon) or (s > 1):
        return -np.inf
    else:
        return halfcauchy.logpdf(s)


def lnpost_flat(param):
    lnp = lnprior_flat(param)
    if lnp > -np.inf:
        return lnlike(param) + lnp
    else:
        return lnp


def lnpost_logflat(param):
    lnp = lnprior_logflat(param)
    if lnp > -np.inf:
        return lnlike(param) + lnp
    else:
        return lnp


def lnpost_halfcauchy(param):
    lnp = lnprior_halfcauchy(param)
    if lnp > -np.inf:
        return lnlike(param) + lnp
    else:
        return lnp


sampler_flat = EnsembleSampler(
    nwalkers = nwalkers,
    dim = 1,
    lnpostfn = lnpost_flat,
    threads=threads
)

sampler_logflat = EnsembleSampler(
    nwalkers = nwalkers,
    dim = 1,
    lnpostfn = lnpost_logflat,
    threads = threads
)

sampler_halfcauchy = EnsembleSampler(
    nwalkers = nwalkers,
    dim = 1,
    lnpostfn = lnpost_halfcauchy,
    threads = threads
)


pos0_flat = norm.rvs(loc=s_th,scale=0.001,size=nwalkers)[:,np.newaxis]
for i in range(n_burnin_repeat):
    pos0_flat, _, _ = sampler_flat.run_mcmc(pos0_flat,n_burnin)
    sampler_flat.reset()
sampler_flat.run_mcmc(pos0_flat,n_burnin)
res_flat = pd.Series(sampler_flat.flatchain[:,0])


pos0_logflat = norm.rvs(loc=s_th,scale=0.001,size=nwalkers)[:,np.newaxis]
for i in range(n_burnin_repeat):
    pos0_logflat, _, _ = sampler_logflat.run_mcmc(pos0_logflat,n_burnin)
    sampler_logflat.reset()
sampler_logflat.run_mcmc(pos0_logflat,n_burnin)
res_logflat = pd.Series(sampler_logflat.flatchain[:,0])


pos0_halfcauchy = norm.rvs(loc=s_th,scale=0.001,size=nwalkers)[:,np.newaxis]
for i in range(n_burnin_repeat):
    pos0_halfcauchy, _, _ = sampler_halfcauchy.run_mcmc(pos0_halfcauchy,n_burnin)
    sampler_halfcauchy.reset()
sampler_halfcauchy.run_mcmc(pos0_halfcauchy,n_burnin)
res_halfcauchy = pd.Series(sampler_halfcauchy.flatchain[:,0])

出力:s_th=0.05

ress = np.array([res_flat,res_logflat,res_halfcauchy])
bins = np.logspace(np.log10(1e-2),np.log10(ress.max()),256)
res_flat.hist(bins=bins,histtype="step")
res_logflat.hist(bins=bins,histtype="step")
res_halfcauchy.hist(bins=bins,histtype="step")

plt.legend(["flat","logflat","halfcauchy"])
plt.xscale("log")
plt.yscale("log")
plt.show()

事後分布の推定結果:

image.png

このくらいだとどれもそんなに変わらない。

しかし、n_tot = 1000, n_sig = 10 (s_th = 0.01)くらいにすると、事前分布の依存性が出てくる:
出力:s_th=0.01
image.png
拡大図
image.png

さてこの時、 $\mathcal{R}(x,x_0|d)$ を求めてみると:

hist_flat = np.histogram(res_flat,bins,density=True)  # hist(n), bins_edges(n+1)
hist_logflat = np.histogram(res_logflat,bins,density=True)
hist_halfcauchy = np.histogram(res_halfcauchy,bins,density=True)

prior_flat = lambda x: 1
prior_logflat = lambda x: 1/x
prior_halfcauchy = lambda x: halfcauchy.pdf(x,loc=0,scale=scale_halfcauchy)

# relative belief updating ratio (RBUR)
def rbur(hist,prior_func):
    posterior,bins = hist
    r = 0.5
    x = np.exp(r*np.log(bins[1:]) + (1-r)*np.log(bins[:-1]) )
    post_prior = posterior/prior_func(x)
    return post_prior/post_prior[0],x

rbur_flat,x_flat = rbur(hist_flat,prior_flat)
rbur_logflat,_x = rbur(hist_logflat,prior_logflat)
rbur_halfcauchy,_x = rbur(hist_halfcauchy,prior_halfcauchy)

plt.plot(x_flat,rbur_flat)
plt.plot(_x,rbur_logflat)
plt.plot(_x,rbur_halfcauchy)
plt.legend(["flat","logflat","halfcauchy"])
for i in range(1,8,2):
    plt.plot(_x,np.ones_like(_x)*np.exp(-i),"--",c="gray",linewidth=1)
plt.xscale("log")
plt.yscale("log")

plt.xlabel("s")
plt.ylabel(r"$\mathcal{R}$")

こうなる:

image.png

確かに $\mathcal{R}$ には事前分布の依存性がなさそう。
もっと頑張ってs_th=0.001にしてみるとこんな感じ:

image.png

image.png

事前分布の違いによってs==0付近での統計が足らずに若干ふらつくが、それでも十分合っている。

例2:ガウス分布の混合比&分散推定

上の分布で、更に分布の分散も局外パラメータ (nuisance parameter) として振ってみることにする。
分散の事前分布はいずれも半コーシー分布で固定しとく。
今回もn_sig = 1 としておく。

サンプリング用コード

一次元の時とほとんど同じ。

from emcee import EnsembleSampler
import sys

#epsilon = sys.float_info.epsilon
epsilon = 0


#### MCMC hyperparameters ####
nwalkers = 64
dim = 1
threads = 32
n_burnin_repeat = 3
n_burnin = 500
n_mcmc = 2000


#### configration of the prior ####
#loc_halfcauchy = 0
scale_halfcauchy = 0.5


def lnlike_2d(param):
    s,scale = param
    p_bg = norm.pdf(data)
    p_sig = norm.pdf(data,loc=loc_sig,scale=scale)
    p_tot = s * p_sig + (1-s) * p_bg
    return np.sum(np.log(p_tot))


def lnpost_flat_2d(param):
    s,scale = param
    lnp = lnprior_flat(s) + lnprior_halfcauchy(scale) 
    if lnp > -np.inf:
        return lnlike_2d(param) + lnp
    else:
        return lnp


def lnpost_logflat_2d(param):
    s,scale = param
    lnp = lnprior_logflat(s) + lnprior_halfcauchy(scale)
    if lnp > -np.inf:
        return lnlike_2d(param) + lnp
    else:
        return lnp


def lnpost_halfcauchy_2d(param):
    s,scale = param
    lnp = lnprior_halfcauchy(s) + lnprior_halfcauchy(scale)
    if lnp > -np.inf:
        return lnlike_2d(param) + lnp
    else:
        return lnp


sampler_flat_2d = EnsembleSampler(
    nwalkers = nwalkers,
    dim = 2,
    lnpostfn = lnpost_flat_2d,
    threads=threads
)

sampler_logflat_2d = EnsembleSampler(
    nwalkers = nwalkers,
    dim = 2,
    lnpostfn = lnpost_logflat_2d,
    threads = threads
)

sampler_halfcauchy_2d = EnsembleSampler(
    nwalkers = nwalkers,
    dim = 2,
    lnpostfn = lnpost_halfcauchy_2d,
    threads = threads
)


pos0_flat_2d = norm.rvs(loc=s_th,scale=0.001,size=(nwalkers,2))
for i in range(n_burnin_repeat):
    pos0_flat_2d, _, _ = sampler_flat_2d.run_mcmc(pos0_flat_2d,n_burnin)
    sampler_flat_2d.reset()
sampler_flat_2d.run_mcmc(pos0_flat_2d,n_burnin)
res_flat_2d = pd.Series(sampler_flat_2d.flatchain[:,0])


pos0_logflat_2d = norm.rvs(loc=s_th,scale=0.001,size=(nwalkers,2))
for i in range(n_burnin_repeat):
    pos0_logflat_2d, _, _ = sampler_logflat_2d.run_mcmc(pos0_logflat_2d,n_burnin)
    sampler_logflat_2d.reset()
sampler_logflat_2d.run_mcmc(pos0_logflat_2d,n_burnin)
res_logflat_2d = pd.Series(sampler_logflat_2d.flatchain[:,0])


pos0_halfcauchy_2d = norm.rvs(loc=s_th,scale=0.001,size=(nwalkers,2))
for i in range(n_burnin_repeat):
    pos0_halfcauchy_2d, _, _ = sampler_halfcauchy_2d.run_mcmc(pos0_halfcauchy_2d,n_burnin)
    sampler_halfcauchy_2d.reset()
sampler_halfcauchy_2d.run_mcmc(pos0_halfcauchy_2d,n_burnin)
res_halfcauchy_2d = pd.Series(sampler_halfcauchy_2d.flatchain[:,0])

結果のプロット:

一次元の場合と同じようにプロットしてみる。ただし、局外パラメータの情報は落として、$p(s|d)$ だけ見る。

ress_2d = np.array([res_flat_2d,res_logflat_2d,res_halfcauchy_2d])
bins = np.logspace(np.log10(1e-4),np.log10(ress_2d.max()),64)
res_flat_2d.hist(bins=bins,histtype="step")
res_logflat_2d.hist(bins=bins,histtype="step")
res_halfcauchy_2d.hist(bins=bins,histtype="step")

plt.legend(["flat","logflat","halfcauchy"],loc="upper left")
plt.xscale("log")
plt.yscale("log")
plt.show()

image.png

散布図:

局外パラメータの振る舞いについても気になるので一応見ておく:

for sampler in (sampler_flat_2d,sampler_logflat_2d,sampler_halfcauchy_2d):
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("s")
    plt.ylabel("scale")
    plt.scatter(*sampler.flatchain.T,s=1,linewidths=0,c=sampler.flatlnprobability)
    plt.colorbar()
    plt.show()

image.png
image.png
image.png

Ralative belief updating function:

問題のやつ。局外パラメータが合っても確かに事前分布依存性が取り除けているか?

hist_flat_2d = np.histogram(res_flat_2d,bins,density=True)  # hist(n), bins_edges(n+1)
hist_logflat_2d = np.histogram(res_logflat_2d,bins,density=True)
hist_halfcauchy_2d = np.histogram(res_halfcauchy_2d,bins,density=True)

prior_flat = lambda x: 1
prior_logflat = lambda x: 1/x
prior_halfcauchy = lambda x: halfcauchy.pdf(x,loc=0,scale=scale_halfcauchy)

# relative belief updating ratio (RBUR)
def rbur(hist,prior_func):
    posterior,bins = hist
    r = 0.5
    x = np.exp(r*np.log(bins[1:]) + (1-r)*np.log(bins[:-1]) )
    post_prior = posterior/prior_func(x)
    return x,post_prior/post_prior[0]

rbur_flat_2d = rbur(hist_flat_2d,prior_flat)
rbur_logflat_2d = rbur(hist_logflat_2d,prior_logflat)
rbur_halfcauchy_2d = rbur(hist_halfcauchy_2d,prior_halfcauchy)

plt.plot(*rbur_flat_2d)
plt.plot(*rbur_logflat_2d)
plt.plot(*rbur_halfcauchy_2d)
plt.legend(["flat","logflat","halfcauchy"])
for i in range(1,8,2):
    plt.plot(rbur_flat_2d[0],np.ones_like(rbur_flat_2d[0])*np.exp(-i),"--",c="gray",linewidth=1)
plt.xscale("log")
plt.yscale("log")

plt.xlabel("s")
plt.ylabel(r"$\mathcal{R}$")

結果:
image.png

やはり低サンプリング数に伴う誤差はあるが、確かに事前分布に依存しない様子が見て取れる。

結局なんなのよ

結局のところ $\mathcal{R}$ を使って頻度主義統計(Frequentist)における尤度比検定のベイジアン的な一般化をしてるだけ。

まとめ

尤度比もどきを使うことで事前分布によらない?ベイジアン的な解析ができる。


  1. ただし、ここで $x,\,\psi$ が互いに独立だと仮定( $\pi(x,\psi|\calM)=\pi(x|\calM)\ \pi(x|\calM)$ )していることに注意。互いに独立でない場合はこの式は成り立たない。 

0
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
4