0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

VAE + Clustering & Anomaly Detection (MNIST)

Posted at

##Library

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import tensorflow as tf
from keras.datasets import mnist 

np.random.seed(10)
tf.set_random_seed(10)

##Data

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

print ('shape of x_train: ', x_train.shape)
print ('shape of y_train: ', y_train.shape)
print ('shape of x_test: ', x_test.shape)
print ('shape of y_test: ', y_test.shape)

image.png

img_zero = []
img_six = []

for i in range(len(x_train)):
  if y_train[i] == 0:
    img_zero.append(x_train[i])
  elif y_train[i] == 6:
    img_six.append(x_train[i])
  else:
    pass

img_zero = np.array(img_zero)
img_six = np.array(img_six)

print ('shape of img_zero: ', img_zero.shape)
print ('shape of img_six: ', img_six.shape)

image.png

label_zero = np.zeros(len(img_zero))
label_six = np.ones(len(img_six))

print ('shape of label_zero: ', label_zero.shape)
print ('shape of label_six: ', label_six.shape)

image.png

img_all = np.concatenate((np.reshape(img_zero, (-1, 28*28)), 
                          np.reshape(img_six, (-1, 28*28))), axis=0)
label_all = np.concatenate((label_zero, label_six), axis= 0)
                         
print ('shape of img_all: ', img_all.shape)
print ('shape of label_all: ', label_all.shape)

image.png

x_reduced = PCA(n_components=2).fit_transform(img_all)

print ('shape of x_reduced: ', x_reduced.shape)

plt.figure(figsize = (5, 3))
plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=label_all, cmap='Reds')
plt.colorbar()
plt.show()

image.png

x_reduced = TSNE(n_components=2, random_state=10).fit_transform(img_all)

print ('shape of x_reduced: ', x_reduced.shape)

plt.figure(figsize = (5, 3))
plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=label_all, cmap='Blues')
plt.colorbar()
plt.show()

image.png

##Model

sess.close()
tf.reset_default_graph()

batch_size = 64

X_in = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28])
Y    = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28])
keep_prob = tf.placeholder(dtype=tf.float32, shape=())

Y_flat = tf.reshape(Y, shape=[-1, 28 * 28])

dec_in_channels = 1
n_latent = 8

reshaped_dim = [-1, 7, 7, dec_in_channels]
#inputs_decoder = 49 * dec_in_channels / 2
inputs_decoder = 24

def lrelu(x, alpha=0.3):
  return tf.maximum(x, tf.multiply(x, alpha))

def encoder(X_in, keep_prob):
  activation = lrelu
  with tf.variable_scope('encoder', reuse=None):
    X = tf.reshape(X_in, shape=[-1, 28, 28, 1])
    x = tf.layers.conv2d(X, filters=64, kernel_size=4, strides=2, padding='same',
                         activation=activation)
    x = tf.nn.dropout(x, keep_prob)
    x = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=2, padding='same', 
                         activation=activation)
    x = tf.nn.dropout(x, keep_prob)
    x = tf.layers.conv2d(x, filters=64, kernel_size=4, strides=1, padding='same', 
                         activation=activation)
    x = tf.nn.dropout(x, keep_prob)
    x = tf.contrib.layers.flatten(x)
    mn = tf.layers.dense(x, units=n_latent)
    sd = 0.5 * tf.layers.dense(x, units=n_latent)            
    epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], n_latent])) 
    z = mn + tf.multiply(epsilon, tf.exp(sd))

    return z, mn, sd

def decoder(sampled_z, keep_prob):
  with tf.variable_scope('decoder', reuse=None):
    x = tf.layers.dense(sampled_z, units=inputs_decoder, activation=lrelu)
    x = tf.layers.dense(x, units=inputs_decoder * 2 + 1, activation=lrelu)
    x = tf.reshape(x, reshaped_dim)
    x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=2, 
                                   padding='same', activation=tf.nn.relu)
    x = tf.nn.dropout(x, keep_prob)
    x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, 
                                   padding='same', activation=tf.nn.relu)
    x = tf.nn.dropout(x, keep_prob)
    x = tf.layers.conv2d_transpose(x, filters=64, kernel_size=4, strides=1, 
                                   padding='same', activation=tf.nn.relu)

    x = tf.contrib.layers.flatten(x)
    x = tf.layers.dense(x, units=28*28, activation=tf.nn.sigmoid)
    img = tf.reshape(x, shape=[-1, 28, 28])
    return img  

sampled, mn, sd = encoder(X_in, keep_prob)
dec = decoder(sampled, keep_prob)

unreshaped = tf.reshape(dec, [-1, 28*28])

img_loss = tf.reduce_sum(tf.squared_difference(unreshaped, Y_flat), 1)
latent_loss = -0.5 * tf.reduce_sum(1.0 + 2.0 * sd - tf.square(mn) - tf.exp(2.0 * sd), 1)
loss = tf.reduce_mean(img_loss + latent_loss)

optimizer = tf.train.AdamOptimizer(0.0005)
train_op = optimizer.minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

##Clustering

n_iter = 1000
show_step  = 200

history_loss = []
history_img_loss = []
history_latent_loss = []

for i in range(n_iter):
  
  # using img_all for training
  batch_indices = np.random.choice(len(img_all), batch_size, replace=False)
  batch = img_all[batch_indices] 
    
  #batch_indices = np.random.choice(len(img_zero), batch_size, replace=False)
  #batch = img_zero[batch_indices] 
  
  feed_dict = {X_in: batch, Y: batch, keep_prob: 0.8}
  sess.run(train_op, feed_dict=feed_dict )
  
  temp_loss, temp_img_loss, temp_latent_loss = sess.run([loss, img_loss, latent_loss], 
                                        feed_dict=feed_dict)
  
  history_loss.append(temp_loss)
  history_img_loss.append(np.mean(temp_img_loss))
  history_latent_loss.append(np.mean(temp_latent_loss))  

  if not i % show_step:
    ls, d, i_ls, l_ls, mu, sigm = sess.run([loss, dec, img_loss, 
                                            latent_loss, mn, sd], feed_dict = 
                                           {X_in: batch, Y: batch, keep_prob: 1.0})
    print('i: ' +str(i) + ' total loss: ' + str(ls) + ' image loss: ' 
          + str(np.mean(i_ls)) + ' latent loss: ' + str(np.mean(l_ls)))
    
    fig = plt.figure(figsize = (5, 3))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(np.reshape(batch[0], [28, 28]), cmap='gray')
    ax1.set_title('Input image')
    ax1.set_axis_off()

    ax1 = fig.add_subplot(1, 2, 2)
    ax1.imshow(np.reshape(d[0], [28, 28]), cmap='gray')
    ax1.set_title('Reconstructed image')
    ax1.set_axis_off()
    
    plt.show()

image.png

image.png

plt.figure(figsize = (5, 3))
plt.plot(range(n_iter), history_loss, 'b-', label='Total')
plt.title('Total loss')
plt.legend(loc='best')
plt.show()

plt.figure(figsize = (5, 3))
plt.plot(range(n_iter), history_img_loss, 'r-', label='Image')
plt.title('Image (reconstruction) loss')
plt.legend(loc='best')
plt.show()

plt.figure(figsize = (5, 3))
plt.plot(range(n_iter), history_latent_loss, 'g-', label='Latent')
plt.title('Latent (kl divergence) loss')
plt.legend(loc='best')
plt.show()

image.png

mu_zero = sess.run(mn, feed_dict={X_in: img_zero, keep_prob: 1.0})
mu_six = sess.run(mn, feed_dict={X_in: img_six, keep_prob: 1.0})
mu_all = np.concatenate((mu_zero, mu_six), axis=0)

print ('shape of mu_zero: ', mu_zero.shape)
print ('shape of mu_six: ', mu_six.shape)
print ('shape of mu_all: ', mu_all.shape)

image.png

x_reduced = PCA(n_components=2).fit_transform(mu_all)

print ('shape of x_reduced: ', x_reduced.shape)

plt.figure(figsize = (5, 3))
plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=label_all, cmap='Reds')
plt.colorbar()
plt.title('PCA')
plt.show()

image.png

x_reduced = TSNE(n_components=2, random_state=10).fit_transform(mu_all)

print ('shape of x_reduced: ', x_reduced.shape)

plt.figure(figsize = (5, 3))
plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=label_all, cmap='Blues')
plt.colorbar()
plt.title('t-SNE')
plt.show()

image.png

##Anomaly detection

n_iter = 10000
show_step  = 2000

history_loss = []
history_img_loss = []
history_latent_loss = []

for i in range(n_iter):
  #batch_indices = np.random.choice(len(img_all), batch_size, replace=False)
  #batch = img_all[batch_indices] 
  
  # using only img_zero for training
  batch_indices = np.random.choice(len(img_zero), batch_size, replace=False)
  batch = img_zero[batch_indices] 
  
  feed_dict = {X_in: batch, Y: batch, keep_prob: 0.8}
  sess.run(train_op, feed_dict=feed_dict )
  
  temp_loss, temp_img_loss, temp_latent_loss = sess.run([loss, img_loss, latent_loss], 
                                        feed_dict=feed_dict)
  
  history_loss.append(temp_loss)
  history_img_loss.append(np.mean(temp_img_loss))
  history_latent_loss.append(np.mean(temp_latent_loss))  

  if not i % show_step:
    ls, d, i_ls, l_ls, mu, sigm = sess.run([loss, dec, img_loss, 
                                            latent_loss, mn, sd], feed_dict = 
                                           {X_in: batch, Y: batch, keep_prob: 1.0})
    print('i: ' +str(i) + ' total loss: ' + str(ls) + ' image loss: ' 
          + str(np.mean(i_ls)) + ' latent loss: ' + str(np.mean(l_ls)))
    
    fig = plt.figure(figsize = (5, 3))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(np.reshape(batch[0], [28, 28]), cmap='gray')
    ax1.set_title('Input image')
    ax1.set_axis_off()

    ax1 = fig.add_subplot(1, 2, 2)
    ax1.imshow(np.reshape(d[0], [28, 28]), cmap='gray')
    ax1.set_title('Reconstructed image')
    ax1.set_axis_off()
    
    plt.show()

image.png

image.png

plt.figure(figsize = (5, 3))
plt.plot(range(n_iter), history_loss, 'b-', label='Total')
plt.title('Total loss')
plt.legend(loc='best')
plt.show()

plt.figure(figsize = (5, 3))
plt.plot(range(n_iter), history_img_loss, 'r-', label='Image')
plt.title('Image (reconstruction) loss')
plt.legend(loc='best')
plt.show()

plt.figure(figsize = (5, 3))
plt.plot(range(n_iter), history_latent_loss, 'g-', label='Latent')
plt.title('Latent (kl divergence) loss')
plt.legend(loc='best')
plt.show()

image.png

x_reduced = PCA(n_components=2).fit_transform(mu_all)

print ('shape of x_reduced: ', x_reduced.shape)

plt.figure(figsize = (5, 3))
plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=label_all, cmap='Reds')
plt.colorbar()
plt.title('PCA')
plt.show()

image.png

x_reduced = TSNE(n_components=2, random_state=10).fit_transform(mu_all)

print ('shape of x_reduced: ', x_reduced.shape)

plt.figure(figsize = (5, 3))
plt.scatter(x_reduced[:, 0], x_reduced[:, 1], c=label_all, cmap='Blues')
plt.colorbar()
plt.title('t-SNE')
plt.show()

image.png

img_loss_zero = sess.run(img_loss, feed_dict={X_in: img_zero, 
                                              Y: img_zero,
                                              keep_prob: 1.0})
img_loss_six = sess.run(img_loss, feed_dict={X_in: img_six, 
                                             Y: img_six,
                                             keep_prob: 1.0})

print ('shape of img_loss_zero: ', img_loss_zero.shape)
print ('mean of img_loss_zero: ', np.mean(img_loss_zero))
print ('std of img_loss_zero: ', np.std(img_loss_zero))
print ()
print ('shape of img_loss_six: ', img_loss_six.shape)
print ('mean of img_loss_six: ', np.mean(img_loss_six))
print ('std of img_loss_six: ', np.std(img_loss_six))

image.png

n_samples = 300

plt.figure(figsize=(5, 3))
plt.scatter(range(n_samples), img_loss_zero[:n_samples], label='zero')
plt.scatter(range(n_samples), img_loss_six[:n_samples], label='six', alpha=0.5)
plt.title('Loss')
plt.legend(loc='best')
plt.show()

image.png

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?