1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Inverse Autoregressive Flow の実装に関するメモ

Posted at

ポイント

  • Inverse Autoregressive Flow を実装し、具体的な数値で確認。

レファレンス

1. Improved Variational Inference with Inverse Autoregressive Flow

image.png

image.png

image.png

image.png

             (参照論文1より引用)

サンプルコード

tf.reset_default_graph()

class InverseAutoregressiveFlow():
    
  def __init__(self):
    pass
  
  def weight_variable(self, name, shape):
    initializer = tf.truncated_normal_initializer(mean = 0.0, stddev = 0.01, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  
  def bias_variable(self, name, shape):
    initializer = tf.constant_initializer(value = 0.0, dtype = tf.float32)
    return tf.get_variable(name, shape, initializer = initializer)
  
  def log_q_z_x(self, x, n_in, dim_z, n_units, batch_size):
    pi = tf.constant(value = np.pi, dtype = tf.float32)

    mask = np.zeros([dim_z, dim_z], dtype = np.float32)
    mask[np.triu_indices(dim_z)] = 1.0
    mask = tf.constant(value = mask, dtype = tf.float32)

    with tf.variable_scope('encoder'):
      w = self.weight_variable('w', [n_in, dim_z * 2 + n_units])
      b = self.bias_variable('b', [dim_z * 2 + n_units])
      y = tf.add(tf.matmul(x, w), b)
      
      mu = y[:, :dim_z]
      logvar = y[:, dim_z : dim_z * 2]
      h = y[:, dim_z * 2:]
      
      e = tf.random_normal(shape = [batch_size, dim_z])
      z = mu + e * tf.exp(0.5 * logvar)
      l = - tf.reduce_sum(0.5 * logvar + 0.5 * e * e + 0.5 * tf.log(2 * pi))
      
    with tf.variable_scope('autoregressive'):
      w_m = self.weight_variable('w_m', [dim_z, dim_z])
      b_m = self.bias_variable('b_m', [dim_z])
      w_s = self.weight_variable('w_s', [dim_z, dim_z])
      b_s = self.bias_variable('b_s', [dim_z])
      w_h = self.weight_variable('w_h', [n_units, dim_z])
      b_h = self.bias_variable('b_h', [dim_z])
           
      w_m_masked = mask * w_m
      w_s_masked = mask * w_s
      
      w_m_h = tf.concat([w_m_masked, w_h], axis = 0)
      w_s_h = tf.concat([w_s_masked, w_h], axis = 0)
      
      z_h = tf.concat([z, h], axis = 1)
      
      m = tf.add(tf.matmul(z_h, w_m_h), b_m)
      s = tf.add(tf.matmul(z_h, w_s_h), b_s)
      sig = tf.nn.sigmoid(s + 1)   # +1 <- 'forget gate bias'
      z = sig * z + (1 - sig) * m
      l -= tf.reduce_sum(tf.log(sig))
      
      return l, w_m_h  

アウトプット

x = mnist.test.images[0 : 5]

n_in = 28*28
dim_z = 2
n_units = 3
batch_size = 5

log_q_z_x, w_m_h = iaf.log_q_z_x(x, n_in, dim_z, n_units, batch_size)

init = tf.global_variables_initializer()

with tf.Session() as sess:
  sess.run(init)
  
  print ('log_q_z_x: ')
  print (sess.run(log_q_z_x))
  print ()
  print ('masked weights for m: ')
  print (sess.run(w_m_h))

image.png

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?