6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

[TensorFlow 2] グラフモードでTensorの中身を確認する方法

Last updated at Posted at 2020-04-25

はじめに

TensorFlow (2.x) で学習時の損失関数やデータセットの変換部分に自作関数を使うときに、「思ったとおりの値が入っているのかな?」と確認しようと思って print() を呼び出しても、値が出力されないときがあります。
tfdbg などのデバッガーを使えば良いのでしょうが、もっとシンプルに、いわゆる「printデバッグ」をする方法をご紹介します。

tfdbgについては→tfdbg を使って Keras の nan や inf を潰す - Qiita

検証環境

  • Ubuntu 18.04
  • Python 3.6.9
  • TensorFlow 2.1.0 (CPU)

以下のコードでは、↓が書かれているものとします。

import tensorflow as tf 

Tensor に対して print() した場合の挙動

TensorFlow 2.xでは Eager Execution がデフォルトになりましたので、インタプリタで Tensor を組み立てている限りは、普通に print() すれば Tensor の値が表示されますし、.numpy()ndarray として値を得ることもできます。

値を確認できる場合

x = tf.constant(10)
y = tf.constant(20)
z = x + y
print(z)         # tf.Tensor(30, shape=(), dtype=int32)
print(z.numpy()) # 30

値を確認できない場合 (1)

学習・評価時に逐次 Tensor の値を評価すると遅いので、@tf.function デコレータをつけることにより、グラフモードで処理する関数を定義することができます。グラフモードで定義した演算は計算グラフとしてまとめて(Pythonの外で)処理されるため、計算過程の値を見ることができません。

@tf.function
def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# Tensor("mul:0", shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

dot() の外で計算結果を出力することはできますが、中では値を見ることができません。
しかも、dot() を複数回呼び出しても、中の print() は基本的にはグラフを解析するタイミングで1回しか実行されません。

もちろん、この場合は @tf.function() を取り除けば値が見えるようにはなります。

def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# tf.Tensor([3 2], shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor([0 4], shape=(2,), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

値を確認できない場合 (2)

tf.data.Dataset に対する map() 処理や、Kerasで自作の損失関数を使用する場合など、暗黙のうちにグラフモードで実行される場合があります。この場合は @tf.function をつけていなくても print() で途中の値を見ることができません。

def fourth_power(x):
    z = x * x
    print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)

# Tensor("mul:0", shape=(), dtype=int64)
# tf.Tensor(0, shape=(), dtype=int64)
# tf.Tensor(1, shape=(), dtype=int64)
# tf.Tensor(16, shape=(), dtype=int64)
# :

tf.print()

グラフモードで実行されている処理の中で Tensor の値を確認するには tf.print() を使用します。
tf.print | TensorFlow Core v2.1.0

以下のように、値を表示することもできますし、tf.shape() を使って Tensor の次元やサイズを表示することもできます。自作関数の中で、なぜか次元やサイズが合わないと怒られてしまうときのデバッグにもご利用ください。

@tf.function
def dot(x, y):
    tmp = x * y
    tf.print(tmp)
    tf.print(tf.shape(tmp))
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# [3 2] 
# [2]
# tf.Tensor(5, shape=(), dtype=int32)
# [0 4]
# [2]
# tf.Tensor(4, shape=(), dtype=int32)

Dataset.map() のデバッグをしたい時でも tf.print() を使えば大丈夫です。

def fourth_power(x):
    z = x * x
    tf.print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)
# 0
# tf.Tensor(0, shape=(), dtype=int64)
# 1
# tf.Tensor(1, shape=(), dtype=int64)
# 4
# tf.Tensor(16, shape=(), dtype=int64)
# :

tf.print() がうまく動かない例

では何も考えず Tensor の中身は tf.print() で出力すればよいのかというと、実はそうでもないのです。
tf.print() が実行されるのは、実際に処理が計算グラフとして実行されるときです。
つまり、**計算を実行するまでもなくグラフ解析時点で型や次元が一致しないことが分かってエラーが発生する場合、tf.print() の処理は実行されません。**ややこしいですね…。

@tf.function
def add(x, y):
    z = x + y
    tf.print(z) # このprintは実行されない
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

このようなケースでは、逆に普通の print() を使って、目的の形状のデータが渡ってきているか確認するのがよいでしょう。

@tf.function
def add(x, y):
    print(x) # このprintは計算グラフ解析時に実行される
    print(y)
    z = x + y
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# Tensor("x:0", shape=(2,), dtype=int32)
# Tensor("y:0", shape=(3,), dtype=int32)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

自作関数の中で変数を更新する方法

例えば、グラフモードで自作関数が呼び出された回数をカウントし、何回目の呼び出しかを表示することを考えます。

ダメな例

count = 0

@tf.function
def dot(x, y):
    global count
    tmp = x * y
    count += 1
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 1 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

2回目にも「1」が表示されてしまいます。
Pythonコードとしての count += 1 がグラフ解析時に1回しか実行されないことによります。

うまく動く例

以下のように tf.Variable()assign_add() などを使うのが正解です。
tf.Variable | TensorFlow Core v2.1.0

count = tf.Variable(0)

@tf.function
def dot(x, y):
    tmp = x * y
    count.assign_add(1)
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 2 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

参考記事

tf.function で性能アップ | TensorFlow Core(公式ドキュメント)

6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?