はじめに
TensorFlow 2.xで自動微分をするときに使う tf.GradientTape
ですが、エラーにならないのになぜか学習がうまくいかないという事態に遭遇したので、事の次第をメモ(そんな大げさなことでもありません)。
検証環境
- Ubuntu 18.04
- Python 3.6.9
- TensorFlow 2.1.0 (CPU)
事例集
GradientTape
がうまく動作する例、しない例を見ていきましょう。
OKな例
import tensorflow as tf
x = tf.Variable(1.0)
opt = tf.keras.optimizers.SGD(lr=0.1)
@tf.function
def step():
with tf.GradientTape() as tape:
square_x = x ** 2
L = square_x
grad = tape.gradient(L, x)
tf.print(grad)
opt.apply_gradients([(grad, x)])
tf.print(x)
step()
# 2
# 0.8
$L(x) = x^2$ を損失関数として $x_0 = 1, x_{t+1} = x_t - 0.1 L'(x_t)$ で $x_t$ を更新することを想定しています。
$L'(x) = 2x$ なので $L'(1) = 2$ となり、$x_1 = 1.0 - 0.1 \times 2 = 0.8$ となります。
step()
を繰り返し呼び出すと x
($x_t$) は0に収束します。
なお tf.print()
についてはこちらでも記事を書いています。
[TensorFlow 2] グラフモードでTensorの中身を確認する方法 - Qiita
NGな例(エラーで気づける)
import tensorflow as tf
x = tf.Variable(1.0)
opt = tf.keras.optimizers.SGD(lr=0.1)
@tf.function
def step():
square_x = x ** 2
with tf.GradientTape() as tape:
L = square_x
grad = tape.gradient(L, x)
tf.print(grad)
opt.apply_gradients([(grad, x)])
tf.print(x)
step()
# ValueError: No gradients provided for any variable: ['Variable:0'].
先ほどと同じく $L(x) = x^2$ を損失関数として $x_0 = 1, x_{t+1} = x_t - 0.1 L'(x_t)$ で $x$ を更新するように定義したつもりでしたが、with tf.GradientTape()
の中で x
を使った計算がされていないので、x
で微分できないと怒られます。
でもこれはエラーで気づけるのでまだいいです。
NGな例(気づかずハマった)
エラーも出ないのになぜか結果がおかしくて悩む例。
import tensorflow as tf
x = tf.Variable(1.0)
y = tf.Variable(1.0)
opt = tf.keras.optimizers.SGD(lr=0.1)
@tf.function
def step():
square_x = x ** 2
with tf.GradientTape() as tape:
square_x = square_x + 0 * x # Dummy
square_y = y ** 2
L = square_x + square_y
grad = tape.gradient(L, [x, y])
tf.print(grad)
opt.apply_gradients(zip(grad, [x, y]))
tf.print("L=", L, "x=", x, "y=", y)
step()
# [0, 2] ??
# L= 2 x= 1 y= 0.8
実際はもっと複雑なコードでしたが、問題の箇所だけを抜き出して簡略化しています。
$L(x, y) = x^2 + y^2$ の値を0に近づけることを想定していて、$\frac{\partial L}{\partial x} = 2x, \frac{\partial L}{\partial y} = 2y$ となると思っていたら、y
は0に近づいていくのになぜか x
の値が1から微動だにしません。
種明かしをすると、with tf.GradientTape()
の中に x
や y
を書いたときのみ微分の計算対象になるので、実は $\frac{\partial L}{\partial x} = 0$ になってしまっているよ、という話なのでした。
その証拠に、step()
を何回呼んでも tf.print(grad)
の実行結果の第1成分($\frac{\partial L}{\partial x}$ を表します)は常に0です。本来は1回目の呼び出しに対する出力が [2, 2]
にならないといけないはず。
当の本人は square_x
を使っているので、てっきり x
で微分ができていると思い込んでいるのですね…。
square_x = square_x + 0 * x # Dummy
の行を取り除くと、GradientTape
に x
の情報が入っていないという趣旨のWARNINGが出ますが、**y
では微分できているのでエラーで止められることはありません。**まして、このダミーコードのように x
を使った計算が GradientTape
に1つでも入っていると、抜けている計算があってもWARNINGすら出ずに通ってしまいます。
つまり**「なぜかうまく収束しない」と思ってあれこれデバッグしてようやく問題に気づくということに…。(実話)**
対処法?
今回の場合、square_x
の計算が手違いで GradientTape
に入っていなかったために、 $\frac{\partial L}{\partial x} = 0$ になってしまっていたことがバグの原因でした。
そういうことにならないようにするための対処として、考えられるのは
- 学習をKeras APIに任せる(自分で自動微分を書かない)
- TF 1.x方式で学習する。
tf.compat.v1
経由でAPIを呼び出す(これはあまりやりたくない) -
GradientTape
のブロックで囲む範囲を広めにとる(むやみに何でも記録するのは得策でないので、少しずつ範囲を狭くしていっておかしくならないかを見極める) -
GradientTape
の中身を何らかの手段でデバッグする
あたりかと思います。
ただ GradientTape
の具体的な処理はCのライブラリを呼び出している(pywrap_tensorflow.TFE_Py_*
)ようなので、Pythonから手軽にデバッグというわけにはいかないかもしれません…。