2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Numpyroを用いた回帰モデル

Posted at

0,前提知識

本記事ではPythonではじめるベイズ機械学習入門(以下[1])の第2章までの知識を前提としています。第3章の回帰モデルをNumpyroで実行したい方向けになっています。ですので理論等は割愛していきます。

使用ライブラリは以下、環境としてはColaboを想定している。

import pandas as pd
import numpy as np
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpyro
import arviz as az

1,線形回帰モデル

1,サンプルデータ

true_w1 = 1.5
true_w2 = 0.8

N = 4
x_data = np.random.uniform(-5,5,size=N)
y_data = true_w1 * x_data + true_w2 + np.random.normal(0.,1.,size=N)

[1]ではプロットまでしているが割愛した。

2,構築

#モデルの定義
def model(x, y):
    w1 = numpyro.sample("w1", numpyro.distributions.Normal(0, 1))  # weight
    w2 = numpyro.sample("w2", numpyro.distributions.Normal(0, 1))  # bias
    sigma = numpyro.sample("sigma", numpyro.distributions.HalfNormal(scale=1.0))  # observation noise
    
    y_hat = w1 * x + w2  # linear regression relationship
    with numpyro.plate("data", len(x)):
        numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y)  # observation model

#NUTSを指定
kernel = numpyro.infer.NUTS(model)
#MCMCによる推論
mcmc = numpyro.infer.MCMC(kernel,num_warmup=500,num_samples=2000,num_chains=2)

3,実行

mcmc.run(jax.random.PRNGKey(1),x=x_data, y=y_data)

以上で事後分布が得られたので、サンプルを可視化していく。

mcmc.print_summary()

以下のように表形式で出力してくれる。見づらいがまとまってるので良し。

mean std median 5.0% 95.0% n_eff r_hat
b 1.25 0.36 1.30 0.70 1.84 2049.80 1.00
sigma 0.69 0.38 0.59 0.21 1.23 1221.64 1.00
w 1.61 0.09 1.62 1.48 1.75 2177.36 1.00

予測分布のサンプル

samples = mcmc.get_samples()
w1_samples = samples['w1']
w2_samples = samples['w2']

x_new = np.linspace(-5,5,100)

for i in range(0,4000,10):
    y_pred = w1_samples[i] * x_plot_data + w2_samples[i]
    plt.plot(x_new,y_pred,alpha=0.01,color="g")

plt.scatter(x_data, y_data, color='blue', label='observed')
plt.legend()
plt.show()

観測データ次第だが以下のように出力される。
スクリーンショット (39).png

2,重回帰モデル

1,サンプルデータ

from mpl_toolkits.mplot3d import Axes3D

#次元
dim = 2
#データ数
N = 100
#真のパラメータ
true_w = np.array([-1.5,0.8,1.2]).reshape([3,1])

#サンプルデータ
x_data = np.random.uniform(-5,5,[N,dim])
#バイアスの次元
bias = np.ones(N).reshape([N,1])
x_data_add_bias = np.concatenate([x_data,bias],axis=1)
y_data = np.dot(x_data_add_bias,true_w) + np.random.normal(0.0,1.0,size=[N,1])

fig = plt.figure()
ax = fig.add_subplot(111,projection="3d")

ax.scatter(x_data[:,0],x_data[:,1],y_data)
ax.set_xlabel("x1")
ax.set_ylabel("x2")
ax.set_zlabel("Y")

plt.show()

2,構築

def model(x, y):
    sigma = numpyro.sample("sigma", numpyro.distributions.HalfNormal(1.0))  # observation noise

    with numpyro.plate("para", x.shape[1]):
        w = numpyro.sample("w", numpyro.distributions.Normal(0,1).expand([x.shape[1]]))  # weight

    y_hat = jnp.dot(w,x.T)
    with numpyro.plate("data", x.shape[0]):
        numpyro.sample("obs", numpyro.distributions.Normal(y_hat, sigma), obs=y.ravel())  # observation model

#NUTSを指定
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel,num_warmup=500,num_samples=2000,num_chains=2)

3,実行

mcmc.run(jax.random.PRNGKey(0),x=x_data_add_bias,y=y_data)

結果を可視化すると

az.plot_trace(mcmc)
az.plot_posterior(mcmc,hdi_prob=0.9)

download.png
download.png

3,参考

Pythonではじめるベイズ機械学習入門
上のサポートページ

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?