はじめに
TensorFlow (2.x) で学習時の損失関数やデータセットの変換部分に自作関数を使うときに、「思ったとおりの値が入っているのかな?」と確認しようと思って print()
を呼び出しても、値が出力されないときがあります。
tfdbg
などのデバッガーを使えば良いのでしょうが、もっとシンプルに、いわゆる「printデバッグ」をする方法をご紹介します。
tfdbg
については→tfdbg を使って Keras の nan や inf を潰す - Qiita
検証環境
- Ubuntu 18.04
- Python 3.6.9
- TensorFlow 2.1.0 (CPU)
以下のコードでは、↓が書かれているものとします。
import tensorflow as tf
Tensor に対して print() した場合の挙動
TensorFlow 2.xでは Eager Execution がデフォルトになりましたので、インタプリタで Tensor
を組み立てている限りは、普通に print()
すれば Tensor
の値が表示されますし、.numpy()
で ndarray
として値を得ることもできます。
値を確認できる場合
x = tf.constant(10)
y = tf.constant(20)
z = x + y
print(z) # tf.Tensor(30, shape=(), dtype=int32)
print(z.numpy()) # 30
値を確認できない場合 (1)
学習・評価時に逐次 Tensor
の値を評価すると遅いので、@tf.function
デコレータをつけることにより、グラフモードで処理する関数を定義することができます。グラフモードで定義した演算は計算グラフとしてまとめて(Pythonの外で)処理されるため、計算過程の値を見ることができません。
@tf.function
def dot(x, y):
tmp = x * y
print(tmp)
return tf.reduce_sum(tmp)
x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# Tensor("mul:0", shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)
dot()
の外で計算結果を出力することはできますが、中では値を見ることができません。
しかも、dot()
を複数回呼び出しても、中の print()
は基本的にはグラフを解析するタイミングで1回しか実行されません。
もちろん、この場合は @tf.function()
を取り除けば値が見えるようにはなります。
def dot(x, y):
tmp = x * y
print(tmp)
return tf.reduce_sum(tmp)
x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# tf.Tensor([3 2], shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor([0 4], shape=(2,), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)
値を確認できない場合 (2)
tf.data.Dataset
に対する map()
処理や、Kerasで自作の損失関数を使用する場合など、暗黙のうちにグラフモードで実行される場合があります。この場合は @tf.function
をつけていなくても print()
で途中の値を見ることができません。
def fourth_power(x):
z = x * x
print(z)
z = z * z
return z
ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
print(i)
# Tensor("mul:0", shape=(), dtype=int64)
# tf.Tensor(0, shape=(), dtype=int64)
# tf.Tensor(1, shape=(), dtype=int64)
# tf.Tensor(16, shape=(), dtype=int64)
# :
tf.print()
グラフモードで実行されている処理の中で Tensor
の値を確認するには tf.print()
を使用します。
tf.print | TensorFlow Core v2.1.0
以下のように、値を表示することもできますし、tf.shape()
を使って Tensor
の次元やサイズを表示することもできます。自作関数の中で、なぜか次元やサイズが合わないと怒られてしまうときのデバッグにもご利用ください。
@tf.function
def dot(x, y):
tmp = x * y
tf.print(tmp)
tf.print(tf.shape(tmp))
return tf.reduce_sum(tmp)
x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# [3 2]
# [2]
# tf.Tensor(5, shape=(), dtype=int32)
# [0 4]
# [2]
# tf.Tensor(4, shape=(), dtype=int32)
Dataset.map()
のデバッグをしたい時でも tf.print()
を使えば大丈夫です。
def fourth_power(x):
z = x * x
tf.print(z)
z = z * z
return z
ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
print(i)
# 0
# tf.Tensor(0, shape=(), dtype=int64)
# 1
# tf.Tensor(1, shape=(), dtype=int64)
# 4
# tf.Tensor(16, shape=(), dtype=int64)
# :
tf.print() がうまく動かない例
では何も考えず Tensor
の中身は tf.print()
で出力すればよいのかというと、実はそうでもないのです。
tf.print()
が実行されるのは、実際に処理が計算グラフとして実行されるときです。
つまり、**計算を実行するまでもなくグラフ解析時点で型や次元が一致しないことが分かってエラーが発生する場合、tf.print()
の処理は実行されません。**ややこしいですね…。
@tf.function
def add(x, y):
z = x + y
tf.print(z) # このprintは実行されない
return z
x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].
このようなケースでは、逆に普通の print()
を使って、目的の形状のデータが渡ってきているか確認するのがよいでしょう。
@tf.function
def add(x, y):
print(x) # このprintは計算グラフ解析時に実行される
print(y)
z = x + y
return z
x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# Tensor("x:0", shape=(2,), dtype=int32)
# Tensor("y:0", shape=(3,), dtype=int32)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].
自作関数の中で変数を更新する方法
例えば、グラフモードで自作関数が呼び出された回数をカウントし、何回目の呼び出しかを表示することを考えます。
ダメな例
count = 0
@tf.function
def dot(x, y):
global count
tmp = x * y
count += 1
tf.print(count, tmp)
return tf.reduce_sum(tmp)
x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 1 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)
2回目にも「1」が表示されてしまいます。
Pythonコードとしての count += 1
がグラフ解析時に1回しか実行されないことによります。
うまく動く例
以下のように tf.Variable()
と assign_add()
などを使うのが正解です。
tf.Variable | TensorFlow Core v2.1.0
count = tf.Variable(0)
@tf.function
def dot(x, y):
tmp = x * y
count.assign_add(1)
tf.print(count, tmp)
return tf.reduce_sum(tmp)
x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 2 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)
参考記事
tf.function で性能アップ | TensorFlow Core(公式ドキュメント)