LoginSignup
11

More than 5 years have passed since last update.

TFLearnでVAEGAN

Last updated at Posted at 2016-10-01

VAEGANとは

Autoencoding beyond pixels using a learned similarity metric

VAEの後ろにGANのDiscriminatorをつけたもの
VAEの誤差にDecoderで生成された画像ではなく、Discriminatorの中間層から取り出した特徴マップを用いる
VAEがピクセル単位で誤差を計測する関係でぼやけた画像が出てくるのに対して、特徴マップ単位で誤差を計測することで大域特徴を再現しつつ精細な画像を生成できる……かもしれない

Encoderはオリジナル画像とデコード画像の特徴マップの差を誤差として学習(VAE)
Decoderはそれに加えて、デコード画像・ランダム生成画像のDiscriminatorによる識別結果も誤差として学習(VAE+GAN)
Discriminatorはオリジナル画像・デコード画像・ランダム生成画像の識別結果から学習(GAN)

VAEを通常のピクセル単位で、GANをオリジナル画像とランダム生成画像で事前学習するといいらしい(GANのGeneratorをVAEのDecoderに置き換えて、両方を独立に学習)

まだ実験中なので、後で色々変更するかも

コード

DCGANのときは頑張って元の実装に寄せたけど、今回はEncoder・Decoder・Discriminatorのそれぞれが別のサンプルで訓練される簡易版で済ませた

元論文や他人の実装は参考にしたけど、多少変更を加えた

  • PretrainingでVAEとGANが同時に訓練されていたのを、VAE→GANと順番に訓練
  • Mean SquareやKullback Leibler DivergenceでFeatureとかの次元は合計(Sum)とって、Sampleの次元は平均(Mean)とってたけど、画像サイズや潜在変数を変更すると比率が変わって面倒な気がしたので全てMeanに変更(数学的には間違ってるかもしれない)
  • Decoder LossがEncoder Loss + Discriminator Lossだったのを両者の平均に変更
vaegan.py
from __future__ import (
    division,
    print_function,
    absolute_import
)
from six.moves import range

import tensorflow as tf
import tflearn

import os
import numpy as np
from skimage import io

PRE_VAE_TENSORBOARD_DIR = '/tmp/tflearn_logs/vae/'
PRE_DIS_TENSORBOARD_DIR = '/tmp/tflearn_logs/dis/'
PRE_VAE_CHECKPOINT_PATH = '/tmp/vaegan/pre-vae'
PRE_DIS_CHECKPOINT_PATH = '/tmp/vaegan/pre-dis'
CHECKPOINT_PATH = '/tmp/vaegan/model'

DNN = tflearn.DNN
input_data = tflearn.input_data
fc = tflearn.fully_connected
reshape = tflearn.reshape
conv = tflearn.conv_2d
conv_t = tflearn.conv_2d_transpose
max_pool = tflearn.max_pool_2d
bn = tflearn.batch_normalization
merge = tflearn.merge
sigmoid = tflearn.sigmoid
softmax = tflearn.softmax
softplus = tflearn.softplus
relu = tflearn.relu
elu = tflearn.elu
crossentropy = tflearn.categorical_crossentropy
adam = tflearn.Adam
Trainer = tflearn.Trainer
TrainOp = tflearn.TrainOp

if not os.path.exists('/tmp/tflearn_logs'):
    os.mkdir('/tmp/tflearn_logs')
if not os.path.exists(PRE_VAE_TENSORBOARD_DIR):
    os.mkdir(PRE_VAE_TENSORBOARD_DIR)
if not os.path.exists(PRE_DIS_TENSORBOARD_DIR):
    os.mkdir(PRE_DIS_TENSORBOARD_DIR)
if not os.path.exists('/tmp/vaegan/'):
    os.mkdir('/tmp/vaegan/')

class VAEGAN(object):
    def __init__(self, img_shape, n_first_channel, n_layer, latent_dim,
                 kullback_leibler_ratio, reconstruction_weight_against_detail,
                 vae_learning_rate=0.001, vae_beta1=0.5,
                 discriminator_learning_rate=0.00001, discriminator_beta1=0.5):
        self.img_shape = list(img_shape)
        self.input_shape = [None] + self.img_shape
        self.img_size = img_shape[:2]
        self.n_first_channel = n_first_channel
        self.n_layer = n_layer
        self.kullback_leibler_ratio = kullback_leibler_ratio
        self.reconstruction_weight_against_detail = reconstruction_weight_against_detail
        self.latent_dim = latent_dim
        self.vae_learning_rate = vae_learning_rate
        self.vae_beta1 = vae_beta1
        self.discriminator_learning_rate = discriminator_learning_rate
        self.discriminator_beta1 = discriminator_beta1

        assert self.n_layer > 1, 'n_layer must be more than 1'

        self.vae_pretrainer = None
        self.discriminator_pretrainer = None
        self.trainer = None
        self.decoder_graph = tf.Graph()
        self.trained_values = {}

    def _build_vae_pretrainer(self, encoder, decoder):
        inputs = input_data(shape=self.input_shape, name='input_x')
        # Build Network
        mean, log_var = encoder(inputs)
        encoded = self._encode(mean, log_var)
        decoded = decoder(encoded)
        # Loss
        element_wise_loss = self._get_mean_square(decoded, inputs)
        kullback_leibler_divergence = \
            self._get_kullback_leibler_divergence(mean, log_var)
        pretrain_vae_loss = self.reconstruction_weight_against_detail *\
            tf.reduce_mean(element_wise_loss + kullback_leibler_divergence)
        # Trainer
        pretrain_vae_op = TrainOp(loss=pretrain_vae_loss, 
                                  optimizer=self._get_optimizer('vae'), 
                                  batch_size=128, 
                                  name='VAE_pretrainer')

        return Trainer(pretrain_vae_op, tensorboard_dir=PRE_VAE_TENSORBOARD_DIR,
                       tensorboard_verbose=0,
                       checkpoint_path=PRE_VAE_CHECKPOINT_PATH,
                       max_checkpoints=1)

    def _build_discriminator_pretrainer(self, decoder, discriminator):
        inputs = input_data(shape=self.input_shape, name='input_x')
        is_true = input_data(shape=(None, 2), name='is_true')
        is_false = input_data(shape=(None, 2), name='is_false')
        # Build Network
        shape = tf.shape(fc(inputs, self.latent_dim))
        random_image = decoder(self._get_z(shape))
        prediction_origin, _ = discriminator(inputs)
        prediction_random, _ = discriminator(random_image, reuse=True)
        # Loss
        prediction_all = merge([prediction_origin, prediction_random], 'concat',
                               axis=0)
        y_all = merge([is_true, is_false], 'concat', axis=0)
        pretrain_discriminator_loss = crossentropy(prediction_all, y_all)
        # Trainer
        pretrain_discriminator_op = TrainOp(
            loss=pretrain_discriminator_loss, 
            optimizer=self._get_optimizer('discriminator'), batch_size=128,
            trainable_vars=self._get_trainable_variables(discriminator.scope),
            name='Discriminator_pretrainer')

        return Trainer(pretrain_discriminator_op,
                       tensorboard_dir=PRE_DIS_TENSORBOARD_DIR,
                       tensorboard_verbose=0,
                       checkpoint_path=PRE_DIS_CHECKPOINT_PATH,
                       max_checkpoints=1)

    def _build_trainer(self, encoder, decoder, discriminator):
        inputs = input_data(shape=self.input_shape, name='input_x')
        is_true = input_data(shape=(None, 2), name='is_true')
        is_false = input_data(shape=(None, 2), name='is_false')
        # Build Network
        mean, log_var = encoder(inputs)
        encoded = self._encode(mean, log_var)
        decoded = decoder(encoded)
        random_image = decoder(self._get_z(tf.shape(mean)), reuse=True)
        # Loss
        ## Encoder
        prediction_origin, feature_map_origin = discriminator(inputs)
        prediction_decoded, feature_map_decoded = discriminator(decoded, reuse=True)
        prediction_random, _ = discriminator(random_image, reuse=True)
        ## Decoder
        feature_wise_loss = \
            self._get_mean_square(feature_map_decoded, feature_map_origin)
        kullback_leibler_divergence = \
            self._get_kullback_leibler_divergence(mean, log_var)
        encoder_loss = self.reconstruction_weight_against_detail *\
            tf.reduce_mean(feature_wise_loss + kullback_leibler_divergence)

        prediction_gan = merge([prediction_decoded, prediction_random],
                               'concat', axis=0)
        y_gan = merge([is_true, is_true], 'concat', axis=0)
        gan_generator_loss = crossentropy(prediction_gan, y_gan)
        decoder_loss = (encoder_loss + gan_generator_loss) * 0.5
        ## Discriminator
        prediction_fake = merge([prediction_decoded, prediction_random],
                                'concat', axis=0)
        y_fake = merge([is_false, is_false], 'concat', axis=0)
        real_loss = crossentropy(prediction_origin, is_true)
        fake_loss = crossentropy(prediction_fake, y_fake)
        discriminator_loss = (real_loss + fake_loss) * 0.5
        # Trainer
        encoder_op = TrainOp(
            loss=encoder_loss,
            optimizer=self._get_optimizer('encoder'),
            batch_size=64,
            trainable_vars=self._get_trainable_variables(encoder.scope),
            name='Encoder')
        decoder_op = TrainOp(
            loss=decoder_loss,
            optimizer=self._get_optimizer('decoder'),
            batch_size=64,
            trainable_vars=self._get_trainable_variables(decoder.scope),
            name='Decoder')
        discriminator_op = TrainOp(
            loss=discriminator_loss,
            optimizer=self._get_optimizer('discriminator'),
            batch_size=64,
            trainable_vars=self._get_trainable_variables(discriminator.scope),
            name='Discriminator')
        return Trainer([encoder_op, decoder_op, discriminator_op],
                       checkpoint_path=CHECKPOINT_PATH, max_checkpoints=1)

    def _encode(self, mean, log_var):
        epsilon = tf.random_normal(tf.shape(mean), name='Epsilon')

        return mean + tf.exp(0.5 * log_var) * epsilon

    def _get_z(self, shape):
        z = tf.random_normal(shape, name='RandomZ')

        return reshape(z, (-1, self.latent_dim))

    def _get_kullback_leibler_divergence(self, mean, log_var):
        square_mean = tf.pow(mean, 2)
        variance = tf.exp(log_var)

        kullback_leibler_divergence = \
            tf.reduce_mean(1 + log_var - square_mean - variance,
                          reduction_indices=1)
        kullback_leibler_divergence = \
            - 0.5 * self.kullback_leibler_ratio * kullback_leibler_divergence

        return kullback_leibler_divergence

    def _get_mean_square(self, prediction, truth):
        return tf.reduce_mean(tf.squared_difference(prediction, truth),
                             reduction_indices=(1, 2, 3))

    def _get_optimizer(self, type_str):
        if type_str in ['vae', 'encoder', 'decoder']:
            learning_rate = self.vae_learning_rate
            beta1 = self.vae_beta1
        else: # 'discriminator'
            learning_rate = self.discriminator_learning_rate
            beta1 = self.discriminator_beta1
        opt = adam(learning_rate=learning_rate, beta1=beta1)

        return opt.get_tensor()

    def _get_trainable_variables(self, scope):
        return [v for v in tflearn.get_all_trainable_variable()
                if scope + '/' in v.name]

    def _get_input_tensor_by_name(self, name):
        return tf.get_collection(tf.GraphKeys.INPUTS, scope=name)[0]

    def train(self, x, n_sample=None, pretrain_vae_epoch=1, 
              pretrain_discriminator_epoch=1, train_epoch=10):
        if n_sample == None:
            n_sample = x.shape[0]
        is_true = np.tile([0., 1.], [n_sample, 1])
        is_false = np.tile([1., 0.], [n_sample, 1])

        encoder = Encoder(self.n_first_channel, self.n_layer, self.latent_dim)
        decoder = Decoder(self.img_shape, self.n_first_channel, self.n_layer)
        discriminator = Discriminator(self.n_first_channel, self.n_layer)

        with tf.Graph().as_default():
            self.vae_pretrainer = self._build_vae_pretrainer(encoder, decoder)
            trainer = self.vae_pretrainer

            input_tensor = self._get_input_tensor_by_name('input_x')
            feed_dict = {input_tensor:x}
            trainer.fit(feed_dict, n_epoch=pretrain_vae_epoch,
                        snapshot_epoch=True, shuffle_all=True,
                        run_id='VAE_pretrain')
            self.trained_values[encoder.scope] = \
                self._get_trained_values(trainer, encoder.scope)
            self.trained_values[decoder.scope] = \
                self._get_trained_values(trainer, decoder.scope)

        with tf.Graph().as_default():
            self.discriminator_pretrainer = \
                self._build_discriminator_pretrainer(decoder, discriminator)
            trainer = self.discriminator_pretrainer
            self._assign_values(trainer, decoder.scope)

            input_tensor = self._get_input_tensor_by_name('input_x')
            true_tensor = self._get_input_tensor_by_name('is_true')
            false_tensor = self._get_input_tensor_by_name('is_false')
            feed_dict = {input_tensor:x,
                         true_tensor:is_true,
                         false_tensor:is_false}
            trainer.fit(feed_dict, n_epoch=pretrain_discriminator_epoch,
                        snapshot_epoch=True, shuffle_all=True,
                        run_id='Discriminator_pretrain')
            self.trained_values[discriminator.scope] = \
                self._get_trained_values(trainer, discriminator.scope)

        with tf.Graph().as_default():
            self.trainer = self._build_trainer(encoder, decoder, discriminator)
            trainer = self.trainer
            self._assign_values(trainer, encoder.scope)
            self._assign_values(trainer, decoder.scope)
            self._assign_values(trainer, discriminator.scope)
            self._set_decoder(decoder)

            input_tensor = self._get_input_tensor_by_name('input_x')
            true_tensor = self._get_input_tensor_by_name('is_true')
            false_tensor = self._get_input_tensor_by_name('is_false')
            feed_dict = {input_tensor:x,
                         true_tensor:is_true,
                         false_tensor:is_false}
            self.trainer.fit([feed_dict] * 3, n_epoch=train_epoch,
                        snapshot_step=1000, snapshot_epoch=False,
                        shuffle_all=True, run_id='VAEGAN',
                        callbacks=[CustomCallback(self)])

    def _get_trained_values(self, trainer, scope):
        return {v.name:tflearn.variables.get_value(v, session=trainer.session)
                for v in self._get_trainable_variables(scope)}

    def _assign_values(self, trainer, scope):
        [trainer.session.run(v.assign(self.trained_values[scope][v.name]))
         for v in self._get_trainable_variables(scope)]

    def _set_decoder(self, decoder):
        with self.decoder_graph.as_default():
            inputs = input_data(shape=(None, self.latent_dim))
            net = decoder(inputs)
            self.decoder = DNN(net)

    def decode(self, z):
        with self.decoder_graph.as_default():
            return self.decoder.predict(z)

class Encoder(object):
    def __init__(self, n_first_channel, n_layer, latent_dim):
        self.n_first_channel = n_first_channel
        self.n_layer = n_layer
        self.latent_dim = latent_dim
        self.scope = 'Encoder'

    def __call__(self, x, reuse=False):
        net = x

        for i in range(self.n_layer):
            n_channel = self.n_first_channel * 2 ** i
            net = conv(net, n_channel, 4, strides=2, reuse=reuse,
                       scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
            net = bn(net, reuse=reuse,
                     scope='{s}/BN_{n}'.format(s=self.scope, n=i))
            net = relu(net)
            # net = softplus(net)
        mean = fc(net, self.latent_dim, reuse=reuse,
                  scope='{s}/Mean'.format(s=self.scope))
        log_var = fc(net, self.latent_dim, reuse=reuse,
                     scope='{s}/LogVariance'.format(s=self.scope))

        return mean, log_var

class Decoder(object):
    def __init__(self, img_shape, n_first_channel, n_layer):
        self.img_size = img_shape[:2]
        self.color_channel = img_shape[2]
        self.n_first_channel = n_first_channel * 2 ** (n_layer - 1)
        self.n_layer = n_layer
        self.scope = 'Decoder'

    def __call__(self, z, reuse=False):
        net = z

        feature_height = self.img_size[0] // 2 ** self.n_layer
        feature_width = self.img_size[1] // 2 ** self.n_layer
        feature_channel = self.n_first_channel

        n_units = feature_height * feature_width * feature_channel
        net = fc(net, n_units, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
        shape = [-1, feature_height, feature_width, feature_channel]
        net = reshape(net, shape)

        for i in range(self.n_layer):
            feature_height *= 2
            feature_width *= 2
            if i < self.n_layer - 1:
                feature_channel //= 2
            else:
                feature_channel = self.color_channel

            net = bn(net, reuse=reuse,
                     scope='{s}/BN_{n}'.format(s=self.scope, n=i))
            net = relu(net)
            # net = elu(net)
            net = conv_t(net, feature_channel, 4,
                         [feature_height, feature_width], strides=2,
                         reuse=reuse,
                         scope='{s}/ConvT_{n}'.format(s=self.scope, n=i))

        net = sigmoid(net)

        return net

class Discriminator(object):
    def __init__(self, n_first_channel, n_layer):
        self.n_first_channel = n_first_channel
        self.n_layer = n_layer
        self.scope = 'Discriminator'

    def __call__(self, x, reuse=False):
        net = x

        for i in range(self.n_layer):
            net = conv(net, self.n_first_channel * 2 ** i, 4, reuse=reuse,
                       scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
            net = max_pool(net, 2)
            net = bn(net, reuse=reuse, 
                     scope='{s}/BN_{n}'.format(s=self.scope, n=i))
            # net = relu(net)
            net = elu(net)
            if i == self.n_layer - 1:
                feature_reconstruction = net

        net = fc(net, 2, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
        net = softmax(net)

        return net, feature_reconstruction

class CustomCallback(tflearn.callbacks.Callback):
    def __init__(self, model, n_side=10):
        self.model = model
        self.n_side = n_side
        self.sample_z = np.random.normal(size=(n_side ** 2, model.latent_dim))

    def _save(self, name, z):
        model = self.model
        n_side = self.n_side
        img_height = model.img_shape[0]
        img_width = model.img_shape[1]
        img_channel = model.img_shape[2]
        image = np.ndarray(shape=(n_side * img_height,
                                  n_side * img_width,
                                  img_channel),
                           dtype=np.float32)

        model.trained_values = {
            scope:model._get_trained_values(model.trainer, scope)
            for scope in model.trained_values}
        with model.decoder_graph.as_default():
            [model._assign_values(model.decoder, scope)
             for scope in model.trained_values]
        decoded = model.decode(z)

        for y in range(n_side):
            for x in range(n_side):
                image[y * img_height : (y + 1) * img_height,
                      x * img_width : (x + 1) * img_width,
                      :] = decoded[x + y * n_side]
        image = np.clip(image, 0, 1)
        image *= 255
        io.imsave(name, image.astype(np.uint8))

    def on_batch_end(self, training_state, snapshot=False):
        if snapshot:
            step = training_state.step

            file_name = '{path}image-{step}.png'.format(path=CHECKPOINT_PATH,
                                                        step=step)
            self._save(file_name, self.sample_z)

    def on_train_end(self, training_state):
        latent_dim = self.model.latent_dim

        sample_z = np.ndarray(shape=(self.n_side ** 2, latent_dim),
                              dtype=np.float32)
        for row in range(self.n_side):
            start = np.random.normal(size=latent_dim)
            stop = np.random.normal(size=latent_dim)
            z_rows = np.array([np.linspace(start[i], stop[i], num=self.n_side)
                               for i in range(latent_dim)]).T
            sample_z[row * self.n_side : (row + 1) * self.n_side, :] = z_rows

        file_name = '{path}image-final.png'.format(path=CHECKPOINT_PATH)
        self._save(file_name, sample_z)

(X, Y), (testX, testY) = tflearn.datasets.cifar10.load_data()
X = np.concatenate((X, testX), axis=0)
Y = np.concatenate((Y, testY), axis=0)
X = X[Y == 1]

img_shape = X.shape[1:]

vaegan = VAEGAN(img_shape=img_shape, n_first_channel=64, n_layer=4,
                latent_dim=32, kullback_leibler_ratio=0.01,
                reconstruction_weight_against_detail=50.0)
vaegan.train(X, pretrain_vae_epoch=1, pretrain_discriminator_epoch=10, 
             train_epoch=100)

参考サイト

VAEGAN
fauxtograph

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11