0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

GradientTapeを使ってfitするとエラーが出た。そして解決した。

Last updated at Posted at 2022-04-02

GradientTapeを使ってfitするとエラー出た

Variatioanal AutoEncoderを実行する際に普段はself.add_loss()を使い学習させていました。
以下のようなコードです。

 #モデルを構築
class VAE(tf.keras.Model):

  def __init__(self,latent_dim):
    super(VAE,self).__init__()
    self.encoder = Encode(latent_dim)
    self.decoder = Decode()
  
  #lossを計算
  def VAE_loss(self,x):
    mean,logvar,z = self.encoder(x)
    x_sigmoid = self.decoder(z)

    #reshape
    shape = tf.shape(x)
    x = tf.reshape(x,[shape[0],-1])
    x_sigmoid = tf.reshape(x_sigmoid,[shape[0],-1])

    #復元誤差
    reconstruction_loss = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(x,x_sigmoid)
    #MNISTのshape
    image_shape=28*28
    reconstruction_loss *= image_shape
   
    #KL-divergence(正規分布との差)
    kl_divergence =-0.5* tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar),axis=1)
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_divergence)
    return vae_loss

  def call(self,inputs):
    loss = self.VAE_loss(inputs)
    self.add_loss(loss,inputs=inputs)
    return x

encoder、decoderのコードは以下から見てください。
Variatinal AutoEoncoderをSubclassing APIで書いてみた

call関数内でVAE_lossという独自の損失関数を定義し、そのlossをself.add_lossに入れます。これでmodel.compile、model.fitをすることでトレーニング開始してくれます。
しかし、このself.add_lossは、研究室で先輩方や同期に聞いても知らないという人が多くて、GradientTapeを使って学習させてると。ならば自分もGradientTapeを使おうと思ったわけですよ。
上のコードを書き換えます。

 #モデルを構築
class VAE(tf.keras.Model):

  def __init__(self,latent_dim):
    super(VAE,self).__init__()
    self.encoder = Encode(latent_dim)
    self.decoder = Decode()
  
  def call(self,x):
    mean,logvar,z = self.encoder(x)
    y_pred = self.decoder(z)
    return mean,logvar,z,y_pred
    
  #lossを計算
  def VAE_loss(self,x,x_sigmoid,mean,logvar,z):
    #reshape
    shape = tf.shape(x)
    x = tf.reshape(x,[shape[0],-1])
    x_sigmoid = tf.reshape(x_sigmoid,[shape[0],-1])
    #復元誤差
    reconstruction_loss = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(x,x_sigmoid)
    image_shape=28*28
    reconstruction_loss *= image_shape

    #KL-divergence(正規分布との差)
    kl_divergence =-0.5* tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar),axis=1)
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_divergence)
    return vae_loss

  def train_step(self,x):
    with tf.GradientTape() as tape:
        mean,logvar,z,y_pred= self(x, training=True)
        loss = self.VAE_loss(x,y_pred,mean,logvar,z)
    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    return {"loss":loss}

さて、これでOKと思い、実行すると

TypeError: Dimension value must be integer or None or have an __index__ method, got TensorShape([None, 28, 28, 1])

え?どゆこと?特におかしくなるようにいじったわけではないが、なぜかエラーが出る。Google様でこのエラー文を検索しても解決策が出てこない...

解決しました

さて問題が解決しました。なにがダメだったのかをみると、入力データの型でした。

 #train_stepに入る前の型
type(train_images)
<class 'numpy.ndarray'>
 #train_stepに入ったあとの型
type(x)
<class 'tuple'>

型がなんかtupleに変わってますね。なんでtupleになってんだ?

if isinstance(x, tuple):
  x = x[0]
 #TensorShape([None, 28, 28, 1])

これでtupleからTensorShapeに変わりました。
このコードをwith tf.GradientTape() as tape:の上に書くと、動きました。これに気づくまでかなり時間かかりました。ずっとほかの部分を見てました。同じようなエラーで悩んでる方は参考にしてください。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?