3
3

More than 5 years have passed since last update.

Bayesian Logistic Regression by Stan

Last updated at Posted at 2018-11-13

Data

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

import pystan

cancer = load_breast_cancer()

indices = np.random.choice(len(cancer.data), 100, replace=False)
data = cancer.data[indices, :5]
target = cancer.target[indices]
print (cancer.feature_names[:5])

malignant_count = len(np.where(target==0)[0])
benign_count = len(np.where(target==1)[0])
print('# of 0 (malignant): ', malignant_count)
print('# of 1 (benign): ', benign_count)

image.png

image.png

x_train, x_test, y_train, y_test = train_test_split(data, target, 
                                                  test_size=0.2, random_state=0)

print ('x_train shape: ', x_train.shape)
print ('y_train shape: ', y_train.shape)
print ('x_test shape: ', x_test.shape)
print ('y_test shape: ', y_test.shape)

image.png

Logistic Regression

clf = LogisticRegression()
clf.fit(x_train, y_train)

print ('Accuracy train: ', clf.score(x_train, y_train)) 
print ('Accuracy validation: ', clf.score(x_test, y_test))

image.png

Bayesian Logistic Regression

cancer_data = {'N': x_train.shape[0], 'M': x_train.shape[1], 'X': x_train, 'y': y_train}

lr_code = """
  data {
    int N;
    int M;
    real X[N, M];
    int<lower=0, upper=1> y[N];
  }
  parameters {
    real beta0;
    real beta[M]; 
  }
  model {
    for (i in 1:N)
        y[i] ~ bernoulli(inv_logit (beta0 + dot_product(X[i] , beta)));
  }
"""        

%%time

stm = pystan.StanModel(model_code=lr_code)

image.png

%%time

n_itr = 2000
n_warmup = 500
chains = 3

fit = stm.sampling(data=cancer_data, iter=n_itr, chains=chains, n_jobs=-1, 
                   warmup=n_warmup, algorithm="NUTS", verbose=False)

image.png

fit

image.png

la    = fit.extract(permuted=True)
names = fit.model_pars 
n_param = np.sum([1 if len(x) == 0 else x[0] for x in fit.par_dims])

mean_list = np.array(fit.summary()['summary'])[:,0]

f, axes = plt.subplots(n_param, 2, figsize=(10, 4*n_param))
cnt = 0
for name in names:
    dat = la[name]
    if dat.ndim == 2:
        for j in range(dat.shape[1]):
            d = dat[:,j]
            sns.distplot(d, hist=False, rug=True, ax=axes[cnt, 0])
            sns.tsplot(d,   alpha=0.8, lw=.5, ax=axes[cnt, 1])
            cnt += 1
    else:
        # Intercept
        sns.distplot(dat, hist=False, rug=True, ax=axes[cnt, 0])
        sns.tsplot(dat,   alpha=0.8, lw=.5, ax=axes[cnt, 1])
        cnt += 1

name_list = []
for name in names:
    if la[name].ndim == 2:
        for i in range(dat.shape[1]):
            name_list.append("{}{}".format(name,i+1))
    else:
        name_list.append(name)

for i in range(2):
    for j, t in enumerate(name_list):
        axes[j, i].set_title(t)

plt.show()

image.png

image.png

def logistic(x, beta):
    tmp = [1]
    tmp.extend(x)
    x = tmp
    return (1+np.exp(-np.dot(x, beta)))**(-1)

def check_accuracy(data, target, param, threshold = 0.5):
    ans_list = []
    for i in range(len(data)):
        res = logistic(data[i], param)
        ans = 1 if res > threshold else 0
        ans_list.append(ans == target[i])

    return np.mean(ans_list)


param = mean_list[0:6]

print ('Accuracy train: ', check_accuracy(x_train, y_train, param))
print ('Accuracy test: ', check_accuracy(x_test, y_test, param))

image.png

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3