Edited at

Adversarial Variational Bayes: Implementation for beginners

This article is for teaching myself how to implement Adversarial Variational Bayes (AVB) as a beginner at Tensorflow. The code is heavily drawn on this blog post. We will be reproduce the synthetic dataset experiment in the original paper.

Graphical explanation

We will be constructing a neural network as the Fig. 1 shows. The input data contains four types of 2x2 binary images (). Although they seem to be 2D image data, we have to stretch them into 1D arrays before feeding them to the neural network. Therefore, we will only use 1D arrays as input data throughput the implementation.

The encoder extracts features from the input data; the distribution of extracted features is called posterior, and operation of running input data through encoder is called inference in terms of probability. Accordingly, running data on latent space (where posterior exists) through decoder is called generation.

Prior is a distribution to which you want to fit the posterior. By training the whole network, the shape of posterior becomes closer and closer to that of prior, which is what you want. The point is that by training this AVB architecture, posterior becomes not only closer to prior but also more "expressive" than those posteriors that generated by other types of autoencoder architectures (such as variational autoencoder), meaning that posterior looks more self-organized (?) so that the decoded images can be sharper.


Figure 1. Network architecture

The details of the theories will be discussed elsewhere, so now we'll focus on the implementation.

Preparation for implementation

The following programs are supposed to be installed.

- Tensorflow

- Tensorflow probability

- Anaconda

The code was tested on python >3.6, ipython 6.5, Tensorflow >1.12 and Tensorflow probability 0.5.

Loading packages

import tensorflow as tf

import numpy as np
import tensorflow_probability as tfp
from tqdm import tqdm
from matplotlib import pyplot as plt

# function shorthand
tfc = tf.contrib
tfd = tfp.distributions
graph_replace = tfc.graph_editor.graph_replace


batch_size = 512 

latent_dim = 2 # dimension of latent space
input_dim = 4 # dimension of input data
n_layer = 2 # number of hidden layers
n_unit = 256 # number of hidden units

Input data

# number of data for each class; here we have 4 classes

points_per_class = batch_size / input_dim
# create labels for the 4 classes
labels = np.concatenate([[i] * int(points_per_class)
for i in range(input_dim)])
# create dataset
np_data = np.eye(input_dim, dtype=np.float32)[labels]

Model construction

We will now start constructing a neural network described as previous figure. Here you will learn some basics of how to use Tensorflow including:

Making input data a constant

To be calculated by Tensorflow, all variables have to be tensors. Since np_data, we created above, is in numpy format, we have to convert it to a tensor. Also because the input data is very simple & artificial (meaning we are not using real images), we can make the input data a constant, which makes coding a little easier (to understand).

x = tf.constant(np_data)

Check the shape on ipython

In [9]: np_data.shape

Out[9]: (512, 4)

In [9]: x.shape
Out[9]: TensorShape([Dimension(512), Dimension(4)])


Now we need to create "noise" from a Gaussian distribution for the input of encoder (Fig. 1). We will create an object with tfp.distributions, from which we can always grab (formally called "sample") some values/probability that are associated with a certain distribution. Since every input data has 4 pixels, we need to sample noise from a 4D Gaussian distribution.

We first start from creating a multivariate Gaussian distribution with tfd.MultivariateNormalDiag. By plugging 4 values of means (loc) and standard deviations (scale_diag) in this function, it creates a 4D Gaussian distribution generator (named as Gauss4D) with the designated means and standard deviations. Here, we set means and standard deviations to be all 0 and 1 respectively. Then, you can sample values equivalent to batch_size from this 4D Gaussian distribution by running Gauss4D.sample(batch_size) as shown below.


Gauss4D = tfd.MultivariateNormalDiag(loc=tf.zeros(input_dim),

noise = Gauss4D.sample(batch_size)

Note that Gauss4D is not a tensor, so Gauss4D cannot be used as a variable.

Now we start constructing neural network.

We start by setting up variable_scope, which is a name for functions and variables. Giving names to functions and variables is important particularly when you train a GAN-like architecture where variables are trained selectively in each round. By calling their names, you can selectively train specific functions and variables. With with statement, all the variables/layers/functions share the same name. By turning on reuse = tf.AUTO_REUSE, you can reuse the same name (but is not necessary in this article).


with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):

enc = tf.concat([x, noise], 1) # concatenate input data and noise
enc = tfc.layers.repeat(enc, n_layer, tfc.layers.fully_connected, n_unit)
enc = tfc.layers.fully_connected(enc, latent_dim, activation_fn = None)

We use repeat to replicate & connect the same layer. Here, we generate & connect n_layers = 2 layers of tfc.layers.fully_connected which has n_unit = 256 of outputs. When using tfc.layers.fully_connected, the activation function is set to ReLU by default. Another tfc.layers.fully_connected is used to reduce the output down to the dimension of latent space (which is latent_dim). No activation function is given for this layer because it is only for reducing the dimension of output.


Now we repeat the same thing to create decoder. Again, using variable_scope to set up name, using repeat to create 2-layer fully connected neural network. But this time we need the final layer to have sigmoid activation function because the output of decoder has to fit to input data which is binary images. The coefficients are to limit the range of the output value to prevent NaN.


with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):    

dec = tfc.layers.repeat(enc, n_layer, tfc.layers.fully_connected, n_unit)
# clipping is necessary to prevent NaN when dec is too large
dec = 1e-6 + (1 - 2e-6) * tfc.layers.fully_connected(dec, input_dim,
activation_fn = tf.nn.sigmoid)
# calculate log_probability in Bernoulli distribution
log_probs = tfd.Bernoulli(probs=dec).log_prob(x)

Because each input image only allow 1 pixel to be 1 and the rest has to be 0, we use Bernoulli distribution to evaluate how generated image is close to input image. tfd.Bernoulli(probs=dec) creates a Bernoulli distribution with its shape defined by dec. By using .log_prob(x), it gives the log probability of sampling x in such distribution.


As Fig. 1 shows, discriminator discriminates posterior from prior which is another Gaussian distribution but this time in 2D. So we start by creating a 2D Gaussian distribution.


Gauss2D = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_dim),

prior = Gauss2D.sample(batch_size)

Posterior and prior are combined with input data when they are discriminated by discriminator (Fig. 1). So we first have to concatenate them with input data x. To make them centered at 0, we perform a little ocnversion 2 * x - 1 when they are concatenated. x will range from 1 to -1. (This is not mandatory)


dis_pr = tf.concat([2*x-1, prior], 1)

dis_po = tf.concat([2*x-1, enc], 1)
with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
dis = tfc.layers.repeat(dis_pr, n_layer,
tfc.layers.fully_connected, n_unit)
log_d_prior = tfc.layers.fully_connected(dis, 1, activation_fn = None)

# use graph_replace to reuse the same network
dis2 = graph_replace(dis, {dis_pr: dis_po})
log_d_posterior = graph_replace(log_d_prior, {dis: dis2})

Although there are two inputs (posterior and prior) here, both of them have to be passed to a single network. This is made possible by using graph_replace. This function works as if you replace the input variable with another: graph_replace(<target function>, {<original var>: <new var>}).

Loss functions

Let's define the loss function for discriminator first because it's easier. Here, we want to perform binary classification, meaning that the labels will be only 0 or 1. We want log_d_prior and log_d_posterior and be close to 0 and 1 respectively, not the other way round. We will see the reason soon.

Using sigmoid_cross_entropy_with_logits, you can calculate the cross entropy between logits and labels. Simply summing up those two sigmoid cross entropy followed by taking the mean of them makes the discriminator loss.


disc_loss = tf.reduce_mean(

logits = log_d_posterior,
labels = tf.ones_like(log_d_posterior)) +
logits = log_d_prior,
labels = tf.zeros_like(log_d_prior)))

Since log_probs is the (log) probability of how the image generated by decoder is close to real input image, it will be trained to be close to 1. In the meanwhile, log_d_posterior is fit to 1 as well. So both values will compete each other, making the whole network adversarial.


# decoder loss

recon_likelihood = tf.reduce_sum(log_probs, axis=1)
# generator loss
gen_loss = tf.reduce_mean(log_d_posterior) - tf.reduce_mean(recon_likelihood)


Now we collect trainable variables separately by names because we want to train them separately.


qvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoder")

pvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "decoder")
dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

Then, we set up optimizers. Here you can define multiple optimizer with different parameters. By using opt.minimize you can also indicate which variables in which functions to be optimized.


opt = tf.train.AdamOptimizer(2e-4, beta1=0.5)

opt2 = tf.train.AdamOptimizer(1e-3, beta1=0.5)
train_gen_op = opt.minimize(gen_loss, var_list = qvars + pvars)
train_disc_op = opt2.minimize(disc_loss, var_list = dvars)


Now, everything is ready. Before start training, you have to create a session because Tensorflow is operated by session. Also, all variables have to be initialized if you want to freshly train a model. Create a session by calling tf.Sess(), and run everything with sess.run(<function or variables>). Basically values in any Tensorflow tensors are not visible without running sess.run().


init_g = tf.global_variables_initializer()

init_l = tf.local_variables_initializer()
sess = tf.Session()

By running optimizer, you train the network. By running loss functions independently, you get outputs of losses each round. So recording these losses you can draw training curves. You can monitor any values by passing the variables to sess.run().

Note that we rung train_gen_op and train_disc_op together, but each optimizer actually optimize different trainable parameters as we designated before. That's why we don't need to worry about freezing parameters. Now you see how convenient to use variable_scope to specifically call variables.


gen_loss_list = []

disc_loss_list = []
n_epoch = 2000
for i in tqdm(range(n_epoch)):
gl, dl, _, _ = sess.run([gen_loss, disc_loss, train_gen_op, train_disc_op])

By using tqdm you can see a progress bar like this:

100%|██████████| 2000/2000 [00:08<00:00, 227.83it/s]


After training ends, you can plot training curve and a latent space scatter plot.



plt.plot(np.arange(n_epoch), np.asarray(gen_loss_list))
plt.plot(np.arange(n_epoch), np.asarray(disc_loss_list))
plt.legend(['generation loss','discrimination loss'])
plt.title('Training curve')

We create latent space scatter plot by passing input data and randomly generated noise into encoder. Because there are only 512 data in input data (i.e. x), which is a little too few, we repeat running enc 10 times to generate 5120 data points on latent space.


n_vis = 10

enc_test = np.vstack([sess.run(enc) for _ in range(n_vis)])
enc_test_label = np.tile(labels, (n_vis))
for i in range(len(np.unique(labels))):
plt.scatter(enc_test[enc_test_label==i, 0], enc_test[enc_test_label==i, 1],
edgecolor='none', alpha=0.5, s=2)
plt.title('Latent space')

Finally, here are results!




The following web sites (in English) significantly helped me understand AVB.




Following web sites explaining the theories (in Japanese) also helped me a lot to achieve .