0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Variatinal AutoEoncoderをSubclassing APIで書いてみた

Posted at

Variatinal AutoEoncoderを使う機会がありまして、tensorflowで書いてるコードをいろいろなサイトで見ていたのですが、だいたいのサイトでは、Sequential APIかFunctional APIで書いている気がします。なぜだ。。。
そこでSubclassing APIで書いてみようとしました。

Variatinal AutoEoncoderを理解したい方は下のサイトへGO
Variational Autoencoder徹底解説

この記事では、ほぼコードだけ

tensorflow公式のVAEコード
上記のコードでは、Sequentialを使ってるのかな。これを一部変えてSubclassingにします。

Reparameterization Trick、エンコーダー、デコーダーに分けて書きます。

Reparameterization Trick

class Reparameterize(tf.keras.layers.Layer):

  def call(self,inputs):
    mean, logvar = inputs
    batch = tf.shape(mean)[0]
    dim = tf.shape(mean)[1]
    eps = tf.keras.backend.random_normal(shape=(batch, dim),mean=0., stddev=1)
    return mean + tf.exp(0.5*logvar) * eps	

エンコーダー

class Encode(tf.keras.layers.Layer):

  def __init__(self , latent_dim):
    super(Encode, self).__init__()
  
    self.conv_layer_1 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2),padding='same', activation='relu')
    self.conv_layer_2 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2),padding='same', activation='relu')


    self.flatten_layer = tf.keras.layers.Flatten()
    self.dense_1 = tf.keras.layers.Dense(32,activation='relu')
    self.dense_mean = tf.keras.layers.Dense(latent_dim)
    self.dense_logvar = tf.keras.layers.Dense(latent_dim)
    self.sampling = Reparameterize()

  def call(self,inputs):
    #pdb.set_trace()
    cx = self.conv_layer_1(inputs)
    cx = self.conv_layer_2(cx)
    x = self.flatten_layer(cx)
    x = self.dense_1(x)
    mean = self.dense_mean(x)
    logvar = self.dense_logvar(x)
    z = self.sampling((mean, logvar))
    return mean,logvar,z

デコーダー

class Decode(tf.keras.layers.Layer):

  def __init__(self):
    super(Decode,self).__init__()
    #元の次元に戻している
    self.dense_z_input = tf.keras.layers.Dense(units = 7*7*64, activation='relu')
    self.reshape_layer = tf.keras.layers.Reshape((7,7,64))
    self.conv_transpose_layer_1 = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=(3,3),  strides=(2,2),padding='same',activation = 'relu')
    self.conv_transpose_layer_2 = tf.keras.layers.Conv2DTranspose(filters=16, kernel_size=(3,3),  strides=(2,2),padding='same',activation = 'relu')
    self.conv_transpose_layer_3 = tf.keras.layers.Conv2DTranspose(filters=1, kernel_size=(3,3),  padding='same',activation = 'sigmoid')

  
  def call(self,z):
    #pdb.set_trace()
    x_output = self.dense_z_input(z)
    x_output = self.reshape_layer(x_output)
    x_output = self.conv_transpose_layer_1(x_output)
    x_output = self.conv_transpose_layer_2(x_output)
    x_output = self.conv_transpose_layer_3(x_output)
    return x_output

それぞれ書きました。これを最後に一つにします。

統合

class VAE(tf.keras.Model):

  def __init__(self,latent_dim):
    super(VAE,self).__init__()
    self.encoder = Encode(latent_dim)
    self.decoder = Decode()

これでモデルは一応動きます。あとは損失関数とか加えたら動きます。今度、そこらへんしっかり書こうかな。
Subclassingでどう書くかはいろいろなサイトで見てやりました。__init__関数に使うレイヤーを書いて、call関数でそれを呼び出す。これだけなのかな?そもそもこれはSubclassingなのかどうか自分ではわかりませんが、とりあえずこんな感じなのではないでしょうか。もし違ってたらタイトル詐欺になる。。。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?