0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 8.1 階層モデルの導入

Last updated at Posted at 2018-08-19

実行環境

インポート

import numpy as np
import pandas as pd
import pystan
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from matplotlib.gridspec import GridSpec
from matplotlib.markers import MarkerStyle
import seaborn as sns
%matplotlib inline

データ読み込み

salary2 = pd.read_csv('./data/data-salary-2.txt')

8.1 階層モデルの導入

8.1.1 解析の目的とデータの分布の確認

fig = plt.figure(figsize=figaspect(2/4))

gs1 = GridSpec(1, 1, figure=fig)
ax1 = plt.subplot(gs1[0, 0])
for kid in salary2['KID'].unique():
    ax1.scatter('X', 'Y', data=salary2.query('KID==@kid'), label=kid, marker=MarkerStyle.filled_markers[kid])
regkws = dict(x='X', y='Y', data=salary2, scatter=False, ci=None, color='gray')
sns.regplot(ax=ax1, **regkws)
ax1.legend(title='KID')

xlim = ax1.get_xlim()
ylim = ax1.get_ylim()
gs2 = GridSpec(2, 2, figure=fig)
for row, col in np.ndindex(2, 2):
    ax = fig.add_subplot(gs2[row, col])
    kid = row + 1 + col * 2
    sns.regplot('X', 'Y', data=salary2.query('KID==@kid'), ci=None, label=kid, ax=ax, color=sns.color_palette()[kid-1], marker=MarkerStyle.filled_markers[kid], line_kws={'linestyle': 'dashed'})
    plt.setp(ax, xlim=xlim, ylim=ylim)
    sns.regplot(ax=ax, **regkws)
    ax.legend(title='KID')
    plt.setp(ax, xlabel='X' if row == 1 else '', ylabel='Y' if col == 0 else '')
    if row == 0:
        plt.setp(ax.get_xticklabels(), visible=False)
    if col == 1:
        plt.setp(ax.get_yticklabels(), visible=False)
gs1.tight_layout(fig, rect=[None, None, 0.5, None])
gs2.tight_layout(fig, rect=[0.5, None, None, None])
plt.show()

fig8-1.png

8.1.2 グループ差を考えない場合

data = dict(
    N=salary2.index.size,
    X=salary2['X'],
    Y=salary2['Y']
)
fit1 = pystan.stan('./stan/model8-1.stan', data=data, seed=1234)

8.1.3 グループごとに切片と傾きを持つ場合

data = {col: salary2[col] for col in salary2.columns}
data['N'] = salary2.index.size
data['K'] = salary2['KID'].max()
fit2 = pystan.stan('./stan/model8-2.stan', data=data, seed=1234)

8.1.4 階層モデル

使用するデータを書籍と揃えるため、Rでデータを作成します。

%load_ext rpy2.ipython
%%R
set.seed(123)
N <- 40
K <- 4
N_k <- c(15, 12, 10, 3)
a0 <- 350
b0 <- 12
s_a <- 60
s_b <- 4
s_Y <- 25
X <- sample(x=0:35, size=N, replace=TRUE)
KID <- rep(1:4, times=N_k)

a <- rnorm(K, mean=0, sd=s_a) + a0
b <- rnorm(K, mean=0, sd=s_b) + b0
d <- data.frame(X=X, KID=KID, a=a[KID], b=b[KID])
d <- transform(d, Y_sim=rnorm(N, mean=a + b*X, sd=s_Y))
import rpy2.robjects as ro
d = ro.pandas2ri.ri2py_dataframe(ro.r['d'])

データ作成ここまで

_, axes = plt.subplots(2, 2, figsize=figaspect(2/2), sharex=True, sharey=True)
rows, cols = axes.shape
for (row, col), ax in np.ndenumerate(axes):
    kid = row * cols + col + 1
    sns.regplot('X', 'Y_sim', data=d.query('KID==@kid'), ci=None, label=kid, marker=MarkerStyle.filled_markers[kid], line_kws={'color': 'k', 'alpha': 0.8, 'linestyle': 'dashed'}, ax=ax)
    ax.legend(title='KID')
    plt.setp(ax, xlabel='X' if row == rows - 1 else '', ylabel='Y' if col == 0 else '')
plt.show()

fig8-2.png

data = dict(
    N=salary2.index.size,
    K=salary2['KID'].unique().size,
    X=salary2['X'],
    Y=salary2['Y'],
    KID=salary2['KID']
)
fit3 = pystan.stan('./stan/model8-3.stan', data=data, seed=1234)

8.1.5 モデルの比較

グラフ外にまとめればいいのでしょうが、右側に凡例をつけるとごちゃごちゃするので省略しました。

ms1 = fit1.extract()
ms2 = fit2.extract()
ms3 = fit3.extract()

K = 4
probs = (2.5, 25, 50, 75, 97.5)
d_qua1 = pd.DataFrame(np.percentile(ms2['a'], probs, axis=0).T, columns=['p{}'.format(p) for p in probs])
d_qua1['KID'] = np.arange(K) + 1 - 0.1
d_qua1['Model'] = '8-2'

d_qua2 = pd.DataFrame(np.percentile(ms3['a'], probs, axis=0).T, columns=['p{}'.format(p) for p in probs])
d_qua2['KID'] = np.arange(K) + 1 + 0.1
d_qua2['Model'] = '8-3'
d_qua = pd.concat([d_qua1, d_qua2], axis=0, ignore_index=True)

fig = plt.figure(figsize=figaspect(2/4))

gs1 = GridSpec(1, 1, figure=fig)
ax1 = plt.subplot(gs1[0, 0])
for model, linestyle, color in zip(d_qua['Model'].unique(), ['dashed', 'solid'], ['w', 'k']):
    d_sub = d_qua.query('Model==@model')
    ax1.vlines('KID', 'p2.5', 'p97.5', data=d_sub, linestyles=linestyle, label=model)
    ax1.scatter('KID', 'p50', data=d_sub, c=color, edgecolors='k', label=None)
xmin, xmax = ax1.get_xlim()
ax1.hlines(np.median(ms1['a']), xmin, xmax, alpha=0.3)
ax1.legend(title='Model', loc='upper left')
plt.setp(ax1, xlabel='KID', ylabel='a', xlim=(xmin, xmax))

K = 4
N_mcmc = ms1['lp__'].size
dx = salary2.groupby('KID').agg({'X': ['min', 'max']})
dx.columns = ['Xmin', 'Xmax']

X_new = np.arange(dx.values.min(), dx.values.max()+1)
N_X = X_new.size
y_base_mcmc1 = ms1['a'].reshape((-1, 1)) + np.outer(ms1['b'], X_new)
y_base_med1 = np.median(y_base_mcmc1, axis=0)
d1 = pd.DataFrame(dict(
    X=np.tile(X_new, 4),
    Y=np.tile(y_base_med1, 4),
    KID=np.repeat(np.arange(4)+1, N_X),
    Model='8-1'
))

d2 = pd.DataFrame({col: [] for col in d1.columns})
d3 = pd.DataFrame({col: [] for col in d1.columns})
for i in range(K):
    kid = i + 1
    X_new = np.arange(dx.loc[kid, 'Xmin'], dx.loc[kid, 'Xmax']+1)
    N_X = X_new.size
    y_base_mcmc2 = ms2['a'][:, i].reshape((-1, 1)) + np.outer(ms2['b'][:, i], X_new)
    y_base_mcmc3 = ms3['a'][:, i].reshape((-1, 1)) + np.outer(ms3['b'][:, i], X_new)
    d2 = d2.append(pd.DataFrame({'X': X_new, 'Y': np.median(y_base_mcmc2, axis=0), 'KID': kid, 'Model': '8-2'}))
    d3 = d3.append(pd.DataFrame({'X': X_new, 'Y': np.median(y_base_mcmc3, axis=0), 'KID': kid, 'Model': '8-3'}))

rows, cols = 2, 2
gs2 = GridSpec(rows, cols)
for row, col in np.ndindex((rows, cols)):
    kid = row + 1 + col * rows
    ax = fig.add_subplot(gs2[row, col], sharex=ax if kid > 1 else None, sharey=ax if kid > 1 else None)
    q = 'KID==@kid'
    ax.scatter('X', 'Y', data=salary2.query(q), color='k', marker=MarkerStyle.filled_markers[kid], alpha=0.3, label=kid)
    ax.plot('X', 'Y', data=d1.query(q), color='k', alpha=0.2, linestyle='solid', label='8-1')
    ax.plot('X', 'Y', data=d2.query(q), color='k', alpha=0.6, linestyle='dashed', label='8-2')
    ax.plot('X', 'Y', data=d3.query(q), color='k', alpha=1, linestyle='solid', label='8-3')
    plt.setp(ax, title=kid, xlabel='X' if row == rows-1 else '', ylabel='Y' if col == 0 else '')
    if row < rows -1:
        plt.setp(ax.get_xticklabels(), visible=False)
    if col > 0:
        plt.setp(ax.get_yticklabels(), visible=False)

gs1.tight_layout(fig, rect=[None, None, 0.5, None])
gs2.tight_layout(fig, rect=[0.5, None, None, None])
top = min(gs1.top, gs2.top)
bottom = max(gs1.bottom, gs2.bottom)
gs1.update(top=top, bottom=bottom)
gs2.update(top=top, bottom=bottom)
plt.show()

fig8-4.png

8.1.6 階層モデルの等価な表現

data = {col: salary2[col] for col in salary2.columns}
data['N'] = salary2.index.size
data['K'] = salary2['KID'].max()
fit = pystan.stan('./stan/model8-4.stan', data=data, seed=1234)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?