Edited at

VAE+ResBlock(個人的)決定版【Tensorflow1.8】

More than 1 year has passed since last update.

VAEは今ではそのお手頃感が売りだと思うのだけど、

github漁っても実装が古かったり読みにくかったり損失関数周りが間違ってそうなものがあったり…だったので、先輩の実装をペタペタしつつ、個人的になうい感じにまとめてみました

VAE + ResBlock + tf.layers系を使います。

いいね来たらgitリポジトリ載せますたぶん。

MNISTに適用しました。(zを二次元にしたので見た目は微妙)

参考にさせていただいたAEの実装(とてもきれいです...:bow:

https://github.com/smurakami/alpha-GAN-tenworflow


import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype(np.float) / 255.
x_test = x_test.astype(np.float) / 255.
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = np.reshape(x_test, (-1, 28, 28, 1))

latent_dim = 2

inputs = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='inputs')
targets = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='targets')

def res_block(x0, filters, kernel_size, strides=(1, 1), kernel_initializer=None):

x = x0
x = tf.layers.conv2d(x, filters, kernel_size=kernel_size, strides=strides, padding="same", kernel_initializer=kernel_initializer)
x = tf.layers.batch_normalization(x, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=kernel_initializer)
x = tf.nn.relu(x)
x = tf.layers.conv2d(x, filters, kernel_size=kernel_size, strides=strides, padding="same", kernel_initializer=kernel_initializer)
x = tf.layers.batch_normalization(x, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=kernel_initializer)
x = tf.nn.relu(x + x0)
return x

def encoder(img):

with tf.variable_scope("encoder", reuse=None):

h = 32
x = img
initializer = tf.random_normal_initializer(0, 0.02)
# -------
x = tf.layers.conv2d(x, h, kernel_size=3, strides=(1, 1), padding="same", kernel_initializer=initializer)
x = tf.nn.relu(x)
# -------
x = res_block(x, h, kernel_size=3, kernel_initializer=initializer)
x = tf.layers.average_pooling2d(x, 2, 2, padding='same')
# -------
x = res_block(x, h, kernel_size=3, kernel_initializer=initializer)
x = tf.layers.average_pooling2d(x, 2, 2, padding='same')
# -------
x = res_block(x, h, kernel_size=3, kernel_initializer=initializer)
x = tf.reshape(x, shape=(-1, h * 7 * 7))
# -------
epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], latent_dim]))
z_mean = tf.layers.dense(x, latent_dim, kernel_initializer=initializer)
z_log_var = tf.layers.dense(x, latent_dim, kernel_initializer=initializer)
z = z_mean + tf.multiply(epsilon, tf.exp(0.5*z_log_var))

return z, z_mean, z_log_var

def build_decoder(z):

with tf.variable_scope("decoder", reuse=None):

h = 32
x = z
initializer = tf.random_normal_initializer(0, 0.02)
# -------
x = tf.layers.dense(x, h * 7 * 7, kernel_initializer=initializer)
x = tf.reshape(x, shape=(-1, 7, 7, h))
x = tf.nn.relu(x)
# -------
x = res_block(x, h, kernel_size=3, kernel_initializer=initializer)
x = tf.image.resize_nearest_neighbor(x, (x.shape[1] * 2, x.shape[2] * 2))
# -------
x = res_block(x, h, kernel_size=3, kernel_initializer=initializer)
x = tf.image.resize_nearest_neighbor(x, (x.shape[1] * 2, x.shape[2] * 2))
# -------
x = res_block(x, h, kernel_size=3, kernel_initializer=initializer)
x = tf.layers.conv2d(x, 1, kernel_size=1, padding="same", kernel_initializer=initializer)
img = tf.nn.tanh(x)

return img

encoded, z_mean, z_log_var = encoder(inputs)

decoded = build_decoder(encoded)

reconst_loss = tf.reshape(tf.reduce_sum(tf.keras.backend.binary_crossentropy(targets, decoded), [1,2]), (-1,)) # reshapeはなくてもいい

latent_loss = 0.5 * tf.reduce_sum(tf.exp(z_log_var) + tf.square(z_mean) - 1 -z_log_var, 1)#1
loss = tf.reduce_mean(reconst_loss + latent_loss)

optimize = tf.train.AdamOptimizer().minimize(loss)

batch_size = 100

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

saver = tf.train.Saver()
saver.restore(sess, "model/model-19")

epoch = 0


毎エポック生成画像を出力するときに使う関数

def save(path):

random_z = np.random.normal(0, 1, (batch_size,latent_dim))
generated = sess.run(decoded, feed_dict = {encoded: random_z}).reshape(-1, 28, 28)

plt.figure(figsize=(8,8))
for i, x in enumerate(generated[:100]):
plt.subplot(10,10,i+1)
plt.axis('off')
plt.imshow(x, cmap='gray')

plt.savefig(path)
plt.close()


学習

max_epoch = 20

for _ in range(epoch, max_epoch):
np.random.shuffle(x_train)
for index in tqdm(range(0, len(x_train), batch_size)):
batch = x_train[index:index+batch_size]
sess.run(optimize, feed_dict = {inputs: batch, targets: batch})
saver.save(sess, './model/model', global_step=epoch)
current_loss = sess.run(loss, feed_dict = {inputs: batch, targets: batch})
print("epoch:", epoch+1, ", loss:", current_loss)
# current_reconst_loss, current_latent_loss = sess.run([reconst_loss, latent_loss], feed_dict = {inputs: batch, targets: batch})
# print(" , reconst_loss:", *current_reconst_loss[:3], "... , shape =", current_reconst_loss.shape)
# print(" , latent_loss:", *current_latent_loss[:3], "... , shape =", current_latent_loss.shape)
save('./sample/mnist_vae-'+str(epoch+1)+'.png')
epoch += 1


エンコード&デコード結果

batch = x_train[:batch_size]

generated = sess.run(decoded, feed_dict = {inputs: batch}).reshape(-1, 28, 28)

plt.figure(figsize=(8,8))
for i, x in enumerate(generated[:100]):
plt.subplot(10,10,i+1)
plt.axis('off')
plt.imshow(x, cmap='gray')

plt.show()


ランダムなzをデコードした結果

random_z = np.random.normal(0, 1, (batch_size,latent_dim))

generated = sess.run(decoded, feed_dict = {encoded:random_z}).reshape(-1, 28, 28)

plt.figure(figsize=(8,8))
for i, x in enumerate(generated[:100]):
plt.subplot(10,10,i+1)
plt.axis('off')
plt.imshow(x, cmap='gray')

plt.show()

sample_z = []

for index in tqdm(range(0, 1000, batch_size)):
batch = x_train[index:index+batch_size]
z = sess.run(encoded, feed_dict = {inputs: batch})
sample_z.extend(z)

sample_z = np.array(sample_z)

np.save("sample_z.npy", sample_z)


潜在空間の可視化

plt.figure(figsize=(10,5))

plt.subplot(1,2,1)
plt.title("data")

plt.xlim(-4,4)
plt.ylim(-4,4)
plt.scatter(sample_z[:,0], sample_z[:,1], marker="x", linewidth=1, color='gray')

plt.subplot(1,2,2)
plt.title("ground truth")

plt.xlim(-4,4)
plt.ylim(-4,4)
for k in range(10):
plt.scatter(sample_z[y_train[:1000]==k][:,0], sample_z[y_train[:1000]==k][:,1], marker="x", linewidth=1, label=str(k))

plt.legend()
plt.show()


ヒトリゴト

読んでくださってありがとうございます

\def\textlarge#1{%

{\rm\Large #1}}
\def\textsmall#1{%
{\rm\scriptsize #1}}

(精一杯小声で)

  $\textsmall{い い ね く だ さ い}$