0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TensorFlowメモ: 複数のoptimizerをtf.functionで使うときに出る"ValueError: tf.function-decorated function tried to create variables on non-first call."

Last updated at Posted at 2021-04-01
ValueError: tf.function-decorated function tried to create variables on non-first call.

@tf.functionを使ってtrain_stepを高速化しようとしていたところ,
optimizer.apply_gradients() を使うところで上記エラーが出る場合がありました.

をみて理解した解決策のメモです.

環境

python: 3.8.0
tensorflow: 2.4.0

コード

2つのモデルとオプティマイザを作ってそれぞれを@tf.functionでデコレートされたtrain_step関数を使って学習させようとしています.2つめのモデルを学習しようとするところで上記エラーが出ます.

import tensorflow as tf


def compute_loss(y,t):
  return tf.reduce_mean(tf.square(y-t))

@tf.function
def train_step(model, optimizer, x, t):
  with tf.GradientTape() as tape:
    y = model(x)
    l = compute_loss(y,t)
    
  grad =tape.gradient(l, model.trainable_variables)
  optimizer.apply_gradients(zip(grad, model.trainable_variables))


model_1 = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model_1.build(input_shape=[None,2])
optimizer_1 = tf.keras.optimizers.Adam()


model_2 = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model_2.build(input_shape=[None,2])
optimizer_2 = tf.keras.optimizers.Adam()

x = tf.random.uniform(shape=(3,2))
t = tf.random.uniform(shape=(3,2))

train_step(model_1, optimizer_1, x,t )# ここではErrorはでない.

# ここでError
train_step(model_2, optimizer_2, x,t )

原因

@tf.functionは事前に計算グラフを作っておき,グラフモードで計算するためのデコレータ(公式ガイド )

  • @tf.function でデコレートされた関数は1度目の呼び出しの時,グラフを作成する(トレーシング).入力データのサイズが異なる場合など,必要に応じて再トレーシングが行われる.
  • 2回目以降の呼び出しのとき,内部で新たに変数を生成させようとするとエラーになる.
  • tf.kerasのoptimizerは初めて呼び出されるときに,内部変数を作る.(Modelのbuildと同じようなことを行う)
  • よって新しいoptimizerを既に使ったことがあるデコレートされた関数に入れて使おうとすると,optimizerが内部変数を生成しようとしてエラーとなる.

また,model_2 のbuildをしていない場合もエラーがでます.その場合はmodel2.buildを事前に行うなど,事前にmodelの内部変数を生成しておけばエラーはでません.しかし一方,

  • optimizerにはmodel.buildのような仕組みが備わっていないので,あらかじめ内部変数を作っておく簡単なやり方がない.

ので,下記のような回避策をとっておく必要があるようです.

回避策

それぞれのモデル,オプティマイザの組ごとに,別の関数を作っておくというのがよいようです.

1 ラッパーを使う

def wrapper():
  @tf.function
  def train_step(model, optimizer, x, t):
    with tf.GradientTape() as tape:
      y = model(x)
      l = compute_loss(y,t)

    grad =tape.gradient(l, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))
  return train_step


train_step_1 = wrapper()
train_step_2 = wrapper()

train_step_1(model_1, optimizer_1, x,t )
train_step_2(model_2, optimizer_2, x,t )

2 @デコレータを利用せず,モデルごとにtf.function()でデコレートされた関数を作る.

def train_step(model, optimizer, x, t):
  with tf.GradientTape() as tape:
    y = model(x)
    l = compute_loss(y,t)
    
  grad =tape.gradient(l, model.trainable_variables)
  optimizer.apply_gradients(zip(grad, model.trainable_variables))


train_step_1 = tf.function(train_step)
train_step_2 = tf.function(train_step)

train_step_1(model_1, optimizer_1, x,t )
train_step_2(model_2, optimizer_2, x,t )
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?