LoginSignup
7
6

More than 3 years have passed since last update.

動くConvolutional VAEコード

Last updated at Posted at 2020-10-11

はじめに

Tensorflowが2.0となりKerasが統合されました。

参考記事 Tensorflow 2.0 with Keras

その結果、これまでkerasで書かれた畳み込み変分オートエンコーダー(Convolutional Variational Auto Encoder)のコードが動かない事情が発生しました。そこで、いくつの最新情報を集め、とりあえず動くコードを作成し、アップしておきます。

環境

tensorflow==2.1

コード

VAE_202010_tf21.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Lambda, Conv2D, Flatten, Conv2DTranspose, Reshape
from tensorflow.keras import backend as K
from tensorflow.keras import losses
from tensorflow.keras.optimizers import Adam

#01. Datasets
(x_train, _), (x_test, _) = mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
print('mnist_train',mnist_digits.shape)
print('x_train',x_train.shape)


#0-1正規化
x_train = x_train / 255.0
x_test = x_test / 255.0


#2 Setting Autoencoder
epochs = 100
batch_size = 256
n_z = 2  # 潜在変数の数(次元数)


#3 潜在変数をサンプリングするための関数
# args = [z_mean, z_log_var]
def func_z_sample(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=K.shape(z_log_var), mean=0, stddev=1)
    return z_mean + epsilon * K.exp(z_log_var/2)

# VAE Network
# Building the Enocoder
encoder_inputs = Input(shape=(28,28,1))
x = Conv2D(32,3,activation='relu',strides=2,padding='same')(encoder_inputs)
x = Conv2D(64,3,activation='relu',strides=2, padding='same')(x)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(n_z, name='z_mean')(x)
z_log_var = Dense(n_z, name='z_log_var')(x)
z = Lambda(func_z_sample, output_shape=(n_z))([z_mean, z_log_var])
encoder = Model(encoder_inputs, [z_mean, z_log_var,z], name='encoder')
encoder.summary()

# Building the Decoder
latent_inputs = Input(shape=(n_z,))
x = Dense(7*7*64, activation='relu')(latent_inputs)
x = Reshape((7,7,64))(x)
x = Conv2DTranspose(64,3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32,3, activation='relu', strides=2, padding='same')(x)
decoder_outputs = Conv2DTranspose(1,3, activation='sigmoid', padding='same')(x)
decoder = Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()

class VAE(Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self,data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                losses.binary_crossentropy(data,reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var -tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return{
            "loss": total_loss,
            "reconstruction loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

vae = VAE(encoder, decoder)
vae.compile(optimizer=Adam())
vae.fit(mnist_digits, epochs=epochs, batch_size=batch_size)

#Display how the latent space clusters different digit classes
def plot_label_clusters(encoder, decoder, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


(x_train, y_train), _ = mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(encoder, decoder, x_train, y_train)

結果

2次元の潜在空間での0~9の分布
121010596_10220811797558378_9181159504384895825_o.jpg

参考資料

  1. Variational AutoEncoder by F. Chollet
  2. AIパーフェクトマスター講座
  3. Generative Deep Learning
7
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
6