LoginSignup
5
5

More than 5 years have passed since last update.

Python(PyStan)で「StanとRでベイズ統計モデリング」のmodel12-4を実行する

Last updated at Posted at 2017-04-02

「StanとRでベイズ統計モデリング」はワクワクする本なのですが、Rなのです。
随分前にRを使っていましたが、現在はpythonに乗り換えていますので、Rに戻るのは
気が乗りません。pythonで頑張りました。

以下のサイトを参考にしましたが、自分流に書いています。特に可視化のコードが分かりづらかったので、
pandasでseaborn用にデータ加工した後、seabornでポンと可視化しています。
PyStan で「StanとRでベイズ統計モデリング」11.3節
Python(PyStan)で「StanとRでベイズ統計モデリング」の5.1節を実行する

状態空間モデルの可視化はRよりしょぼいです。Rのようにベイズ信頼区間を灰色にしたかったのですが、
できませんでした。

STAN良いですね! 夢が膨らみます!
seabornも良いです!

書籍のgithubはここです。

#!/usr/bin/env python
# -*- coding utf-8 -*-

# -------
# import
# -------
import os
import math
import pickle
import pystan
import functools
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

# ------------
# environment
# ------------
file_dir = os.getcwd()
# file_dir =os.path.abspath(os.path.dirname(__file__))
input_dir = os.path.join(file_dir, 'input')
model_dir = os.path.join(file_dir, 'model')
output_dir = os.path.join(file_dir, 'output')

data_file = os.path.join(input_dir, 'data-ss1.txt')
model_file = os.path.join(model_dir, 'model12-4.stan')
model_obj_file = os.path.join(output_dir, 'model.pkl')
fit_obj_file = os.path.join(output_dir, 'fit.pkl')
fit_file = os.path.join(output_dir, 'fit.txt')

# ---------
# compile
# ---------
model = pystan.StanModel(file=model_file)

with open(model_obj_file, 'wb') as f:
  pickle.dump(model, f)

with open(model_obj_file, 'rb') as f:
  model = pickle.load(f)

# ----------------
# data processing
# ----------------
data = pd.read_csv(data_file)

stan_data = {'T': len(data), 'T_pred': 3}
stan_data.update(data.to_dict('list'))

# ----------
# sampling
# ----------
fit = model.sampling(data=stan_data, n_jobs=1)

with open(fit_obj_file, 'wb') as f:
  pickle.dump(fit, f)

with open(fit_file, 'w') as f:
  f.write(str(fit))

with open(fit_obj_file, 'rb') as f:
  fit = pickle.load(f)

# --------------------------
# prepare for visualization
# --------------------------
sample_wide = fit.extract(permuted=False, inc_warmup=False)

sample_wide = pd.DataFrame(
  sample_wide.transpose(1,0,2).reshape(
    sample_wide.shape[0] * sample_wide.shape[1],
    sample_wide.shape[2]
  ),
  columns=fit.sim['fnames_oi']
)

sample_long = sample_wide.unstack().reset_index()
sample_long = sample_long.drop('level_1', axis=1)
sample_long.columns = ['param', 'sample']

chain_cnt = fit.sim['chains']
sample_cnt = int(len(sample_wide) / fit.sim['chains'])
param_cnt = len(fit.sim['fnames_oi'])
sample_long['chain'] = (
  np.ones((param_cnt, sample_cnt * chain_cnt), dtype='int64') *
  np.repeat(range(0, chain_cnt), sample_cnt)
).ravel()
sample_long['seq'] =sample_long.groupby(['param', 'chain']).cumcount()

# ---------------
# sampling check
# ---------------
sns.set(font_scale=2)
sns.set_style("ticks")
sns.despine(offset=10, trim=True)
g = sns.FacetGrid(
  sample_long,
  row="param",
  hue="chain",
  size=5, aspect=3,
  sharex=True, sharey=False
)
g = (g.map(plt.plot, "seq", "sample").add_legend())
plt.show()

sns.set(font_scale=2)
sns.set_style("ticks")
sns.despine(offset=10, trim=True)
g = sns.FacetGrid(
  sample_long,
  col="param",
  col_wrap=9,
  hue="chain",
  size=3, aspect=1,
  sharex=False, sharey=False
)
g = (g.map(sns.kdeplot, "sample").add_legend())
plt.show()

#sample_wide = sample_long[['seq', 'param', 'sample']].pivot_table(
#  index='seq', columns='param', values='sample'
#)
g = sns.PairGrid(sample_wide[[
  'mu_all[0]', 'mu_all[11]', 'mu_all[23]',
  'y_pred[0]', 'y_pred[1]', 'y_pred[2]',
  's_mu', 's_Y'
]])
g.map_upper(plt.scatter)
g.map_lower(sns.kdeplot, cmap="Blues_d")
g.map_diag(sns.distplot, hist=False, fit=stats.norm);
plt.show()

# --------------
# visualization
# --------------
state = pd.DataFrame(
  np.percentile(fit.extract()['mu_all'], q=[10, 50, 90], axis=0).T,
  columns=['state_{0}'.format(c) for c in ['lower', 'middle', 'upper']]
)

pred = pd.DataFrame(
  np.percentile(fit.extract()['y_pred'], q=[10, 50, 90], axis=0).T,
  columns=['pred_{0}'.format(c) for c in ['lower', 'middle', 'upper']]
)

data['T'] = range(0, len(data))
state['T'] = range(0, len(state))
pred['T'] = range(21, 24)

result = functools.reduce(
  lambda left, right: pd.merge(left, right, on='T', how='left'), (
    state,
    pred,
    data[['T', 'Y']]
  )
)

sns.set(font_scale=2)
sns.set_style("ticks")
sns.despine(offset=10, trim=True)
plt.figure(figsize=(15, 15))
plt.plot(result['T'], result['Y'], 'o', color='black')
plt.plot(result['T'], result['state_middle'], '-', color='black')
plt.plot(result['T'], result['state_upper'], '--', color='black')
plt.plot(result['T'], result['state_lower'], '--', color='black')
plt.plot(result['T'], result['pred_middle'], 'o', color='black')
plt.errorbar(
  result['T'], result['pred_middle'],
  yerr=[
    result['pred_middle'] - result['state_lower'],
    result['pred_upper'] - result['state_middle']
  ],
  color='black', ecolor='black', fmt='o'
)
plt.show()

tmp.jpeg

5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5