Edited at

TFLearnでVAEGAN

More than 1 year has passed since last update.


VAEGANとは

Autoencoding beyond pixels using a learned similarity metric

VAEの後ろにGANのDiscriminatorをつけたもの

VAEの誤差にDecoderで生成された画像ではなく、Discriminatorの中間層から取り出した特徴マップを用いる

VAEがピクセル単位で誤差を計測する関係でぼやけた画像が出てくるのに対して、特徴マップ単位で誤差を計測することで大域特徴を再現しつつ精細な画像を生成できる……かもしれない

Encoderはオリジナル画像とデコード画像の特徴マップの差を誤差として学習(VAE)

Decoderはそれに加えて、デコード画像・ランダム生成画像のDiscriminatorによる識別結果も誤差として学習(VAE+GAN)

Discriminatorはオリジナル画像・デコード画像・ランダム生成画像の識別結果から学習(GAN)

VAEを通常のピクセル単位で、GANをオリジナル画像とランダム生成画像で事前学習するといいらしい(GANのGeneratorをVAEのDecoderに置き換えて、両方を独立に学習)

まだ実験中なので、後で色々変更するかも


コード

DCGANのときは頑張って元の実装に寄せたけど、今回はEncoder・Decoder・Discriminatorのそれぞれが別のサンプルで訓練される簡易版で済ませた

元論文や他人の実装は参考にしたけど、多少変更を加えた


  • PretrainingでVAEとGANが同時に訓練されていたのを、VAE→GANと順番に訓練

  • Mean SquareやKullback Leibler DivergenceでFeatureとかの次元は合計(Sum)とって、Sampleの次元は平均(Mean)とってたけど、画像サイズや潜在変数を変更すると比率が変わって面倒な気がしたので全てMeanに変更(数学的には間違ってるかもしれない)

  • Decoder LossがEncoder Loss + Discriminator Lossだったのを両者の平均に変更


vaegan.py

from __future__ import (

division,
print_function,
absolute_import
)
from six.moves import range

import tensorflow as tf
import tflearn

import os
import numpy as np
from skimage import io

PRE_VAE_TENSORBOARD_DIR = '/tmp/tflearn_logs/vae/'
PRE_DIS_TENSORBOARD_DIR = '/tmp/tflearn_logs/dis/'
PRE_VAE_CHECKPOINT_PATH = '/tmp/vaegan/pre-vae'
PRE_DIS_CHECKPOINT_PATH = '/tmp/vaegan/pre-dis'
CHECKPOINT_PATH = '/tmp/vaegan/model'

DNN = tflearn.DNN
input_data = tflearn.input_data
fc = tflearn.fully_connected
reshape = tflearn.reshape
conv = tflearn.conv_2d
conv_t = tflearn.conv_2d_transpose
max_pool = tflearn.max_pool_2d
bn = tflearn.batch_normalization
merge = tflearn.merge
sigmoid = tflearn.sigmoid
softmax = tflearn.softmax
softplus = tflearn.softplus
relu = tflearn.relu
elu = tflearn.elu
crossentropy = tflearn.categorical_crossentropy
adam = tflearn.Adam
Trainer = tflearn.Trainer
TrainOp = tflearn.TrainOp

if not os.path.exists('/tmp/tflearn_logs'):
os.mkdir('/tmp/tflearn_logs')
if not os.path.exists(PRE_VAE_TENSORBOARD_DIR):
os.mkdir(PRE_VAE_TENSORBOARD_DIR)
if not os.path.exists(PRE_DIS_TENSORBOARD_DIR):
os.mkdir(PRE_DIS_TENSORBOARD_DIR)
if not os.path.exists('/tmp/vaegan/'):
os.mkdir('/tmp/vaegan/')

class VAEGAN(object):
def __init__(self, img_shape, n_first_channel, n_layer, latent_dim,
kullback_leibler_ratio, reconstruction_weight_against_detail,
vae_learning_rate=0.001, vae_beta1=0.5,
discriminator_learning_rate=0.00001, discriminator_beta1=0.5):
self.img_shape = list(img_shape)
self.input_shape = [None] + self.img_shape
self.img_size = img_shape[:2]
self.n_first_channel = n_first_channel
self.n_layer = n_layer
self.kullback_leibler_ratio = kullback_leibler_ratio
self.reconstruction_weight_against_detail = reconstruction_weight_against_detail
self.latent_dim = latent_dim
self.vae_learning_rate = vae_learning_rate
self.vae_beta1 = vae_beta1
self.discriminator_learning_rate = discriminator_learning_rate
self.discriminator_beta1 = discriminator_beta1

assert self.n_layer > 1, 'n_layer must be more than 1'

self.vae_pretrainer = None
self.discriminator_pretrainer = None
self.trainer = None
self.decoder_graph = tf.Graph()
self.trained_values = {}

def _build_vae_pretrainer(self, encoder, decoder):
inputs = input_data(shape=self.input_shape, name='input_x')
# Build Network
mean, log_var = encoder(inputs)
encoded = self._encode(mean, log_var)
decoded = decoder(encoded)
# Loss
element_wise_loss = self._get_mean_square(decoded, inputs)
kullback_leibler_divergence = \
self._get_kullback_leibler_divergence(mean, log_var)
pretrain_vae_loss = self.reconstruction_weight_against_detail *\
tf.reduce_mean(element_wise_loss + kullback_leibler_divergence)
# Trainer
pretrain_vae_op = TrainOp(loss=pretrain_vae_loss,
optimizer=self._get_optimizer('vae'),
batch_size=128,
name='VAE_pretrainer')

return Trainer(pretrain_vae_op, tensorboard_dir=PRE_VAE_TENSORBOARD_DIR,
tensorboard_verbose=0,
checkpoint_path=PRE_VAE_CHECKPOINT_PATH,
max_checkpoints=1)

def _build_discriminator_pretrainer(self, decoder, discriminator):
inputs = input_data(shape=self.input_shape, name='input_x')
is_true = input_data(shape=(None, 2), name='is_true')
is_false = input_data(shape=(None, 2), name='is_false')
# Build Network
shape = tf.shape(fc(inputs, self.latent_dim))
random_image = decoder(self._get_z(shape))
prediction_origin, _ = discriminator(inputs)
prediction_random, _ = discriminator(random_image, reuse=True)
# Loss
prediction_all = merge([prediction_origin, prediction_random], 'concat',
axis=0)
y_all = merge([is_true, is_false], 'concat', axis=0)
pretrain_discriminator_loss = crossentropy(prediction_all, y_all)
# Trainer
pretrain_discriminator_op = TrainOp(
loss=pretrain_discriminator_loss,
optimizer=self._get_optimizer('discriminator'), batch_size=128,
trainable_vars=self._get_trainable_variables(discriminator.scope),
name='Discriminator_pretrainer')

return Trainer(pretrain_discriminator_op,
tensorboard_dir=PRE_DIS_TENSORBOARD_DIR,
tensorboard_verbose=0,
checkpoint_path=PRE_DIS_CHECKPOINT_PATH,
max_checkpoints=1)

def _build_trainer(self, encoder, decoder, discriminator):
inputs = input_data(shape=self.input_shape, name='input_x')
is_true = input_data(shape=(None, 2), name='is_true')
is_false = input_data(shape=(None, 2), name='is_false')
# Build Network
mean, log_var = encoder(inputs)
encoded = self._encode(mean, log_var)
decoded = decoder(encoded)
random_image = decoder(self._get_z(tf.shape(mean)), reuse=True)
# Loss
## Encoder
prediction_origin, feature_map_origin = discriminator(inputs)
prediction_decoded, feature_map_decoded = discriminator(decoded, reuse=True)
prediction_random, _ = discriminator(random_image, reuse=True)
## Decoder
feature_wise_loss = \
self._get_mean_square(feature_map_decoded, feature_map_origin)
kullback_leibler_divergence = \
self._get_kullback_leibler_divergence(mean, log_var)
encoder_loss = self.reconstruction_weight_against_detail *\
tf.reduce_mean(feature_wise_loss + kullback_leibler_divergence)

prediction_gan = merge([prediction_decoded, prediction_random],
'concat', axis=0)
y_gan = merge([is_true, is_true], 'concat', axis=0)
gan_generator_loss = crossentropy(prediction_gan, y_gan)
decoder_loss = (encoder_loss + gan_generator_loss) * 0.5
## Discriminator
prediction_fake = merge([prediction_decoded, prediction_random],
'concat', axis=0)
y_fake = merge([is_false, is_false], 'concat', axis=0)
real_loss = crossentropy(prediction_origin, is_true)
fake_loss = crossentropy(prediction_fake, y_fake)
discriminator_loss = (real_loss + fake_loss) * 0.5
# Trainer
encoder_op = TrainOp(
loss=encoder_loss,
optimizer=self._get_optimizer('encoder'),
batch_size=64,
trainable_vars=self._get_trainable_variables(encoder.scope),
name='Encoder')
decoder_op = TrainOp(
loss=decoder_loss,
optimizer=self._get_optimizer('decoder'),
batch_size=64,
trainable_vars=self._get_trainable_variables(decoder.scope),
name='Decoder')
discriminator_op = TrainOp(
loss=discriminator_loss,
optimizer=self._get_optimizer('discriminator'),
batch_size=64,
trainable_vars=self._get_trainable_variables(discriminator.scope),
name='Discriminator')
return Trainer([encoder_op, decoder_op, discriminator_op],
checkpoint_path=CHECKPOINT_PATH, max_checkpoints=1)

def _encode(self, mean, log_var):
epsilon = tf.random_normal(tf.shape(mean), name='Epsilon')

return mean + tf.exp(0.5 * log_var) * epsilon

def _get_z(self, shape):
z = tf.random_normal(shape, name='RandomZ')

return reshape(z, (-1, self.latent_dim))

def _get_kullback_leibler_divergence(self, mean, log_var):
square_mean = tf.pow(mean, 2)
variance = tf.exp(log_var)

kullback_leibler_divergence = \
tf.reduce_mean(1 + log_var - square_mean - variance,
reduction_indices=1)
kullback_leibler_divergence = \
- 0.5 * self.kullback_leibler_ratio * kullback_leibler_divergence

return kullback_leibler_divergence

def _get_mean_square(self, prediction, truth):
return tf.reduce_mean(tf.squared_difference(prediction, truth),
reduction_indices=(1, 2, 3))

def _get_optimizer(self, type_str):
if type_str in ['vae', 'encoder', 'decoder']:
learning_rate = self.vae_learning_rate
beta1 = self.vae_beta1
else: # 'discriminator'
learning_rate = self.discriminator_learning_rate
beta1 = self.discriminator_beta1
opt = adam(learning_rate=learning_rate, beta1=beta1)

return opt.get_tensor()

def _get_trainable_variables(self, scope):
return [v for v in tflearn.get_all_trainable_variable()
if scope + '/' in v.name]

def _get_input_tensor_by_name(self, name):
return tf.get_collection(tf.GraphKeys.INPUTS, scope=name)[0]

def train(self, x, n_sample=None, pretrain_vae_epoch=1,
pretrain_discriminator_epoch=1, train_epoch=10):
if n_sample == None:
n_sample = x.shape[0]
is_true = np.tile([0., 1.], [n_sample, 1])
is_false = np.tile([1., 0.], [n_sample, 1])

encoder = Encoder(self.n_first_channel, self.n_layer, self.latent_dim)
decoder = Decoder(self.img_shape, self.n_first_channel, self.n_layer)
discriminator = Discriminator(self.n_first_channel, self.n_layer)

with tf.Graph().as_default():
self.vae_pretrainer = self._build_vae_pretrainer(encoder, decoder)
trainer = self.vae_pretrainer

input_tensor = self._get_input_tensor_by_name('input_x')
feed_dict = {input_tensor:x}
trainer.fit(feed_dict, n_epoch=pretrain_vae_epoch,
snapshot_epoch=True, shuffle_all=True,
run_id='VAE_pretrain')
self.trained_values[encoder.scope] = \
self._get_trained_values(trainer, encoder.scope)
self.trained_values[decoder.scope] = \
self._get_trained_values(trainer, decoder.scope)

with tf.Graph().as_default():
self.discriminator_pretrainer = \
self._build_discriminator_pretrainer(decoder, discriminator)
trainer = self.discriminator_pretrainer
self._assign_values(trainer, decoder.scope)

input_tensor = self._get_input_tensor_by_name('input_x')
true_tensor = self._get_input_tensor_by_name('is_true')
false_tensor = self._get_input_tensor_by_name('is_false')
feed_dict = {input_tensor:x,
true_tensor:is_true,
false_tensor:is_false}
trainer.fit(feed_dict, n_epoch=pretrain_discriminator_epoch,
snapshot_epoch=True, shuffle_all=True,
run_id='Discriminator_pretrain')
self.trained_values[discriminator.scope] = \
self._get_trained_values(trainer, discriminator.scope)

with tf.Graph().as_default():
self.trainer = self._build_trainer(encoder, decoder, discriminator)
trainer = self.trainer
self._assign_values(trainer, encoder.scope)
self._assign_values(trainer, decoder.scope)
self._assign_values(trainer, discriminator.scope)
self._set_decoder(decoder)

input_tensor = self._get_input_tensor_by_name('input_x')
true_tensor = self._get_input_tensor_by_name('is_true')
false_tensor = self._get_input_tensor_by_name('is_false')
feed_dict = {input_tensor:x,
true_tensor:is_true,
false_tensor:is_false}
self.trainer.fit([feed_dict] * 3, n_epoch=train_epoch,
snapshot_step=1000, snapshot_epoch=False,
shuffle_all=True, run_id='VAEGAN',
callbacks=[CustomCallback(self)])

def _get_trained_values(self, trainer, scope):
return {v.name:tflearn.variables.get_value(v, session=trainer.session)
for v in self._get_trainable_variables(scope)}

def _assign_values(self, trainer, scope):
[trainer.session.run(v.assign(self.trained_values[scope][v.name]))
for v in self._get_trainable_variables(scope)]

def _set_decoder(self, decoder):
with self.decoder_graph.as_default():
inputs = input_data(shape=(None, self.latent_dim))
net = decoder(inputs)
self.decoder = DNN(net)

def decode(self, z):
with self.decoder_graph.as_default():
return self.decoder.predict(z)

class Encoder(object):
def __init__(self, n_first_channel, n_layer, latent_dim):
self.n_first_channel = n_first_channel
self.n_layer = n_layer
self.latent_dim = latent_dim
self.scope = 'Encoder'

def __call__(self, x, reuse=False):
net = x

for i in range(self.n_layer):
n_channel = self.n_first_channel * 2 ** i
net = conv(net, n_channel, 4, strides=2, reuse=reuse,
scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
net = bn(net, reuse=reuse,
scope='{s}/BN_{n}'.format(s=self.scope, n=i))
net = relu(net)
# net = softplus(net)
mean = fc(net, self.latent_dim, reuse=reuse,
scope='{s}/Mean'.format(s=self.scope))
log_var = fc(net, self.latent_dim, reuse=reuse,
scope='{s}/LogVariance'.format(s=self.scope))

return mean, log_var

class Decoder(object):
def __init__(self, img_shape, n_first_channel, n_layer):
self.img_size = img_shape[:2]
self.color_channel = img_shape[2]
self.n_first_channel = n_first_channel * 2 ** (n_layer - 1)
self.n_layer = n_layer
self.scope = 'Decoder'

def __call__(self, z, reuse=False):
net = z

feature_height = self.img_size[0] // 2 ** self.n_layer
feature_width = self.img_size[1] // 2 ** self.n_layer
feature_channel = self.n_first_channel

n_units = feature_height * feature_width * feature_channel
net = fc(net, n_units, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
shape = [-1, feature_height, feature_width, feature_channel]
net = reshape(net, shape)

for i in range(self.n_layer):
feature_height *= 2
feature_width *= 2
if i < self.n_layer - 1:
feature_channel //= 2
else:
feature_channel = self.color_channel

net = bn(net, reuse=reuse,
scope='{s}/BN_{n}'.format(s=self.scope, n=i))
net = relu(net)
# net = elu(net)
net = conv_t(net, feature_channel, 4,
[feature_height, feature_width], strides=2,
reuse=reuse,
scope='{s}/ConvT_{n}'.format(s=self.scope, n=i))

net = sigmoid(net)

return net

class Discriminator(object):
def __init__(self, n_first_channel, n_layer):
self.n_first_channel = n_first_channel
self.n_layer = n_layer
self.scope = 'Discriminator'

def __call__(self, x, reuse=False):
net = x

for i in range(self.n_layer):
net = conv(net, self.n_first_channel * 2 ** i, 4, reuse=reuse,
scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
net = max_pool(net, 2)
net = bn(net, reuse=reuse,
scope='{s}/BN_{n}'.format(s=self.scope, n=i))
# net = relu(net)
net = elu(net)
if i == self.n_layer - 1:
feature_reconstruction = net

net = fc(net, 2, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
net = softmax(net)

return net, feature_reconstruction

class CustomCallback(tflearn.callbacks.Callback):
def __init__(self, model, n_side=10):
self.model = model
self.n_side = n_side
self.sample_z = np.random.normal(size=(n_side ** 2, model.latent_dim))

def _save(self, name, z):
model = self.model
n_side = self.n_side
img_height = model.img_shape[0]
img_width = model.img_shape[1]
img_channel = model.img_shape[2]
image = np.ndarray(shape=(n_side * img_height,
n_side * img_width,
img_channel),
dtype=np.float32)

model.trained_values = {
scope:model._get_trained_values(model.trainer, scope)
for scope in model.trained_values}
with model.decoder_graph.as_default():
[model._assign_values(model.decoder, scope)
for scope in model.trained_values]
decoded = model.decode(z)

for y in range(n_side):
for x in range(n_side):
image[y * img_height : (y + 1) * img_height,
x * img_width : (x + 1) * img_width,
:] = decoded[x + y * n_side]
image = np.clip(image, 0, 1)
image *= 255
io.imsave(name, image.astype(np.uint8))

def on_batch_end(self, training_state, snapshot=False):
if snapshot:
step = training_state.step

file_name = '{path}image-{step}.png'.format(path=CHECKPOINT_PATH,
step=step)
self._save(file_name, self.sample_z)

def on_train_end(self, training_state):
latent_dim = self.model.latent_dim

sample_z = np.ndarray(shape=(self.n_side ** 2, latent_dim),
dtype=np.float32)
for row in range(self.n_side):
start = np.random.normal(size=latent_dim)
stop = np.random.normal(size=latent_dim)
z_rows = np.array([np.linspace(start[i], stop[i], num=self.n_side)
for i in range(latent_dim)]).T
sample_z[row * self.n_side : (row + 1) * self.n_side, :] = z_rows

file_name = '{path}image-final.png'.format(path=CHECKPOINT_PATH)
self._save(file_name, sample_z)

(X, Y), (testX, testY) = tflearn.datasets.cifar10.load_data()
X = np.concatenate((X, testX), axis=0)
Y = np.concatenate((Y, testY), axis=0)
X = X[Y == 1]

img_shape = X.shape[1:]

vaegan = VAEGAN(img_shape=img_shape, n_first_channel=64, n_layer=4,
latent_dim=32, kullback_leibler_ratio=0.01,
reconstruction_weight_against_detail=50.0)
vaegan.train(X, pretrain_vae_epoch=1, pretrain_discriminator_epoch=10,
train_epoch=100)



参考サイト

VAEGAN

fauxtograph