LoginSignup
3
2

More than 5 years have passed since last update.

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 11.3 ゼロ過剰ポアソン分布

Last updated at Posted at 2018-08-20

実行環境

インポート

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy import stats
import pystan
import matplotlib.pyplot as plt
%matplotlib inline

データ読み込み

ZIP = pd.read_csv('./data/data-ZIP.txt')
categorical = ['Sex', 'Sake']
ZIP.loc[:, categorical] = ZIP.loc[:, categorical].astype('category')

11.3 ゼロ過剰ポアソン分布

fit = smf.ols('Y ~ {}'.format('+'.join(ZIP.columns[:-1])), data=ZIP).fit()
fit.summary()

出力は省略

11.3.1 解析の目的とデータの分布の確認

前に使ったものとコードを共有しているので、書籍にあるグラフとは少し異なります。対角(diagonal)の部分でhue=Noneとでもすれば近づきます。

def pairplot(df, **kwargs):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.figure import figaspect
    import seaborn as sns

    def corrplot(x, y, data, cmap=None, correlation='pearson', **kwargs):
        from scipy import stats
        from matplotlib.patches import Ellipse

        if correlation is 'spearman':
            r = stats.spearmanr(data[x], data[y])[0]
        else:
            r = stats.pearsonr(data[x], data[y])[0]
        if cmap is None:
            cmap = 'coolwarm'
        if type(cmap) is str:
            cmap = plt.get_cmap(cmap)
        color = cmap((r+1)/2)
        ax.axis('off')
        ax.add_artist(Ellipse((0.5, 0.5), width=np.sqrt(1+r), height=np.sqrt(1-r), angle=45, color=color))
        ax.text(0.5, 0.5, '{:.2f}'.format(r), size='x-large', horizontalalignment='center', verticalalignment='center')

    def crosstabplot(x, y, data, ax, **kwargs):
        import pandas as pd

        cross = pd.crosstab(data[x], data[y]).values
        size = cross / cross.max() * 500
        crosstab_kws = kwargs['crosstab_kws'] if 'crosstab_kws' in kwargs else {}
        scatter_kws = dict(color=sns.color_palette()[0], alpha=0.3)
        scatter_kws.update(crosstab_kws['scatter_kws'] if 'scatter_kws' in crosstab_kws else {})
        text_kws = dict(size='x-large')
        text_kws.update(crosstab_kws['text_kws'] if 'text_kws' in crosstab_kws else {})
        for (xx, yy), count in np.ndenumerate(cross):
            ax.scatter(xx, yy, s=size[xx, yy], **scatter_kws)
            ax.text(xx, yy, count, horizontalalignment='center', verticalalignment='center', **text_kws)

    def catplot(x, y, hue, data, orient, ax, **kwargs):
        box_kws = dict(color='w')
        box_kws.update(kwargs['box_kws'] if 'box_kws' in kwargs else {})
        sns.boxplot(x, y, data=data, orient=orient, ax=ax, **box_kws)
        strip_kws = dict(color=None if hue else sns.color_palette()[0])
        strip_kws.update(kwargs['strip_kws'] if 'strip_kws' in kwargs else {})
        sns.stripplot(x, y, hue, data, orient=orient, ax=ax, **strip_kws)
        legend = ax.get_legend()
        if legend:
            plt.setp(legend, visible=False)

    def scatterplot(x, y, hue, data, ax, **kwargs):
        if hue:
            hues = data[hue].unique()
            colors = sns.color_palette(n_colors=hues.size)
            for h, c in zip(hues, colors):
                ax.scatter(x, y, data=data.query('{col}=={value}'.format(col=hue, value=h)), color=c, **kwargs)
        else:
            ax.scatter(x, y, data=data, **kwargs)

    n_variables = df.columns.size
    hue = kwargs['hue'] if 'hue' in kwargs else None
    figsize = kwargs['figsize'] if 'figsize' in kwargs else figaspect(1) * 0.5 * n_variables
    _, axes = plt.subplots(n_variables, n_variables, figsize=figsize)
    plt.subplots_adjust(hspace=0.1, wspace=0.1)

    for i in range(n_variables):
        axes[i, i].get_shared_x_axes().join(*axes[i:n_variables, i])
        if i > 1:
            axes[i, 0].get_shared_y_axes().join(*axes[i, :i-1])

    for (row, col), ax in np.ndenumerate(axes):
        x = df.columns[col]
        y = df.columns[row]
        x_data = df[x]
        y_data = df[y]
        x_dtype = x_data.dtype.name
        y_dtype = y_data.dtype.name
        if x_dtype == 'category':
            x_categories = x_data.cat.categories
        if y_dtype == 'category':
            y_categories = y_data.cat.categories

        if row == col: # diagonal
            hue_data = df[hue] if hue else None
            if x_dtype == 'category':
                bar_kws = dict(alpha=0.4)
                bar_kws.update(kwargs['bar_kws'] if 'bar_kws' in kwargs else {})
                if hue:
                    cross = pd.crosstab(x_data, hue_data)
                    cross.index = cross.index.categories
                    cross.columns = cross.columns.categories if hue_data.dtype.name == 'category' else hue_data.unique()
                    cross.reset_index(inplace=True)
                    melt = pd.melt(cross, id_vars='index',var_name='hue')
                    sns.barplot('index', 'value', 'hue', data=melt, ci=None, orient='v', dodge=False, ax=ax, **bar_kws)
                    plt.setp(ax.get_legend(), visible=False)
                else:
                    cross = pd.crosstab(x_data, []).values.ravel()
                    sns.barplot(x_data.cat.categories, cross, ci=None, orient='v', color=sns.color_palette()[0], ax=ax, **bar_kws)
            else:
                dist_kws = kwargs['dist_kws'] if 'dist_kws' in kwargs else {}
                if hue:
                    hist_kws = dict(color=sns.color_palette(n_colors=hue_data.unique().size), alpha=0.4)
                    hist_kws.update(dist_kws['hist_kws'] if 'hist_kws' in dist_kws else {})
                    hue_values = df[hue].cat.categories if hue_data.dtype.name == 'category' else df[hue].unique()
                    ax.hist([df.query('{hue}=={v}'.format(hue=hue, v=v))[x] for v in hue_values], density=True, histtype='barstacked', **hist_kws)
                    sns.distplot(x_data, hist=False, ax=ax, **dist_kws)
                else:
                    sns.distplot(x_data, ax=ax, **dist_kws)
        elif row < col: # upper
            corr_kws = kwargs['corr_kws'] if 'corr_kws' in kwargs else {}
            corrplot(x, y, data=df, **corr_kws)
        else: # lower
            if x_dtype == 'category' and y_dtype == 'category':
                crosstabplot(x, y, data=df, ax=ax)
            else:
                cat_kws = kwargs['cat_kws'] if 'cat_kws' in kwargs else {}
                if x_dtype == 'category':
                    catplot(x, y, hue, df, 'v', ax, **cat_kws)
                elif y_dtype == 'category':
                    catplot(x, y, hue, df, 'h', ax, **cat_kws)
                else:
                    scatter_kws = kwargs['scatter_kws'] if 'scatter_kws' in kwargs else {}
                    scatterplot(x, y, hue, df, ax, **scatter_kws)
        if row < n_variables-1:
            plt.setp(ax, xlabel='')
            plt.setp(ax.get_xticklabels(), visible=False)
        else:
            plt.setp(ax, xlabel=x)
            if x_dtype == 'category':
                plt.setp(ax, xticks=np.arange(x_categories.size), xticklabels=x_data.cat.categories)
        if col > 0:
            plt.setp(ax, ylabel='')
            plt.setp(ax.get_yticklabels(), visible=False)
        else:
            plt.setp(ax, ylabel=y)
            if row > 0 and y_dtype == 'category':
                plt.setp(ax, yticks=np.arange(y_categories.size), yticklabels=y_data.cat.categories)

    return axes
pairplot(ZIP, hue='Y', corr_kws=dict(correlation='spearman'))
plt.show()

fig11-4.png

11.3.4 Stanで実装

X = sm.add_constant(ZIP.iloc[:, :-1])
X['Age'] /= 10
data = dict(
    N=ZIP.index.size,
    D=X.columns.size,
    Y=ZIP['Y'],
    X=X.astype(float)
)
fit = pystan.stan('./stan/model11-7.stan', data=data, pars=('b', 'q', 'lambda'), seed=123)

ms = fit.extract()
N_mcmc = len(ms['lp__'])
r = [stats.spearmanr(a, b)[0] for a, b in zip(ms['lambda'], ms['q'])]
np.percentile(r, (2.5, 25, 50, 75, 97.5))

array([-0.80376139, -0.6971176 , -0.64964589, -0.60332529, -0.47596719])

11.3.5 推定結果の解釈

print(fit.stansummary())

出力は省略

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2