custom_graident を使用する場合の通常の書き方
@tf.custom_gradient
def gradient_reversal(x):
y = x
def grad(dy):
return - dy
return y, grad
# model 内で使用する場合
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
def call(self, x):
return gradient_reversal(x)
custom_gradient 内でスコープ外の変数(self等)を使用したい場合
class MyModel2(tf.keras.Model):
def __init__(self):
super(MyModel2, self).__init__()
self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())
@tf.custom_gradient
def forward(self, x):
y = self.alpha * x
def backward(w, variables=None):
with tf.GradientTape() as tape:
tape.watch(w)
z = - self.alpha * w
grads = tape.gradient(z, [w])
return z, grads
return y, backward
def call(self, x):
return self.forward(x)
-
ドキュメント 内での引数が
dy
になっているので計算済みのものが渡ってくるかと思いきや、Backpropagation 時の実行関数として指定できる (上記のbackward
メソッド) -
bakward
メソッド外のスコープの変数を使用する場合は、variables=None
を受け取るようにしないと以下のようなエラーが発生する (ドキュメントのArgs内でも説明されている)
TypeError: If using @custom_gradient with a function that uses variables, then grad_fn must accept a keyword argument 'variables'.
検証用コード
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())
@tf.custom_gradient
def forward(self, x):
y = self.alpha * x
tf.print("forward")
tf.print(" y: ", y)
def backward(w, variables=None):
z = self.alpha * w
tf.print("backward")
tf.print(" z: ", z)
tf.print(" variables: ", variables)
return z, variables
return y, backward
def call(self, x):
return self.forward(x)
class MyModel2(tf.keras.Model):
def __init__(self):
super(MyModel2, self).__init__()
self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())
@tf.custom_gradient
def forward(self, x):
y = self.alpha * x
tf.print("forward")
tf.print(" y: ", y)
def backward(w, variables=None):
with tf.GradientTape() as tape:
tape.watch(w)
z = - self.alpha * w
grads = tape.gradient(z, [w])
tf.print("backward")
tf.print(" z: ", z)
tf.print(" variables: ", variables)
tf.print(" alpha: ", self.alpha)
tf.print(" grads: ", grads)
return z, grads
return y, backward
def call(self, x):
return self.forward(x)
for model in [MyModel(), MyModel2()]:
print()
print()
print()
print(model.name)
for i in range(10):
with tf.GradientTape() as tape:
x = tf.Variable(1.0, tf.float32)
y = model(x)
grads = tape.gradient(y, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
tf.print("step")
tf.print(" y:", y)
tf.print(" grads:", grads)
print()