LoginSignup
2
4

More than 1 year has passed since last update.

Numpyro HMCで直線回帰

Last updated at Posted at 2022-03-01

Numpyro HMC

Numpyro HMCを用いた単純な直線回帰MCMCのコードを書きます。HMCは低棄却率のMCMCサンプリング法の一つで、特に高次元だと効率が良くなるそうです。以下の説明は自己満なので、読まずに信じてはいけません。

import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import numpy as np
import arviz
import matplotlib.pyplot as plt
import pandas as pd

ここでは、単純な直線回帰を考えます。データの誤差が正規分布するとして、
$$\mu=mx+b\\
y\sim\mathcal{N}(\mu,\sigma)$$
としましょう。

#真のパラメータ
m_true = 0.3
b_true = 0.5
sigma_true = 0.2

np.random.seed(0)
N = 50
x = np.random.uniform(-np.pi,np.pi,N)
y_obs= m_true*x+b_true+np.random.normal(0,sigma_true,N)
plt.scatter(x,y_obs)
X = np.linspace(-np.pi,np.pi,N)
plt.plot(X,m_true*X+b_true,c="red")

image.png
これに対して、モデルフィットを行います。データの誤差は無知として、正規分布すると思ってlikelihoodを書いてみましょう。

# set model                                                         
def model(X, y=None):
    N = len(X)
    m = numpyro.sample("m",dist.Uniform(-10,10))
    b  = numpyro.sample("b",dist.Uniform(-10,10))
    sigma = numpyro.sample('sigma', dist.Uniform(0,5))
    mu = m*X+b
    numpyro.deterministic("chi2dof", jnp.sum((y-mu)**2/sigma**2)/N)
    #numpyro.sample("y",dist.Normal(mu,sigma),obs=y)
    loglike = -0.5*jnp.sum(((y-mu)**2)/(sigma**2)) - N*jnp.log(sigma) - 0.5*N*jnp.log(2*jnp.pi)
    numpyro.factor("loglike", loglike)

Pyro系はlog likelihoodを明示して書く必要がないようですが、numpyro.factorで指定することも可能です。
ここでは、データの各点の誤差が独立な正規分布すると仮定しているので、
$$\ln{L}=\ln\Pi^{N}_{i}\frac{1}{(\sqrt{2\pi\sigma_i})^{N}}\exp{-\frac{1}{2}{\left(\frac{y_i-\mu}{\sigma_i}\right)^2}}$$
としています。

# setting up the sampler                                                   
nuts_kernel = NUTS(model)
num_warmup, num_samples = 300, 3000
mcmc = MCMC(nuts_kernel,num_warmup=num_warmup, num_samples=num_samples,num_chains=1)

num_warmupは、HMCの各linkにおけるサンプリング長を決めるためのburn-inのようなもの(?)です。

# sampling                                                                 
mcmc.run(random.PRNGKey(0), x, y = y_obs)
mcmc.print_summary()
sample = pd.DataFrame(mcmc.get_samples(group_by_chain=False))
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b      0.47      0.03      0.47      0.42      0.51   4438.33      1.00
         m      0.29      0.02      0.29      0.26      0.32   4707.75      1.00
     sigma      0.19      0.02      0.19      0.16      0.23   3559.17      1.00

Number of divergences: 0
fig = arviz.plot_trace(mcmc,show=True,figsize=(12,6),compact=False)

image.png

import corner
fig = corner.corner(sample,truths=(b_true,1,m_true,sigma_true),show_titles=True,truth_color="red")

image.png

今回はHMCの恩恵を感じることはありませんでしたが、Numpyro HMCの使い方を少しは把握したと思います。

2
4
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4