Numpyro HMC
Numpyro HMCを用いた単純な直線回帰MCMCのコードを書きます。HMCは低棄却率のMCMCサンプリング法の一つで、特に高次元だと効率が良くなるそうです。以下の説明は自己満なので、読まずに信じてはいけません。
import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import numpy as np
import arviz
import matplotlib.pyplot as plt
import pandas as pd
ここでは、単純な直線回帰を考えます。データの誤差が正規分布するとして、
$$\mu=mx+b\\
y\sim\mathcal{N}(\mu,\sigma)$$
としましょう。
#真のパラメータ
m_true = 0.3
b_true = 0.5
sigma_true = 0.2
np.random.seed(0)
N = 50
x = np.random.uniform(-np.pi,np.pi,N)
y_obs= m_true*x+b_true+np.random.normal(0,sigma_true,N)
plt.scatter(x,y_obs)
X = np.linspace(-np.pi,np.pi,N)
plt.plot(X,m_true*X+b_true,c="red")
これに対して、モデルフィットを行います。データの誤差は無知として、正規分布すると思ってlikelihoodを書いてみましょう。
# set model
def model(X, y=None):
N = len(X)
m = numpyro.sample("m",dist.Uniform(-10,10))
b = numpyro.sample("b",dist.Uniform(-10,10))
sigma = numpyro.sample('sigma', dist.Uniform(0,5))
mu = m*X+b
numpyro.deterministic("chi2dof", jnp.sum((y-mu)**2/sigma**2)/N)
#numpyro.sample("y",dist.Normal(mu,sigma),obs=y)
loglike = -0.5*jnp.sum(((y-mu)**2)/(sigma**2)) - N*jnp.log(sigma) - 0.5*N*jnp.log(2*jnp.pi)
numpyro.factor("loglike", loglike)
Pyro系はlog likelihoodを明示して書く必要がないようですが、numpyro.factorで指定することも可能です。
ここでは、データの各点の誤差が独立な正規分布すると仮定しているので、
$$\ln{L}=\ln\Pi^{N}_{i}\frac{1}{(\sqrt{2\pi\sigma_i})^{N}}\exp{-\frac{1}{2}{\left(\frac{y_i-\mu}{\sigma_i}\right)^2}}$$
としています。
# setting up the sampler
nuts_kernel = NUTS(model)
num_warmup, num_samples = 300, 3000
mcmc = MCMC(nuts_kernel,num_warmup=num_warmup, num_samples=num_samples,num_chains=1)
num_warmupは、HMCの各linkにおけるサンプリング長を決めるためのburn-inのようなもの(?)です。
# sampling
mcmc.run(random.PRNGKey(0), x, y = y_obs)
mcmc.print_summary()
sample = pd.DataFrame(mcmc.get_samples(group_by_chain=False))
mean std median 5.0% 95.0% n_eff r_hat
b 0.47 0.03 0.47 0.42 0.51 4438.33 1.00
m 0.29 0.02 0.29 0.26 0.32 4707.75 1.00
sigma 0.19 0.02 0.19 0.16 0.23 3559.17 1.00
Number of divergences: 0
fig = arviz.plot_trace(mcmc,show=True,figsize=(12,6),compact=False)
import corner
fig = corner.corner(sample,truths=(b_true,1,m_true,sigma_true),show_titles=True,truth_color="red")
今回はHMCの恩恵を感じることはありませんでしたが、Numpyro HMCの使い方を少しは把握したと思います。