Variatinal AutoEoncoderを使う機会がありまして、tensorflowで書いてるコードをいろいろなサイトで見ていたのですが、だいたいのサイトでは、Sequential APIかFunctional APIで書いている気がします。なぜだ。。。
そこでSubclassing APIで書いてみようとしました。
Variatinal AutoEoncoderを理解したい方は下のサイトへGO
Variational Autoencoder徹底解説
この記事では、ほぼコードだけ
tensorflow公式のVAEコード
上記のコードでは、Sequentialを使ってるのかな。これを一部変えてSubclassingにします。
Reparameterization Trick、エンコーダー、デコーダーに分けて書きます。
Reparameterization Trick
class Reparameterize(tf.keras.layers.Layer):
def call(self,inputs):
mean, logvar = inputs
batch = tf.shape(mean)[0]
dim = tf.shape(mean)[1]
eps = tf.keras.backend.random_normal(shape=(batch, dim),mean=0., stddev=1)
return mean + tf.exp(0.5*logvar) * eps
エンコーダー
class Encode(tf.keras.layers.Layer):
def __init__(self , latent_dim):
super(Encode, self).__init__()
self.conv_layer_1 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2),padding='same', activation='relu')
self.conv_layer_2 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2),padding='same', activation='relu')
self.flatten_layer = tf.keras.layers.Flatten()
self.dense_1 = tf.keras.layers.Dense(32,activation='relu')
self.dense_mean = tf.keras.layers.Dense(latent_dim)
self.dense_logvar = tf.keras.layers.Dense(latent_dim)
self.sampling = Reparameterize()
def call(self,inputs):
#pdb.set_trace()
cx = self.conv_layer_1(inputs)
cx = self.conv_layer_2(cx)
x = self.flatten_layer(cx)
x = self.dense_1(x)
mean = self.dense_mean(x)
logvar = self.dense_logvar(x)
z = self.sampling((mean, logvar))
return mean,logvar,z
デコーダー
class Decode(tf.keras.layers.Layer):
def __init__(self):
super(Decode,self).__init__()
#元の次元に戻している
self.dense_z_input = tf.keras.layers.Dense(units = 7*7*64, activation='relu')
self.reshape_layer = tf.keras.layers.Reshape((7,7,64))
self.conv_transpose_layer_1 = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=(3,3), strides=(2,2),padding='same',activation = 'relu')
self.conv_transpose_layer_2 = tf.keras.layers.Conv2DTranspose(filters=16, kernel_size=(3,3), strides=(2,2),padding='same',activation = 'relu')
self.conv_transpose_layer_3 = tf.keras.layers.Conv2DTranspose(filters=1, kernel_size=(3,3), padding='same',activation = 'sigmoid')
def call(self,z):
#pdb.set_trace()
x_output = self.dense_z_input(z)
x_output = self.reshape_layer(x_output)
x_output = self.conv_transpose_layer_1(x_output)
x_output = self.conv_transpose_layer_2(x_output)
x_output = self.conv_transpose_layer_3(x_output)
return x_output
それぞれ書きました。これを最後に一つにします。
統合
class VAE(tf.keras.Model):
def __init__(self,latent_dim):
super(VAE,self).__init__()
self.encoder = Encode(latent_dim)
self.decoder = Decode()
これでモデルは一応動きます。あとは損失関数とか加えたら動きます。今度、そこらへんしっかり書こうかな。
Subclassingでどう書くかはいろいろなサイトで見てやりました。__init__関数に使うレイヤーを書いて、call関数でそれを呼び出す。これだけなのかな?そもそもこれはSubclassingなのかどうか自分ではわかりませんが、とりあえずこんな感じなのではないでしょうか。もし違ってたらタイトル詐欺になる。。。