Python
Stan
PyStan

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 8.3 非線形モデルの階層モデル

実行環境

インポート

import numpy as np
import pandas as pd
import pystan
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from matplotlib.gridspec import GridSpec
import seaborn as sns
%matplotlib inline

データ読み込み

conc2 = pd.read_csv('./data/data-conc-2.txt')

8.3 非線形モデルの階層モデル

8.3.1 解析の目的とデータの分布の確認

Time = [1, 2, 4, 8, 12, 24]
Time_tbl = pd.Series(Time, index=['Time{}'.format(t) for t in Time])
d = pd.melt(conc2, id_vars='PersonID', var_name='Time', value_name='Y')
d['Time'] = Time_tbl[d['Time']].values

fig = plt.figure(figsize=figaspect(4/8)*1.5)

gs1 = GridSpec(4, 4, figure=fig)
for i, idx in enumerate(np.ndindex(4, 4)):
    person = i + 1
    ax = fig.add_subplot(gs1[idx])
    ax.plot('Time', 'Y', 'o', linestyle='solid', data=d.query('PersonID==@person'))
    ax_kws = dict(title=person, xlim=(0, 24), ylim=(-3, 37), xticks=Time, yticks=np.arange(0, 40, 10))
    if idx[0] == 3:
        ax_kws['xlabel'] = 'Time (hour)'
    else:
        plt.setp(ax.get_xticklabels(), visible=False)
    if idx[1] == 0:
        ax_kws['ylabel'] = 'Y'
    else:
        plt.setp(ax.get_yticklabels(), visible=False)
    plt.setp(ax, **ax_kws)
gs2 = GridSpec(1, 1)
ax = fig.add_subplot(gs2[0, 0])
sns.distplot(d.query('Time==24')['Y'], bins=9, hist_kws={'facecolor': 'w', 'edgecolor': 'k'}, kde_kws={'shade': True}, axlabel='Time24', ax=ax)
gs1.tight_layout(fig, rect=[None, None, 0.5, None])
gs2.tight_layout(fig, rect=[0.5, None, None, None])
top = min(gs1.top, gs2.top)
gs1.update(top=top)
gs2.update(top=top)
plt.show()

fig8-7.png

8.3.4 Stanで実装

N = conc2.index.size
Time = (1, 2, 4, 8, 12, 24)
T_new = 60
Time_new = np.linspace(0, 24, T_new)
data = dict(
    N=N,
    T=len(Time),
    Time=Time,
    Y=conc2.loc[:, 'Time1':'Time24'],
    T_new=T_new,
    Time_new=Time_new
)
fit = pystan.stan('./stan/model8-7.stan', data=data, seed=1234)
ms = fit.extract()
probs = (2.5, 50, 97.5)
qua = np.transpose(np.percentile(ms['y_new'], (2.5, 50, 97.5), axis=0), axes=(1, 2, 0))
d_est = pd.DataFrame(qua.reshape((-1, 3)), columns=['p{}'.format(p) for p in probs])
d_est['PersonID'] = np.repeat(np.arange(N)+1, T_new)
d_est['Time'] = np.tile(Time_new, N)

Time_tbl = pd.Series(Time, index=['Time{}'.format(t) for t in Time])
d = pd.melt(conc2, id_vars='PersonID', var_name='Time', value_name='Y')
d['Time'] = Time_tbl[d['Time']].values

_, axes = plt.subplots(4, 4, figsize=figaspect(7/8)*1.5)
for (row, col), ax in np.ndenumerate(axes):
    person = row * 4+ col + 1
    ax.fill_between('Time', 'p2.5', 'p97.5', data=d_est.query('PersonID==@person'), color='k', alpha=1/5)
    ax.plot('Time', 'p50', data=d_est.query('PersonID==@person'), color='k')
    ax.scatter('Time', 'Y', data=d.query('PersonID==@person'), color='k')
    if row < 3:
        plt.setp(ax.get_xticklabels(), visible=False)
    else:
        plt.setp(ax, xlabel='Time (hour)')
    if col > 0:
        plt.setp(ax.get_yticklabels(), visible=False)
    else:
        plt.setp(ax, ylabel='Y')
    plt.setp(ax, title=person, xticks=Time, xlim=(0, 24), yticks=np.arange(0, 40, 10), ylim=(-3, 37))
plt.tight_layout()
plt.show()

fig8-8.png