概要・導入
ベイズ統計と頻度主義に基づく統計とを比較したとき、ベイズ統計の欠点として、事後分布が事前分布(prior)に依存することが挙げられる。
これは特に標本数が少ないとき(尤度関数のピークが弱いとき)に顕著になる。
この記事では、arXiv:1910.06646の論文で紹介されている、ベイズ統計における事前分布に依存しないモデル比較の方法についてまとめる。
理論
検証したいモデルを $\mathcal{M}$ ,モデルのパラメータの組を $\Theta=(\theta_1,\theta_2,\dots)$ とおく。
特に、$\Theta$を興味があるパラメータの組 $x$ とそうでないもの(nuisance パラメータ) $\psi$ に分ける:$\Theta\to (x,\psi)$
与えられたデータ $d$ に対し、モデル $\mathcal{M}$ の尤度関数は
\newcommand{\calM}{\mathcal{M}}
\newcommand{\calL}{\mathcal{L}}
{\calL}_{\calM}(\Theta) = {\calL}_{\calM}(x,\psi) = p(d|x,\psi,\calM)
と与えられているとする。
今、 $x$ に対して何かデータ $d$ を用いた信頼区間 or 制限を与えることを考える。
特に、$x=x_0$ なのかそうでないかを考えたいような場合、単純に考えると、これは $x=x_0$ と制限したモデル( $\mathcal{M}_{x_0}$ とする)と $\mathcal{M}$ とのベイズファクター $B$ (Evidence の比)を見れば良い:
\begin{align*}
B = \frac{Z}{Z_{x_0}} &= \frac{\int d\Theta\ {\calL}_{\calM}(\Theta)\ \pi(\Theta|\calM)}{\int d\psi\ \calL_{\calM_{x_0}}(\psi)\ \pi(\psi|\calM_{x_0})}
\end{align*}
ただしここで
\calL_{\calM_{x_0}}(\psi) \equiv \mathcal{L}_{\calM}(x_0,\psi)
とした。
しかし、これは知りたいパラメータである $x$ についての事前分布のとり方( $\pi(\Theta|\calM)$ に含まれている)に依存する。
そこで、 $Z$ についても特に $x$ をある値に制限することを考え( このモデルを$\mathcal{M}_x$ とする)、この場合のベイズファクターを考えることにする:
\mathcal{R}(x,x_0|d) \equiv B_{x,x_0} = \frac{Z_x}{Z_{x_0}} = \frac{\int d\psi\ \mathcal{L}(x,\psi)\ \pi(\psi|\calM_x)}{\int d\psi\ \mathcal{L}(x_0,\psi)\ \pi(\psi|\calM_{x_0})}
この時 $\mathcal{R}(x,x_0|d)$ を "relative belief updating ratio"(相対信頼度更新比?) や "shape distortion function"(形状歪曲関数?) と呼ぶ。
$\mathcal{R}(x,x_0|d)$ は $x$ に関する事前分布なしに定まっているので、この意味で事前分布依存性がない(もちろん他のパラメータに対しての事前分布依存性はある)。
実際に$\mathcal{R}(x,x_0|d)$ の値を評価する場合は、Evidence $Z_x,Z_{x_0}$ を計算する代わりに、
ベイズファクターの満たす関係式1
\frac{p(\mathcal{M_x}|d)}{p({{\calM}_{x_0}}|d)} = B_{x,x_0}\frac{\pi(\mathcal{M_{x}})}{\pi(\calM_{x_0})}
及び
\begin{align}
p({\calM}_x|d) &= p(x|d,\calM)\ p(\calM|d)\\
\pi({\calM}_x) &= \pi(x|\calM)\ \pi(\calM)
\end{align}
と $x$ によらない部分をくくり出せることを用いて導ける
\mathcal{R}(x,x_0|d) = \frac{\frac{p(x|d,M)}{\pi(x|\calM)}}{\frac{p(x_0|d,\calM)}{\pi(x_0|\calM)}}
という表式が便利。
例1:ガウス分布の混合比推定
例として、でかいガウス分布のバックグラウンドにちょこっとだけシグナル成分のガウシアンが乗ったようなデータを考える:
p(x|s) \sim s \mathcal{G}[x;\mu=2,\sigma=0.1] + (1-s) \mathcal{G}[x;\mu=0,\sigma=1]
実際にプロットしてみるとこんな感じ:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import norm
from scipy.special import logsumexp
n_tot = 1000
n_sig = 50
n_bg = n_tot - n_sig
s_th = n_sig / n_tot
loc_sig = 2
scale_sig = 0.1
data_bg = norm.rvs(size=n_bg)
data_sig = norm.rvs(loc=loc_sig,scale=scale_sig,size=n_sig)
data = pd.Series(np.concatenate([data_bg,data_sig]))
data.hist(bins=64)
plt.show()
このような分布に対して、異なる事前分布(flat
, logflat
, halfcauchy
)を用いて混合比 s_th
を推定してみる。
推定に当たっては、EnsembleSampler
(単純なMetropolis-Hasting法と比べてステップ幅をいじらなくてもいい感じにサンプリングができる)が便利なemcee
を使ってみる。
s_th=0.05
くらいならばどちらの事前分布を使ってもそこまで違いはない:
from emcee import EnsembleSampler
import sys
#epsilon = sys.float_info.epsilon
epsilon = 0
#### MCMC hyperparameters ####
nwalkers = 64
dim = 1
threads = 32
n_burnin_repeat = 3
n_burnin = 500
n_mcmc = 2000
#### configration of the prior ####
#loc_halfcauchy = 0
#scale_halfcauchy = 1
def lnlike(param):
s = param
p_bg = norm.pdf(data)
p_sig = norm.pdf(data,loc=loc_sig,scale=scale_sig)
p_tot = s * p_sig + (1-s) * p_bg
return np.sum(np.log(p_tot))
def lnprior_flat(param):
s = param
if (s < epsilon) or (s > 1):
return -np.inf
else:
return 0
def lnprior_logflat(param):
s = param
if (s < epsilon) or (s > 1):
return -np.inf
else:
return -np.log(s) # prior(s) = 1/x
def lnprior_halfcauchy(param):
s = param
if (s < epsilon) or (s > 1):
return -np.inf
else:
return halfcauchy.logpdf(s)
def lnpost_flat(param):
lnp = lnprior_flat(param)
if lnp > -np.inf:
return lnlike(param) + lnp
else:
return lnp
def lnpost_logflat(param):
lnp = lnprior_logflat(param)
if lnp > -np.inf:
return lnlike(param) + lnp
else:
return lnp
def lnpost_halfcauchy(param):
lnp = lnprior_halfcauchy(param)
if lnp > -np.inf:
return lnlike(param) + lnp
else:
return lnp
sampler_flat = EnsembleSampler(
nwalkers = nwalkers,
dim = 1,
lnpostfn = lnpost_flat,
threads=threads
)
sampler_logflat = EnsembleSampler(
nwalkers = nwalkers,
dim = 1,
lnpostfn = lnpost_logflat,
threads = threads
)
sampler_halfcauchy = EnsembleSampler(
nwalkers = nwalkers,
dim = 1,
lnpostfn = lnpost_halfcauchy,
threads = threads
)
pos0_flat = norm.rvs(loc=s_th,scale=0.001,size=nwalkers)[:,np.newaxis]
for i in range(n_burnin_repeat):
pos0_flat, _, _ = sampler_flat.run_mcmc(pos0_flat,n_burnin)
sampler_flat.reset()
sampler_flat.run_mcmc(pos0_flat,n_burnin)
res_flat = pd.Series(sampler_flat.flatchain[:,0])
pos0_logflat = norm.rvs(loc=s_th,scale=0.001,size=nwalkers)[:,np.newaxis]
for i in range(n_burnin_repeat):
pos0_logflat, _, _ = sampler_logflat.run_mcmc(pos0_logflat,n_burnin)
sampler_logflat.reset()
sampler_logflat.run_mcmc(pos0_logflat,n_burnin)
res_logflat = pd.Series(sampler_logflat.flatchain[:,0])
pos0_halfcauchy = norm.rvs(loc=s_th,scale=0.001,size=nwalkers)[:,np.newaxis]
for i in range(n_burnin_repeat):
pos0_halfcauchy, _, _ = sampler_halfcauchy.run_mcmc(pos0_halfcauchy,n_burnin)
sampler_halfcauchy.reset()
sampler_halfcauchy.run_mcmc(pos0_halfcauchy,n_burnin)
res_halfcauchy = pd.Series(sampler_halfcauchy.flatchain[:,0])
出力:s_th=0.05
ress = np.array([res_flat,res_logflat,res_halfcauchy])
bins = np.logspace(np.log10(1e-2),np.log10(ress.max()),256)
res_flat.hist(bins=bins,histtype="step")
res_logflat.hist(bins=bins,histtype="step")
res_halfcauchy.hist(bins=bins,histtype="step")
plt.legend(["flat","logflat","halfcauchy"])
plt.xscale("log")
plt.yscale("log")
plt.show()
事後分布の推定結果:
このくらいだとどれもそんなに変わらない。
しかし、n_tot = 1000, n_sig = 10 (s_th = 0.01)
くらいにすると、事前分布の依存性が出てくる:
出力:s_th=0.01
拡大図
さてこの時、 $\mathcal{R}(x,x_0|d)$ を求めてみると:
hist_flat = np.histogram(res_flat,bins,density=True) # hist(n), bins_edges(n+1)
hist_logflat = np.histogram(res_logflat,bins,density=True)
hist_halfcauchy = np.histogram(res_halfcauchy,bins,density=True)
prior_flat = lambda x: 1
prior_logflat = lambda x: 1/x
prior_halfcauchy = lambda x: halfcauchy.pdf(x,loc=0,scale=scale_halfcauchy)
# relative belief updating ratio (RBUR)
def rbur(hist,prior_func):
posterior,bins = hist
r = 0.5
x = np.exp(r*np.log(bins[1:]) + (1-r)*np.log(bins[:-1]) )
post_prior = posterior/prior_func(x)
return post_prior/post_prior[0],x
rbur_flat,x_flat = rbur(hist_flat,prior_flat)
rbur_logflat,_x = rbur(hist_logflat,prior_logflat)
rbur_halfcauchy,_x = rbur(hist_halfcauchy,prior_halfcauchy)
plt.plot(x_flat,rbur_flat)
plt.plot(_x,rbur_logflat)
plt.plot(_x,rbur_halfcauchy)
plt.legend(["flat","logflat","halfcauchy"])
for i in range(1,8,2):
plt.plot(_x,np.ones_like(_x)*np.exp(-i),"--",c="gray",linewidth=1)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("s")
plt.ylabel(r"$\mathcal{R}$")
こうなる:
確かに $\mathcal{R}$ には事前分布の依存性がなさそう。
もっと頑張ってs_th=0.001
にしてみるとこんな感じ:
事前分布の違いによってs==0
付近での統計が足らずに若干ふらつくが、それでも十分合っている。
例2:ガウス分布の混合比&分散推定
上の分布で、更に分布の分散も局外パラメータ (nuisance parameter) として振ってみることにする。
分散の事前分布はいずれも半コーシー分布で固定しとく。
今回もn_sig = 1
としておく。
サンプリング用コード
一次元の時とほとんど同じ。
from emcee import EnsembleSampler
import sys
#epsilon = sys.float_info.epsilon
epsilon = 0
#### MCMC hyperparameters ####
nwalkers = 64
dim = 1
threads = 32
n_burnin_repeat = 3
n_burnin = 500
n_mcmc = 2000
#### configration of the prior ####
#loc_halfcauchy = 0
scale_halfcauchy = 0.5
def lnlike_2d(param):
s,scale = param
p_bg = norm.pdf(data)
p_sig = norm.pdf(data,loc=loc_sig,scale=scale)
p_tot = s * p_sig + (1-s) * p_bg
return np.sum(np.log(p_tot))
def lnpost_flat_2d(param):
s,scale = param
lnp = lnprior_flat(s) + lnprior_halfcauchy(scale)
if lnp > -np.inf:
return lnlike_2d(param) + lnp
else:
return lnp
def lnpost_logflat_2d(param):
s,scale = param
lnp = lnprior_logflat(s) + lnprior_halfcauchy(scale)
if lnp > -np.inf:
return lnlike_2d(param) + lnp
else:
return lnp
def lnpost_halfcauchy_2d(param):
s,scale = param
lnp = lnprior_halfcauchy(s) + lnprior_halfcauchy(scale)
if lnp > -np.inf:
return lnlike_2d(param) + lnp
else:
return lnp
sampler_flat_2d = EnsembleSampler(
nwalkers = nwalkers,
dim = 2,
lnpostfn = lnpost_flat_2d,
threads=threads
)
sampler_logflat_2d = EnsembleSampler(
nwalkers = nwalkers,
dim = 2,
lnpostfn = lnpost_logflat_2d,
threads = threads
)
sampler_halfcauchy_2d = EnsembleSampler(
nwalkers = nwalkers,
dim = 2,
lnpostfn = lnpost_halfcauchy_2d,
threads = threads
)
pos0_flat_2d = norm.rvs(loc=s_th,scale=0.001,size=(nwalkers,2))
for i in range(n_burnin_repeat):
pos0_flat_2d, _, _ = sampler_flat_2d.run_mcmc(pos0_flat_2d,n_burnin)
sampler_flat_2d.reset()
sampler_flat_2d.run_mcmc(pos0_flat_2d,n_burnin)
res_flat_2d = pd.Series(sampler_flat_2d.flatchain[:,0])
pos0_logflat_2d = norm.rvs(loc=s_th,scale=0.001,size=(nwalkers,2))
for i in range(n_burnin_repeat):
pos0_logflat_2d, _, _ = sampler_logflat_2d.run_mcmc(pos0_logflat_2d,n_burnin)
sampler_logflat_2d.reset()
sampler_logflat_2d.run_mcmc(pos0_logflat_2d,n_burnin)
res_logflat_2d = pd.Series(sampler_logflat_2d.flatchain[:,0])
pos0_halfcauchy_2d = norm.rvs(loc=s_th,scale=0.001,size=(nwalkers,2))
for i in range(n_burnin_repeat):
pos0_halfcauchy_2d, _, _ = sampler_halfcauchy_2d.run_mcmc(pos0_halfcauchy_2d,n_burnin)
sampler_halfcauchy_2d.reset()
sampler_halfcauchy_2d.run_mcmc(pos0_halfcauchy_2d,n_burnin)
res_halfcauchy_2d = pd.Series(sampler_halfcauchy_2d.flatchain[:,0])
結果のプロット:
一次元の場合と同じようにプロットしてみる。ただし、局外パラメータの情報は落として、$p(s|d)$ だけ見る。
ress_2d = np.array([res_flat_2d,res_logflat_2d,res_halfcauchy_2d])
bins = np.logspace(np.log10(1e-4),np.log10(ress_2d.max()),64)
res_flat_2d.hist(bins=bins,histtype="step")
res_logflat_2d.hist(bins=bins,histtype="step")
res_halfcauchy_2d.hist(bins=bins,histtype="step")
plt.legend(["flat","logflat","halfcauchy"],loc="upper left")
plt.xscale("log")
plt.yscale("log")
plt.show()
散布図:
局外パラメータの振る舞いについても気になるので一応見ておく:
for sampler in (sampler_flat_2d,sampler_logflat_2d,sampler_halfcauchy_2d):
plt.xscale("log")
plt.yscale("log")
plt.xlabel("s")
plt.ylabel("scale")
plt.scatter(*sampler.flatchain.T,s=1,linewidths=0,c=sampler.flatlnprobability)
plt.colorbar()
plt.show()
Ralative belief updating function:
問題のやつ。局外パラメータが合っても確かに事前分布依存性が取り除けているか?
hist_flat_2d = np.histogram(res_flat_2d,bins,density=True) # hist(n), bins_edges(n+1)
hist_logflat_2d = np.histogram(res_logflat_2d,bins,density=True)
hist_halfcauchy_2d = np.histogram(res_halfcauchy_2d,bins,density=True)
prior_flat = lambda x: 1
prior_logflat = lambda x: 1/x
prior_halfcauchy = lambda x: halfcauchy.pdf(x,loc=0,scale=scale_halfcauchy)
# relative belief updating ratio (RBUR)
def rbur(hist,prior_func):
posterior,bins = hist
r = 0.5
x = np.exp(r*np.log(bins[1:]) + (1-r)*np.log(bins[:-1]) )
post_prior = posterior/prior_func(x)
return x,post_prior/post_prior[0]
rbur_flat_2d = rbur(hist_flat_2d,prior_flat)
rbur_logflat_2d = rbur(hist_logflat_2d,prior_logflat)
rbur_halfcauchy_2d = rbur(hist_halfcauchy_2d,prior_halfcauchy)
plt.plot(*rbur_flat_2d)
plt.plot(*rbur_logflat_2d)
plt.plot(*rbur_halfcauchy_2d)
plt.legend(["flat","logflat","halfcauchy"])
for i in range(1,8,2):
plt.plot(rbur_flat_2d[0],np.ones_like(rbur_flat_2d[0])*np.exp(-i),"--",c="gray",linewidth=1)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("s")
plt.ylabel(r"$\mathcal{R}$")
やはり低サンプリング数に伴う誤差はあるが、確かに事前分布に依存しない様子が見て取れる。
結局なんなのよ
結局のところ $\mathcal{R}$ を使って頻度主義統計(Frequentist)における尤度比検定のベイジアン的な一般化をしてるだけ。
まとめ
尤度比もどきを使うことで事前分布によらない?ベイジアン的な解析ができる。
-
ただし、ここで $x,\,\psi$ が互いに独立だと仮定( $\pi(x,\psi|\calM)=\pi(x|\calM)\ \pi(x|\calM)$ )していることに注意。互いに独立でない場合はこの式は成り立たない。 ↩