LoginSignup
1
3

More than 5 years have passed since last update.

StanとRでベイズ統計モデリング(アヒル本)をPythonにしてみる - 5.3 ロジスティック回帰

Last updated at Posted at 2018-08-17

実行環境

インポート

import pandas as pd
import numpy as np
from scipy.special import expit
import pystan
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
from matplotlib.figure import figaspect
from matplotlib import colors
import seaborn as sns
%matplotlib inline

データ読み込み

attendance3 = pd.read_csv('./data/data-attendance-3.txt')

5.3 ロジスティック回帰

5.3.2 データの分布の確認

pd.crosstab(attendance3['Weather'], attendance3['Y'])

5.3.6 Stanで実装

data = dict(
    I=attendance3.index.size,
    A=attendance3['A'],
    Score=attendance3['Score']/200,
    W=attendance3['Weather'].map(dict(A=0, B=0.2, C=1)),
    Y=attendance3['Y']
)
fit = pystan.stan('./stan/model5-5.stan', data=data, seed=1234)

5.3.7 図によるモデルのチェック

ms = fit.extract()

X = np.arange(30, 201)
q_mcmc = ms['b'][:, 0][:, np.newaxis] + ms['b'][:, 2][:, np.newaxis]*X[np.newaxis, :]/200
q_qua = expit(np.percentile(q_mcmc, (10, 50, 90), axis=0)).T
d_est = pd.DataFrame(np.hstack((X.reshape((-1, 1)), q_qua)), columns=('X', 'p10', 'p50', 'p90'))

plt.figure(figsize=figaspect(3/4.5))
ax = plt.axes()
ax.plot(X, d_est['p50'], c='k')
ax.fill_between(X, d_est['p10'], d_est['p90'], color='k', alpha=2/6)
sns.stripplot('Score', 'Y', data=attendance3.query('A==0 and Weather=="A"'), jitter=True, orient='h', color='k', ax=ax)
plt.setp(ax, xlabel='Score', ylabel='q', ylim=ax.get_ylim()[::-1])
plt.show()

fig5-9.png

d_qua = pd.DataFrame(np.percentile(ms['q'], (10, 50, 90), axis=0).T, columns=('p10', 'p50', 'p90'))
d_qua = pd.concat([attendance3, d_qua], axis=1)

plt.figure(figsize=figaspect(3/4.5))
ax = plt.axes()
ax.violinplot([d_qua.query('Y==@i')['p50'].values for i in range(2)], positions=[0, 1], vert=False)
sns.stripplot('p50', 'Y', hue='A', data=d_qua, jitter=True, orient='h', palette={0: colors.to_rgb('black'), 1: colors.to_rgb('gray')}, size=3, ax=ax)
plt.setp(ax, xlabel='q', ylabel='Y', yticks=(0, 1), ylim=ax.get_ylim()[::-1])
plt.show()

fig5-10.png

ms = fit.extract()

N_mcmc = ms['lp__'].size
spec = np.linspace(0, 1, 201)
probs = [10, 50, 90]

def roc(i):
    fpr, tpr, _ = roc_curve(attendance3['Y'], ms['q'][i, :], drop_intermediate=False)
    return np.interp(spec, fpr, tpr)

m_roc = np.array(list(map(roc, range(N_mcmc))))
d_est = pd.DataFrame(np.hstack((spec.reshape(-1, 1), np.percentile(m_roc, probs, axis=0).T)), columns=['X'] + ['p{}'.format(p) for p in probs])
ax = d_est.plot('X', 'p50', figsize=figaspect(1), legend=False, c='k')
ax.fill_between('X', 'p10', 'p90', data=d_est, color='k', alpha=2/6)
lim = ax.get_xlim()
ax.plot(lim, lim, c='k', alpha=0.5)
plt.setp(ax, xlim=lim, ylim=lim, xlabel='False Positive', ylabel='True Positive')
plt.show()

fig5-11.png

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3