4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 4.4 単回帰

Last updated at Posted at 2018-08-16

実行環境

インポート

import pickle
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import pystan
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from matplotlib import colors as mcolors
import seaborn as sns
%matplotlib inline

データ読み込み

salary = pd.read_csv('./data/data-salary.txt', header=0)

4.4 単回帰

salary.iloc[np.r_[0:3, -1]]

4.4.2 データの分布の確認

salary.plot(x='X', y='Y', kind='scatter')

fig4-2.png

4.4.4 Rのlm関数で推定

model = smf.ols('Y ~ X', data=salary)
res = model.fit()
res.params

Intercept -119.697132
X 21.904201
dtype: float64

X_new = pd.DataFrame(data=np.arange(23, 61).reshape((-1, 1)), columns=['X'])
predictions = res.get_prediction(X_new)
int_95 = predictions.summary_frame(alpha=0.05)
int_50 = predictions.summary_frame(alpha=0.5)

x, y = X_new['X'], int_95['mean']

_, axes = plt.subplots(1, 2, figsize=figaspect(3/8), sharex=True, sharey=True)
for ax, category in zip(axes, ['mean', 'obs']):
    color = 'k'
    ax.scatter(salary['X'], salary['Y'], c=color)
    ax.plot(x, y, c=color)
    lower = '{}_ci_lower'.format(category)
    upper = '{}_ci_upper'.format(category)
    ax.fill_between(x, int_50[lower], int_50[upper], color=color, alpha=0.5)
    ax.fill_between(x, int_95[lower], int_95[upper], color=color, alpha=0.3)
    plt.setp(ax, xlabel='X', ylabel='Y')
plt.show()

fig4-3.png

4.4.6 Rからの実行方法

sample_file = './sample/model4-5'
data = dict(
    N=salary.index.size,
    X=salary['X'],
    Y=salary['Y']
)
fit = pystan.stan(file='./stan/model4-5.stan', data=data, seed=1234, sample_file=sample_file)

4.4.7 RStanの結果の見方

fit

Inference for Stan model: anon_model_760fa0b78af0bbaaa934d1350faefc45.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.

    mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat

a -123.5 2.04 76.4 -272.3 -169.8 -122.9 -73.9 26.33 1404 1.0
b 22.0 0.05 1.7 18.68 20.89 21.97 23.06 25.38 1361 1.0
sigma 85.34 0.44 15.62 61.78 73.94 83.32 94.41 121.4 1256 1.0
lp__ -93.65 0.05 1.33 -97.01 -94.26 -93.31 -92.68 -92.14 828 1.0

Samples were drawn using NUTS at Thu Aug 16 11:49:54 2018.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).

4.4.8 収束診断をファイルへ出力する

fit.plot()
plt.setp(plt.gcf(), size_inches=figaspect(3/2) * 1.5)
plt.tight_layout()
plt.show()

fig4-4a.png

names = fit.sim['fnames_oi']
color_list = [mcolors.to_rgba(c) for c in plt.rcParams['axes.prop_cycle'].by_key()['color']]
chains = [pd.read_csv('{name}_{i}.csv'.format(name=sample_file, i=i), usecols=names, engine='python', comment='#') for i in range(fit.sim['chains'])]

_, axes = plt.subplots(len(names), 1, figsize=figaspect(len(names)))
for name, ax in zip(names, axes):
    ax.set_title(name)
    # 対数尤度の描画は時間がかかるので省略
    if name == 'lp__':
        continue
    for i, chain in enumerate(chains):
        ax.plot(chain.index, chain[name], label='Chain{}'.format(i+1))
    ax.legend()
plt.show()

fig4-4.png

4.4.9 MCMCの設定の変更

stanmodel = pystan.StanModel(file='./stan/model4-5.stan')
data = dict(
    N=salary.shape[0],
    X=salary['X'],
    Y=salary['Y']
)
pars = ('b', 'sigma')
def init():
    return dict(
        a=np.random.uniform(-10, 10, 1)[0],
        b=np.random.uniform(0, 10, 1)[0],
        sigma=10
    )
fit_b = stanmodel.sampling(data=data, pars=pars, init=init, seed=123, chains=3, iter=1000, warmup=200, thin=200)

4.4.10 並列計算の実行方法

# モデルの保存にはpickleを利用
with open('model.pkl', 'wb') as f:
    pickle.dump(fit, f, protocol=pickle.HIGHEST_PROTOCOL)
# 並列計算はn_jobsで指定
stanmodel.sampling(data=data, n_jobs=-1)

4.4.11 ベイズ信頼区間とベイズ予測区間の算出

with open('model.pkl', 'rb') as f:
    fit = pickle.load(f)
ms = fit.extract()
ms['b']

array([21.648502 , 24.32315818, 22.51297332, ..., 18.02090357,
23.9904834 , 22.04242875])

np.percentile(ms['b'], (2.5, 97.5))

array([18.70508043, 25.37697167])

d_mcmc = pd.DataFrame(data=dict(a=ms['a'], b=ms['b'], sigma=ms['sigma']))
d_mcmc.head()
sns.jointplot('a', 'b', data=d_mcmc)
plt.show()

fig4-7.png

N_mcmc = len(ms['lp__'])
y50_base = ms['a'] + ms['b'] * 50
y50 = np.random.normal(loc=y50_base, scale=ms['sigma'], size=N_mcmc)
d_mcmc = pd.DataFrame(data=dict(a=ms['a'], b=ms['b'], sigma=ms['sigma'], y50_base=y50_base, y50=y50))
d_mcmc.iloc[np.r_[:3, -1]]
ages = X_new['X']
steps = ages.size
median = np.empty(steps)
conf50 = np.empty((steps, 2))
conf95 = np.empty((steps, 2))
pred50 = np.empty((steps, 2))
pred95 = np.empty((steps, 2))
np.random.seed(1234)
for i, age in enumerate(ages):
    base = ms['a'] + ms['b'] * age
    y = np.random.normal(loc=base, scale=ms['sigma'], size=N_mcmc)
    median[i] = np.median(base)
    conf50[i] = np.percentile(base, (25, 75))
    pred50[i] = np.percentile(y, (25, 75))
    conf95[i] = np.percentile(base, (2.5, 97.5))
    pred95[i] = np.percentile(y, (2.5, 97.5))
_, axes = plt.subplots(1, 2, figsize=figaspect(3/8), sharex=True, sharey=True)
for interval50, interval95, ax in zip([conf50, pred50], [conf95, pred95], axes):
    color='k'
    ax.scatter('X', 'Y', data=salary, c=color)
    ax.plot(ages, median, c=color)
    ax.fill_between(ages, interval50[:, 0], interval50[:, 1], color=color, alpha=0.5)
    ax.fill_between(ages, interval95[:, 0], interval95[:, 1], color=color, alpha=0.3)
    plt.setp(ax, xlabel='X', ylabel='Y')
plt.show()

fig4-8.png

4.4.12 transformed parameters ブロックと generated quantities ブロック

X_new = np.arange(23, 61)
data = {col: salary[col] for col in salary.columns}
data.update(dict(
    N=salary.index.size,
    N_new=len(X_new),
    X_new=X_new
))
fit = pystan.stan('./stan/model4-4.stan', data=data, seed=1234)
ms = fit.extract()

def quantile_mcmc(x, y_mcmc, probs=[2.5, 25, 50, 75, 97.5]):
    qua = np.percentile(y_mcmc, probs, axis=0).T
    return pd.DataFrame(np.hstack((x.reshape((-1, 1)), qua)), columns=['X'] + ['p{}'.format(p) for p in probs])

def plot_5quantile(ax, data):
    color = 'k'
    ax.plot('X', 'p50', data=data, color=color)
    ax.fill_between('X', 'p2.5', 'p97.5', data=data, color=color, alpha=1/6)
    ax.fill_between('X', 'p25', 'p75', data=data, color=color, alpha=2/6)
    plt.setp(ax, xlabel='X', ylabel='Y', yticks=np.arange(200, 1401, 400), xlim=(22, 61), ylim=(200, 1400))

plt.figure(figsize=figaspect(3/8))
d_est = quantile_mcmc(x=X_new, y_mcmc=ms['y_base_new'])
ax1 = plt.subplot(121)
plot_5quantile(ax1, d_est)
d_est = quantile_mcmc(x=X_new, y_mcmc=ms['y_new'])
ax2 = plt.subplot(122)
plot_5quantile(ax2, d_est)
plt.show()

model4-4.png

ms['y_new'][:6, :4]

array([[350.62655153, 380.55832701, 393.61242951, 367.42299371],
[499.54793356, 476.32663655, 456.6198796 , 373.92625043],
[512.49960072, 415.93237044, 528.00088398, 489.78449245],
[ 63.62170946, 305.60617775, 372.62753946, 466.29639311],
[296.99384618, 429.13932318, 510.96193705, 525.803673 ],
[455.40991609, 442.3794277 , 429.91309597, 333.97571465]])

4
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?