インポート
import numpy as np
import pandas as pd
import pystan
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from matplotlib.gridspec import GridSpec
import seaborn as sns
%matplotlib inline
データ読み込み
conc2 = pd.read_csv('./data/data-conc-2.txt')
8.3 非線形モデルの階層モデル
8.3.1 解析の目的とデータの分布の確認
Time = [1, 2, 4, 8, 12, 24]
Time_tbl = pd.Series(Time, index=['Time{}'.format(t) for t in Time])
d = pd.melt(conc2, id_vars='PersonID', var_name='Time', value_name='Y')
d['Time'] = Time_tbl[d['Time']].values
fig = plt.figure(figsize=figaspect(4/8)*1.5)
gs1 = GridSpec(4, 4, figure=fig)
for i, idx in enumerate(np.ndindex(4, 4)):
person = i + 1
ax = fig.add_subplot(gs1[idx])
ax.plot('Time', 'Y', 'o', linestyle='solid', data=d.query('PersonID==@person'))
ax_kws = dict(title=person, xlim=(0, 24), ylim=(-3, 37), xticks=Time, yticks=np.arange(0, 40, 10))
if idx[0] == 3:
ax_kws['xlabel'] = 'Time (hour)'
else:
plt.setp(ax.get_xticklabels(), visible=False)
if idx[1] == 0:
ax_kws['ylabel'] = 'Y'
else:
plt.setp(ax.get_yticklabels(), visible=False)
plt.setp(ax, **ax_kws)
gs2 = GridSpec(1, 1)
ax = fig.add_subplot(gs2[0, 0])
sns.distplot(d.query('Time==24')['Y'], bins=9, hist_kws={'facecolor': 'w', 'edgecolor': 'k'}, kde_kws={'shade': True}, axlabel='Time24', ax=ax)
gs1.tight_layout(fig, rect=[None, None, 0.5, None])
gs2.tight_layout(fig, rect=[0.5, None, None, None])
top = min(gs1.top, gs2.top)
gs1.update(top=top)
gs2.update(top=top)
plt.show()
8.3.4 Stanで実装
N = conc2.index.size
Time = (1, 2, 4, 8, 12, 24)
T_new = 60
Time_new = np.linspace(0, 24, T_new)
data = dict(
N=N,
T=len(Time),
Time=Time,
Y=conc2.loc[:, 'Time1':'Time24'],
T_new=T_new,
Time_new=Time_new
)
fit = pystan.stan('./stan/model8-7.stan', data=data, seed=1234)
ms = fit.extract()
probs = (2.5, 50, 97.5)
qua = np.transpose(np.percentile(ms['y_new'], (2.5, 50, 97.5), axis=0), axes=(1, 2, 0))
d_est = pd.DataFrame(qua.reshape((-1, 3)), columns=['p{}'.format(p) for p in probs])
d_est['PersonID'] = np.repeat(np.arange(N)+1, T_new)
d_est['Time'] = np.tile(Time_new, N)
Time_tbl = pd.Series(Time, index=['Time{}'.format(t) for t in Time])
d = pd.melt(conc2, id_vars='PersonID', var_name='Time', value_name='Y')
d['Time'] = Time_tbl[d['Time']].values
_, axes = plt.subplots(4, 4, figsize=figaspect(7/8)*1.5)
for (row, col), ax in np.ndenumerate(axes):
person = row * 4+ col + 1
ax.fill_between('Time', 'p2.5', 'p97.5', data=d_est.query('PersonID==@person'), color='k', alpha=1/5)
ax.plot('Time', 'p50', data=d_est.query('PersonID==@person'), color='k')
ax.scatter('Time', 'Y', data=d.query('PersonID==@person'), color='k')
if row < 3:
plt.setp(ax.get_xticklabels(), visible=False)
else:
plt.setp(ax, xlabel='Time (hour)')
if col > 0:
plt.setp(ax.get_yticklabels(), visible=False)
else:
plt.setp(ax, ylabel='Y')
plt.setp(ax, title=person, xticks=Time, xlim=(0, 24), yticks=np.arange(0, 40, 10), ylim=(-3, 37))
plt.tight_layout()
plt.show()