Python
Stan
PyStan

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 12.7 2次元の空間構造

実行環境

インポート

import numpy as np
import pandas as pd
import pystan
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

データ読み込み

mesh2D = pd.read_csv('./data/data-2Dmesh.txt', header=None)
mesh2D.index += 1
mesh2D.columns += 1
mesh2D_design = pd.read_csv('./data/data-2Dmesh-design.txt', header=None)

12.7 2次元の空間構造

12.7.1 解析の目的とデータの分布の確認

plt.matshow(mesh2D)
plt.colorbar(fraction=0.025, pad=0.05)
plt.setp(plt.gca(), xticks=(5,10,15,20), yticks=(5,10,15), xlabel='Plate Column', ylabel='Plate Row')
plt.show()

fig12-8.png

12.7.4 Stanで実装

d_melt = pd.melt(mesh2D.reset_index(), id_vars='index')
d_melt.columns = ('i', 'j', 'Y')

statsmodelsのlowessが1次元にしか対応していなかったので、Rを使用しました。平滑化するだけなら、他の手法を使ってもいいかと思います。

%load_ext rpy2.ipython
%%R
install.packages('reshape2')

d <- as.matrix(read.csv('./data/data-2Dmesh.txt', header=F))
I <- nrow(d)
J <- ncol(d)
rownames <- 1:I
colnames <- 1:J
d_melt <- reshape2::melt(d)
colnames(d_melt) <- c('i', 'j', 'Y')

d_melt$j = as.numeric(d_melt$j)

loess_res <- loess(Y ~ i + j, data=d_melt, span=0.1)
smoothed <- matrix(loess_res$fitted, nrow=I, ncol=J)
import rpy2.robjects as ro
smoothed = ro.r('smoothed')

Rの使用ここまで。

T = mesh2D_design.values.max()
data = dict(
    I=mesh2D.index.size,
    J=mesh2D.columns.size,
    Y=mesh2D,
    T=T,
    TID=mesh2D_design
)

def init():
    return dict(
        r=smoothed,
        s_r=1,
        beta=np.random.normal(0, 0.1, T),
        s_beta=1,
        s_Y=1
    )

stanmodel = pystan.StanModel('./stan/model12-13.stan')
fit = stanmodel.sampling(data=data, iter=5000, thin=5, seed=1234, init=init)

stanmodel_b = pystan.StanModel('./stan/model12-13b.stan')
fit_b = stanmodel_b.sampling(data=data, iter=5000, thin=5, seed=1234, init=init)

12.7.5 推定結果の解釈

Jupyter notebookでmatplotlibをインタラクティブに3D表示するためのおまじないです。

%matplotlib notebook
%matplotlib notebook

インタラクティブな3D表示ならPlotlyを使ったほうがヌルヌル動く(matplotlibが角度を変えたりするたびにバックエンドで描画して画像を送るのに対して、PlotlyはフロントでJSで描画しなおす)のですが、公式ドキュメントが使いづらく(重い・引数が入れ子になりすぎてて目当てのものを見つけたときに遡りづらい)、適当に'a'などの存在しない引数名を与えてエラーメッセージに出てくる引数候補から見当をつけるというバッドノウハウを蓄積した結果、スペックの高いPCに買い替えてmatplotlibで頑張ったほうが生産性が高いという結論に達しました……

ms = fit.extract()

r_median = np.median(ms['r'], axis=0)
I, J = r_median.shape
ii, jj = np.mgrid[:I, :J]

ax = Axes3D(plt.gcf())
ax.plot_wireframe(ii, jj, r_median, color='k')
ax.plot_surface(ii, jj, r_median, color='k', alpha=0.2)
plt.setp(ax, xlabel='Plate Row', ylabel='Plate Column', zlabel='r')
plt.show()

fig12-9-left.png

通常の描画に戻します。

%matplotlib inline
d = mesh2D.values
TID = mesh2D_design.values
mean_Y = [d[TID == t+1].mean() - d.mean() for t in range(T)]

d_est = pd.DataFrame(np.percentile(ms['beta'], (2.5, 50, 97.5), axis=0).T, columns=['p{}'.format(p) for p in (2.5, 50, 97.5)])
d_est['x'] = mean_Y

plt.figure(figsize=figaspect(1))
ax = plt.axes()
err = pd.DataFrame(np.abs(d_est.loc[:, ['p2.5', 'p97.5']].values - d_est['p50'].values.reshape((-1, 1))), columns=('lower', 'upper'))
ax.errorbar('x', 'p50', yerr=(err['lower'], err['upper']), fmt='o', data=d_est, color='k')
lim = (-5, 5)
ax.plot(lim, lim, color='k', linestyle='dashed')
plt.setp(ax, xlim=lim, ylim=lim, xlabel='Mean of Y[TID]', ylabel='beta[t]')

plt.show()

fig12-9-right.png