9
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[TensorFlow] 本当は怖いGradientTape

Posted at

はじめに

TensorFlow 2.xで自動微分をするときに使う tf.GradientTape ですが、エラーにならないのになぜか学習がうまくいかないという事態に遭遇したので、事の次第をメモ(そんな大げさなことでもありません)。

検証環境

  • Ubuntu 18.04
  • Python 3.6.9
  • TensorFlow 2.1.0 (CPU)

事例集

GradientTape がうまく動作する例、しない例を見ていきましょう。

OKな例

import tensorflow as tf

x = tf.Variable(1.0)
opt = tf.keras.optimizers.SGD(lr=0.1)

@tf.function
def step():
    with tf.GradientTape() as tape:
        square_x = x ** 2
        L = square_x
    grad = tape.gradient(L, x)
    tf.print(grad)
    opt.apply_gradients([(grad, x)])
    tf.print(x)

step()
# 2
# 0.8

$L(x) = x^2$ を損失関数として $x_0 = 1, x_{t+1} = x_t - 0.1 L'(x_t)$ で $x_t$ を更新することを想定しています。
$L'(x) = 2x$ なので $L'(1) = 2$ となり、$x_1 = 1.0 - 0.1 \times 2 = 0.8$ となります。
step() を繰り返し呼び出すと x ($x_t$) は0に収束します。

なお tf.print() についてはこちらでも記事を書いています。
[TensorFlow 2] グラフモードでTensorの中身を確認する方法 - Qiita

NGな例(エラーで気づける)

import tensorflow as tf

x = tf.Variable(1.0)
opt = tf.keras.optimizers.SGD(lr=0.1)

@tf.function
def step():
    square_x = x ** 2
    with tf.GradientTape() as tape:
        L = square_x
    grad = tape.gradient(L, x)
    tf.print(grad)
    opt.apply_gradients([(grad, x)])
    tf.print(x)

step()
# ValueError: No gradients provided for any variable: ['Variable:0'].

先ほどと同じく $L(x) = x^2$ を損失関数として $x_0 = 1, x_{t+1} = x_t - 0.1 L'(x_t)$ で $x$ を更新するように定義したつもりでしたが、with tf.GradientTape() の中で x を使った計算がされていないので、x で微分できないと怒られます。
でもこれはエラーで気づけるのでまだいいです。

NGな例(気づかずハマった)

エラーも出ないのになぜか結果がおかしくて悩む例。

import tensorflow as tf

x = tf.Variable(1.0)
y = tf.Variable(1.0)
opt = tf.keras.optimizers.SGD(lr=0.1)

@tf.function
def step():
    square_x = x ** 2
    with tf.GradientTape() as tape:
        square_x = square_x + 0 * x # Dummy
        square_y = y ** 2
        L = square_x + square_y
    grad = tape.gradient(L, [x, y])
    tf.print(grad)
    opt.apply_gradients(zip(grad, [x, y]))
    tf.print("L=", L, "x=", x, "y=", y)

step()
# [0, 2] ??
# L= 2 x= 1 y= 0.8

実際はもっと複雑なコードでしたが、問題の箇所だけを抜き出して簡略化しています。

$L(x, y) = x^2 + y^2$ の値を0に近づけることを想定していて、$\frac{\partial L}{\partial x} = 2x, \frac{\partial L}{\partial y} = 2y$ となると思っていたら、y は0に近づいていくのになぜか x の値が1から微動だにしません。

種明かしをすると、with tf.GradientTape() の中に xy を書いたときのみ微分の計算対象になるので、実は $\frac{\partial L}{\partial x} = 0$ になってしまっているよ、という話なのでした。
その証拠に、step() を何回呼んでも tf.print(grad) の実行結果の第1成分($\frac{\partial L}{\partial x}$ を表します)は常に0です。本来は1回目の呼び出しに対する出力が [2, 2] にならないといけないはず。
当の本人は square_x を使っているので、てっきり x で微分ができていると思い込んでいるのですね…。

square_x = square_x + 0 * x # Dummy

の行を取り除くと、GradientTapex の情報が入っていないという趣旨のWARNINGが出ますが、**y では微分できているのでエラーで止められることはありません。**まして、このダミーコードのように x を使った計算が GradientTape に1つでも入っていると、抜けている計算があってもWARNINGすら出ずに通ってしまいます。

つまり**「なぜかうまく収束しない」と思ってあれこれデバッグしてようやく問題に気づくということに…。(実話)**

対処法?

今回の場合、square_x の計算が手違いで GradientTape に入っていなかったために、 $\frac{\partial L}{\partial x} = 0$ になってしまっていたことがバグの原因でした。
そういうことにならないようにするための対処として、考えられるのは

  • 学習をKeras APIに任せる(自分で自動微分を書かない)
  • TF 1.x方式で学習する。tf.compat.v1 経由でAPIを呼び出す(これはあまりやりたくない)
  • GradientTape のブロックで囲む範囲を広めにとる(むやみに何でも記録するのは得策でないので、少しずつ範囲を狭くしていっておかしくならないかを見極める)
  • GradientTape の中身を何らかの手段でデバッグする

あたりかと思います。
ただ GradientTape の具体的な処理はCのライブラリを呼び出している(pywrap_tensorflow.TFE_Py_*)ようなので、Pythonから手軽にデバッグというわけにはいかないかもしれません…。

9
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?