Python
Stan
PyStan

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 11.3 ゼロ過剰ポアソン分布

実行環境

インポート

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy import stats
import pystan
import matplotlib.pyplot as plt
%matplotlib inline

データ読み込み

ZIP = pd.read_csv('./data/data-ZIP.txt')
categorical = ['Sex', 'Sake']
ZIP.loc[:, categorical] = ZIP.loc[:, categorical].astype('category')

11.3 ゼロ過剰ポアソン分布

fit = smf.ols('Y ~ {}'.format('+'.join(ZIP.columns[:-1])), data=ZIP).fit()
fit.summary()

出力は省略

11.3.1 解析の目的とデータの分布の確認

前に使ったものとコードを共有しているので、書籍にあるグラフとは少し異なります。対角(diagonal)の部分でhue=Noneとでもすれば近づきます。

def pairplot(df, **kwargs):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.figure import figaspect
    import seaborn as sns

    def corrplot(x, y, data, cmap=None, correlation='pearson', **kwargs):
        from scipy import stats
        from matplotlib.patches import Ellipse

        if correlation is 'spearman':
            r = stats.spearmanr(data[x], data[y])[0]
        else:
            r = stats.pearsonr(data[x], data[y])[0]
        if cmap is None:
            cmap = 'coolwarm'
        if type(cmap) is str:
            cmap = plt.get_cmap(cmap)
        color = cmap((r+1)/2)
        ax.axis('off')
        ax.add_artist(Ellipse((0.5, 0.5), width=np.sqrt(1+r), height=np.sqrt(1-r), angle=45, color=color))
        ax.text(0.5, 0.5, '{:.2f}'.format(r), size='x-large', horizontalalignment='center', verticalalignment='center')

    def crosstabplot(x, y, data, ax, **kwargs):
        import pandas as pd

        cross = pd.crosstab(data[x], data[y]).values
        size = cross / cross.max() * 500
        crosstab_kws = kwargs['crosstab_kws'] if 'crosstab_kws' in kwargs else {}
        scatter_kws = dict(color=sns.color_palette()[0], alpha=0.3)
        scatter_kws.update(crosstab_kws['scatter_kws'] if 'scatter_kws' in crosstab_kws else {})
        text_kws = dict(size='x-large')
        text_kws.update(crosstab_kws['text_kws'] if 'text_kws' in crosstab_kws else {})
        for (xx, yy), count in np.ndenumerate(cross):
            ax.scatter(xx, yy, s=size[xx, yy], **scatter_kws)
            ax.text(xx, yy, count, horizontalalignment='center', verticalalignment='center', **text_kws)

    def catplot(x, y, hue, data, orient, ax, **kwargs):
        box_kws = dict(color='w')
        box_kws.update(kwargs['box_kws'] if 'box_kws' in kwargs else {})
        sns.boxplot(x, y, data=data, orient=orient, ax=ax, **box_kws)
        strip_kws = dict(color=None if hue else sns.color_palette()[0])
        strip_kws.update(kwargs['strip_kws'] if 'strip_kws' in kwargs else {})
        sns.stripplot(x, y, hue, data, orient=orient, ax=ax, **strip_kws)
        legend = ax.get_legend()
        if legend:
            plt.setp(legend, visible=False)

    def scatterplot(x, y, hue, data, ax, **kwargs):
        if hue:
            hues = data[hue].unique()
            colors = sns.color_palette(n_colors=hues.size)
            for h, c in zip(hues, colors):
                ax.scatter(x, y, data=data.query('{col}=={value}'.format(col=hue, value=h)), color=c, **kwargs)
        else:
            ax.scatter(x, y, data=data, **kwargs)

    n_variables = df.columns.size
    hue = kwargs['hue'] if 'hue' in kwargs else None
    figsize = kwargs['figsize'] if 'figsize' in kwargs else figaspect(1) * 0.5 * n_variables
    _, axes = plt.subplots(n_variables, n_variables, figsize=figsize)
    plt.subplots_adjust(hspace=0.1, wspace=0.1)

    for i in range(n_variables):
        axes[i, i].get_shared_x_axes().join(*axes[i:n_variables, i])
        if i > 1:
            axes[i, 0].get_shared_y_axes().join(*axes[i, :i-1])

    for (row, col), ax in np.ndenumerate(axes):
        x = df.columns[col]
        y = df.columns[row]
        x_data = df[x]
        y_data = df[y]
        x_dtype = x_data.dtype.name
        y_dtype = y_data.dtype.name
        if x_dtype == 'category':
            x_categories = x_data.cat.categories
        if y_dtype == 'category':
            y_categories = y_data.cat.categories

        if row == col: # diagonal
            hue_data = df[hue] if hue else None
            if x_dtype == 'category':
                bar_kws = dict(alpha=0.4)
                bar_kws.update(kwargs['bar_kws'] if 'bar_kws' in kwargs else {})
                if hue:
                    cross = pd.crosstab(x_data, hue_data)
                    cross.index = cross.index.categories
                    cross.columns = cross.columns.categories if hue_data.dtype.name == 'category' else hue_data.unique()
                    cross.reset_index(inplace=True)
                    melt = pd.melt(cross, id_vars='index',var_name='hue')
                    sns.barplot('index', 'value', 'hue', data=melt, ci=None, orient='v', dodge=False, ax=ax, **bar_kws)
                    plt.setp(ax.get_legend(), visible=False)
                else:
                    cross = pd.crosstab(x_data, []).values.ravel()
                    sns.barplot(x_data.cat.categories, cross, ci=None, orient='v', color=sns.color_palette()[0], ax=ax, **bar_kws)
            else:
                dist_kws = kwargs['dist_kws'] if 'dist_kws' in kwargs else {}
                if hue:
                    hist_kws = dict(color=sns.color_palette(n_colors=hue_data.unique().size), alpha=0.4)
                    hist_kws.update(dist_kws['hist_kws'] if 'hist_kws' in dist_kws else {})
                    hue_values = df[hue].cat.categories if hue_data.dtype.name == 'category' else df[hue].unique()
                    ax.hist([df.query('{hue}=={v}'.format(hue=hue, v=v))[x] for v in hue_values], density=True, histtype='barstacked', **hist_kws)
                    sns.distplot(x_data, hist=False, ax=ax, **dist_kws)
                else:
                    sns.distplot(x_data, ax=ax, **dist_kws)
        elif row < col: # upper
            corr_kws = kwargs['corr_kws'] if 'corr_kws' in kwargs else {}
            corrplot(x, y, data=df, **corr_kws)
        else: # lower
            if x_dtype == 'category' and y_dtype == 'category':
                crosstabplot(x, y, data=df, ax=ax)
            else:
                cat_kws = kwargs['cat_kws'] if 'cat_kws' in kwargs else {}
                if x_dtype == 'category':
                    catplot(x, y, hue, df, 'v', ax, **cat_kws)
                elif y_dtype == 'category':
                    catplot(x, y, hue, df, 'h', ax, **cat_kws)
                else:
                    scatter_kws = kwargs['scatter_kws'] if 'scatter_kws' in kwargs else {}
                    scatterplot(x, y, hue, df, ax, **scatter_kws)
        if row < n_variables-1:
            plt.setp(ax, xlabel='')
            plt.setp(ax.get_xticklabels(), visible=False)
        else:
            plt.setp(ax, xlabel=x)
            if x_dtype == 'category':
                plt.setp(ax, xticks=np.arange(x_categories.size), xticklabels=x_data.cat.categories)
        if col > 0:
            plt.setp(ax, ylabel='')
            plt.setp(ax.get_yticklabels(), visible=False)
        else:
            plt.setp(ax, ylabel=y)
            if row > 0 and y_dtype == 'category':
                plt.setp(ax, yticks=np.arange(y_categories.size), yticklabels=y_data.cat.categories)

    return axes
pairplot(ZIP, hue='Y', corr_kws=dict(correlation='spearman'))
plt.show()

fig11-4.png

11.3.4 Stanで実装

X = sm.add_constant(ZIP.iloc[:, :-1])
X['Age'] /= 10
data = dict(
    N=ZIP.index.size,
    D=X.columns.size,
    Y=ZIP['Y'],
    X=X.astype(float)
)
fit = pystan.stan('./stan/model11-7.stan', data=data, pars=('b', 'q', 'lambda'), seed=123)

ms = fit.extract()
N_mcmc = len(ms['lp__'])
r = [stats.spearmanr(a, b)[0] for a, b in zip(ms['lambda'], ms['q'])]
np.percentile(r, (2.5, 25, 50, 75, 97.5))

array([-0.80376139, -0.6971176 , -0.64964589, -0.60332529, -0.47596719])

11.3.5 推定結果の解釈

print(fit.stansummary())

出力は省略