LoginSignup
1
1

More than 5 years have passed since last update.

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 5.2 二項ロジスティック回帰

Last updated at Posted at 2018-08-17

実行環境

インポート

import numpy as np
import pandas as pd
import pystan
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from matplotlib import colors
from matplotlib.markers import MarkerStyle
%matplotlib inline

データ読み込み

attendance2 = pd.read_csv('./data/data-attendance-2.txt')
attendance2['A'] = pd.Categorical(attendance2['A'])

5.2 二項ロジスティック回帰

5.2.2 データの分布の確認

def pairplot(df, **kwargs):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.figure import figaspect
    import seaborn as sns

    def corrplot(x, y, data, cmap=None, correlation='pearson', **kwargs):
        from scipy import stats
        from matplotlib.patches import Ellipse

        if correlation is 'spearman':
            r = stats.spearmanr(data[x], data[y])[0]
        else:
            r = stats.pearsonr(data[x], data[y])[0]
        if cmap is None:
            cmap = 'coolwarm'
        if type(cmap) is str:
            cmap = plt.get_cmap(cmap)
        color = cmap((r+1)/2)
        ax.axis('off')
        ax.add_artist(Ellipse((0.5, 0.5), width=np.sqrt(1+r), height=np.sqrt(1-r), angle=45, color=color))
        ax.text(0.5, 0.5, '{:.2f}'.format(r), size='x-large', horizontalalignment='center', verticalalignment='center')

    def crosstabplot(x, y, data, ax, **kwargs):
        import pandas as pd

        cross = pd.crosstab(data[x], data[y]).values
        size = cross / cross.max() * 500
        crosstab_kws = kwargs['crosstab_kws'] if 'crosstab_kws' in kwargs else {}
        scatter_kws = dict(color=sns.color_palette()[0], alpha=0.3)
        scatter_kws.update(crosstab_kws['scatter_kws'] if 'scatter_kws' in crosstab_kws else {})
        text_kws = dict(size='x-large')
        text_kws.update(crosstab_kws['text_kws'] if 'text_kws' in crosstab_kws else {})
        for (xx, yy), count in np.ndenumerate(cross):
            ax.scatter(xx, yy, s=size[xx, yy], **scatter_kws)
            ax.text(xx, yy, count, horizontalalignment='center', verticalalignment='center', **text_kws)

    def catplot(x, y, hue, data, orient, ax, **kwargs):
        box_kws = dict(color='w')
        box_kws.update(kwargs['box_kws'] if 'box_kws' in kwargs else {})
        sns.boxplot(x, y, data=data, orient=orient, ax=ax, **box_kws)
        strip_kws = dict(color=None if hue else sns.color_palette()[0])
        strip_kws.update(kwargs['strip_kws'] if 'strip_kws' in kwargs else {})
        sns.stripplot(x, y, hue, data, orient=orient, ax=ax, **strip_kws)
        legend = ax.get_legend()
        if legend:
            plt.setp(legend, visible=False)

    def scatterplot(x, y, hue, data, ax, **kwargs):
        if hue:
            hues = data[hue].unique()
            colors = sns.color_palette(n_colors=hues.size)
            for h, c in zip(hues, colors):
                ax.scatter(x, y, data=data.query('{col}=={value}'.format(col=hue, value=h)), color=c, **kwargs)
        else:
            ax.scatter(x, y, data=data, **kwargs)

    n_variables = df.columns.size
    hue = kwargs['hue'] if 'hue' in kwargs else None
    figsize = kwargs['figsize'] if 'figsize' in kwargs else figaspect(1) * 0.5 * n_variables
    _, axes = plt.subplots(n_variables, n_variables, figsize=figsize)
    plt.subplots_adjust(hspace=0.1, wspace=0.1)

    for i in range(n_variables):
        axes[i, i].get_shared_x_axes().join(*axes[i:n_variables, i])
        if i > 1:
            axes[i, 0].get_shared_y_axes().join(*axes[i, :i-1])

    for (row, col), ax in np.ndenumerate(axes):
        x = df.columns[col]
        y = df.columns[row]
        x_data = df[x]
        y_data = df[y]
        x_dtype = x_data.dtype.name
        y_dtype = y_data.dtype.name
        if x_dtype == 'category':
            x_categories = x_data.cat.categories
        if y_dtype == 'category':
            y_categories = y_data.cat.categories

        if row == col: # diagonal
            hue_data = df[hue] if hue else None
            if x_dtype == 'category':
                bar_kws = dict(alpha=0.4)
                bar_kws.update(kwargs['bar_kws'] if 'bar_kws' in kwargs else {})
                if hue:
                    cross = pd.crosstab(x_data, hue_data)
                    cross.index = cross.index.categories
                    cross.columns = cross.columns.categories if hue_data.dtype.name == 'category' else hue_data.unique()
                    cross.reset_index(inplace=True)
                    melt = pd.melt(cross, id_vars='index',var_name='hue')
                    sns.barplot('index', 'value', 'hue', data=melt, ci=None, orient='v', dodge=False, ax=ax, **bar_kws)
                    plt.setp(ax.get_legend(), visible=False)
                else:
                    cross = pd.crosstab(x_data, []).values.ravel()
                    sns.barplot(x_data.cat.categories, cross, ci=None, orient='v', color=sns.color_palette()[0], ax=ax, **bar_kws)
            else:
                dist_kws = kwargs['dist_kws'] if 'dist_kws' in kwargs else {}
                if hue:
                    hist_kws = dict(color=sns.color_palette(n_colors=hue_data.unique().size), alpha=0.4)
                    hist_kws.update(dist_kws['hist_kws'] if 'hist_kws' in dist_kws else {})
                    hue_values = df[hue].cat.categories if hue_data.dtype.name == 'category' else df[hue].unique()
                    ax.hist([df.query('{hue}=={v}'.format(hue=hue, v=v))[x] for v in hue_values], density=True, histtype='barstacked', **hist_kws)
                    sns.distplot(x_data, hist=False, ax=ax, **dist_kws)
                else:
                    sns.distplot(x_data, ax=ax, **dist_kws)
        elif row < col: # upper
            corr_kws = kwargs['corr_kws'] if 'corr_kws' in kwargs else {}
            corrplot(x, y, data=df, **corr_kws)
        else: # lower
            if x_dtype == 'category' and y_dtype == 'category':
                crosstabplot(x, y, data=df, ax=ax)
            else:
                cat_kws = kwargs['cat_kws'] if 'cat_kws' in kwargs else {}
                if x_dtype == 'category':
                    catplot(x, y, hue, df, 'v', ax, **cat_kws)
                elif y_dtype == 'category':
                    catplot(x, y, hue, df, 'h', ax, **cat_kws)
                else:
                    scatter_kws = kwargs['scatter_kws'] if 'scatter_kws' in kwargs else {}
                    scatterplot(x, y, hue, df, ax, **scatter_kws)
        if row < n_variables-1:
            plt.setp(ax, xlabel='')
            plt.setp(ax.get_xticklabels(), visible=False)
        else:
            plt.setp(ax, xlabel=x)
            if x_dtype == 'category':
                plt.setp(ax, xticks=np.arange(x_categories.size), xticklabels=x_data.cat.categories)
        if col > 0:
            plt.setp(ax, ylabel='')
            plt.setp(ax.get_yticklabels(), visible=False)
        else:
            plt.setp(ax, ylabel=y)
            if row > 0 and y_dtype == 'category':
                plt.setp(ax, yticks=np.arange(y_categories.size), yticklabels=y_data.cat.categories)

    return axes
attendance2['ratio'] = attendance2['Y'] / attendance2['M']
pairplot(attendance2.loc[:, 'A':'ratio'], hue='A', corr_kws=dict(correlation='spearman'))
plt.show()

fig5-6.png

5.2.5 Stanで実装

data = dict(
    N=attendance2.index.size,
    A=attendance2['A'],
    Score=attendance2['Score']/200,
    M=attendance2['M'],
    Y=attendance2['Y']
)
fit = pystan.stan('./stan/model5-4.stan', data=data, seed=1234)

5.2.6 推定結果の解釈

fit

出力は省略

5.2.7 図によるモデルのチェック

ms = fit.extract()

prob = [10, 50, 90]
d_qua = np.percentile(ms['y_pred'], prob, axis=0).T
d_qua = pd.DataFrame(np.hstack((attendance2, d_qua)), columns=attendance2.columns.tolist()+['p{}'.format(p) for p in prob])

plt.figure(figsize=figaspect(1))
ax = plt.axes()
c = ['black', 'gray']
for v in d_qua['A'].unique().astype(int):
    d_part = d_qua.query('A==@v')
    err_lower = d_part['p50'] - d_part['p10']
    err_upper = d_part['p90'] - d_part['p50']
    ax.errorbar('Y', 'p50', data=d_part, yerr=[err_lower, err_upper], fmt='.', color=colors.to_rgb(c[v]), marker=MarkerStyle.filled_markers[int(v)], label='A={}'.format(v))
# ax.axline(1, 1)
_, xmax = ax.get_xlim()
_, ymax = ax.get_ylim()
lim = (0, xmax if xmax > ymax else ymax)
ax.plot(lim, lim, c='k', alpha=3/5)
ax.legend()
plt.setp(ax, xlim=lim, ylim=lim, xlabel='Observed', ylabel='Predicted')
plt.show()

fig5-8.png

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1