1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Backpropagation (tf.custom_gradient) で self を使用したい (tensorflow)

Last updated at Posted at 2020-05-30
custom_graident を使用する場合の通常の書き方
@tf.custom_gradient
def gradient_reversal(x):
  y = x
  def grad(dy):
    return - dy
  return y, grad

# model 内で使用する場合
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

    def call(self, x):
        return gradient_reversal(x)
custom_gradient 内でスコープ外の変数(self等)を使用したい場合
class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)
  • ドキュメント 内での引数が dy になっているので計算済みのものが渡ってくるかと思いきや、Backpropagation 時の実行関数として指定できる (上記の backward メソッド)

  • bakward メソッド外のスコープの変数を使用する場合は、 variables=None を受け取るようにしないと以下のようなエラーが発生する (ドキュメントのArgs内でも説明されている)

TypeError: If using @custom_gradient with a function that uses variables, then grad_fn must accept a keyword argument 'variables'.

検証用コード

import tensorflow as tf


optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            z = self.alpha * w
            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            return z, variables

        return y, backward

    def call(self, x):
        return self.forward(x)


class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])

            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            tf.print("  alpha: ", self.alpha)
            tf.print("  grads: ", grads)
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)


for model in [MyModel(), MyModel2()]:
    print()
    print()
    print()
    print(model.name)
    for i in range(10):
        with tf.GradientTape() as tape:
            x = tf.Variable(1.0, tf.float32)
            y = model(x)

        grads = tape.gradient(y, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        tf.print("step")
        tf.print("  y:", y)
        tf.print("  grads:", grads)
        print()

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?