1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Bayesian Neural Net with Flipout

Last updated at Posted at 2018-11-13

Reference

Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches

Data

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

import tensorflow_probability as tfp
tfd = tfp.distributions

from tensorflow.keras.models import Sequential

cancer = load_breast_cancer()

print (len(cancer.data))
print (cancer.feature_names)
print (cancer.target_names)

malignant_count = len(np.where(cancer.target==0)[0])
benign_count = len(np.where(cancer.target==1)[0])
print('# of 0 (malignant): ', malignant_count)
print('# of 1 (benign): ', benign_count)

image.png

data = np.array(cancer.data)
target = np.reshape(cancer.target, (-1, 1))

x_train, x_test, y_train, y_test = train_test_split(data, target, 
                                                  test_size=0.2, random_state=0)

print ('x_train shape: ', x_train.shape)
print ('y_train shape: ', y_train.shape)
print ('x_test shape: ', x_test.shape)
print ('y_test shape: ', y_test.shape)

image.png

Sample Code

tf.reset_default_graph()

model = Sequential()
layer = tfp.layers.DenseFlipout(1, input_shape=(30, ))
model.add(layer)
               
model.summary()

image.png

n_iter = 500
show_step = 100
n_samples = 50

x = tf.placeholder(shape = [None, 30], dtype = tf.float32)
y = tf.placeholder(shape = [None, 1], dtype = tf.int32)

logits = model(x)
labels_distribution = tfd.Bernoulli(logits=logits)

neg_log_likelihood = -tf.reduce_mean(labels_distribution.log_prob(y))
kl = tf.reduce_mean(model.losses)
elbo_loss = neg_log_likelihood + kl

optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(elbo_loss)

preds = tf.cast(logits > 0, dtype=tf.int32)
correct_preds = tf.equal(preds, y)
accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:
  sess.run(init)
  
  history_loss_train = []
  history_loss_test = []
  history_acc_train = []
  history_acc_test = []

  print ('Training')  
  for step in range(n_iter):
    feed_dict = {x: x_train, y:y_train}
    
    sess.run(train_op, feed_dict=feed_dict)
    
    loss_train = sess.run(elbo_loss, feed_dict=feed_dict)
    acc_train = sess.run(accuracy, feed_dict=feed_dict)

    history_loss_train.append(loss_train)
    history_acc_train.append(acc_train)

    if (step+1) % show_step == 0:
      print ('-'*50)
      print ('Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f}'.format(
          step+1, loss_train, acc_train))
      
    feed_dict = {x: x_test, y: y_test}
    loss_test = sess.run(elbo_loss, feed_dict=feed_dict)
    acc_test = sess.run(accuracy, feed_dict=feed_dict)

    history_loss_test.append(loss_test)
    history_acc_test.append(acc_test)
    
  w_draw = layer.kernel_posterior.sample()
  b_draw = layer.bias_posterior.sample()
  
  sample_w = []
  sample_b = []
  feed_dict = {x: x_train, y:y_train}
  for _ in range(n_samples):
      w, b = sess.run((w_draw, b_draw), feed_dict=feed_dict)
      sample_w.append(w)
      sample_b.append(b)        

image.png

fig = plt.figure(figsize = (10, 3))
ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(range(n_iter), history_loss_train, 'b-', label = 'Training')
ax1.plot(range(n_iter), history_loss_test, 'r-', label = 'Test')
ax1.set_title('Loss')
ax1.legend(loc = 'upper right')

ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(range(n_iter), history_acc_train, 'b-', label = 'Training')
ax2.plot(range(n_iter), history_acc_test, 'r-', label = 'Test')
ax2.set_ylim(0.0, 1.0)
ax2.set_title('Accuracy')
ax2.legend(loc = 'lower right')

plt.show()

image.png

fig = plt.figure(figsize = (10, 3))
ax1 = fig.add_subplot(1, 2, 1)
ax1.plot(range(n_samples), np.array(sample_w)[:,0, 0])
ax1.set_title('w')

ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(range(n_samples), np.array(sample_b)[:, 0])
ax2.set_title('b')

plt.show()

image.png

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?