はじめに
この記事はTensorFlow2.0 Advent Calendar 2019の20日目の記事です。TensorFlow2.0の大きな変更点といえば、EagerExecutionがデフォルトになり、命令型言語での記述が可能になりより自由自在にPythonicな書き方ができるようになったことだと思います。しかし、その一方でパフォーマンスやポータビリティが犠牲になるという問題もあるのですが、それを解決して、1.xにおけるGraphモードと2.xのEagerモードの両方の恩恵を受けられるようにするために登場したのがtf.function
です。この記事ではtf.function
の使い方と使う際に知っていた方がいい注意点について紹介しようと思います。基本的は公式サイトのまとめになるので、詳しく知りたい方はそちらも参考にしてください。
tf.functionの使い方
使い方は簡単で、最適化したい重い処理を記述した関数に@tf.function
でコレータをつけるか、関数を定義してそれをtf.function
メソッドに食わせてGraphモード用の関数を別途作成するという方法です。
import tensorflow as tf
@tf.function
def f(x,y):
return x + y
#または
def g(x,y):
return x * y
h = tf.function(g)
また、@tf.function
内で他の関数を呼ぶ場合、スコープはその関数にもおよぶため、わざわざ全ての関数をチェックして@tf.function
をつける必要はありません(というかオススメしません)。そのためとりあえず重い処理の部分に添えるだけで簡単にGraphモードの恩恵を受けられるように思われます。しかし、チュートリアルに乗っているようなシンプルな書き方でしたらこれで全くも大ないのですが、少し複雑なことをしようとするとtf.function
の仕様を知っていないと思いもしないような動作をするため注意が必要です。公式サイトの冒頭には以下のような記述があります。
- Object MutationやPythonのlistのようなPythonの固有の挙動に依存するな
- tf.functionはNumpyやPythonのプリミティブ型を使うよりもTensorFlowのOpsを使った方がベストなパフォーマンスを発揮します
- 疑わしければ
for x in y
の書き方をしなさい
一部どういうことか解釈に苦しむ項目もありますが、具体例をみた方がわかりやすいと思いますのでみていきましょう。
実験
まずは簡単のために以下のようなシンプルな関数を用意します。
import tensorflow as tf
@tf.function
def double(a):
print("Tracing with, {}".format(a) )
return a + a
入力された引数を二倍にして返却するというシンプルな関数で、入力引数をprintする処理も入っています。引数の方は整数でも実数でも文字列でも動作します。これをいくつかのパターンで実行してみましょう。
print(double(3))
print()
print(double(2))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant('a')))
print()
print(double(tf.constant('b')))
結果は以下のようになります。
Tracing with, 3
tf.Tensor(6, shape=(), dtype=int32)
Tracing with, 2
tf.Tensor(4, shape=(), dtype=int32)
Tracing with, Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)
Tracing with, Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
少し奇妙な結果になってしましました。一番最後のtf.constant('b')
を引数として実行した結果だけ print文が実行されません。上記のプログラムをもう一度実行してみるとさらにおかしな結果が得られます。
print(double(3))
print()
print(double(2))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant('a')))
print()
print(double(tf.constant('b')))
結果は以下のようになります。
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(2.2, shape=(), dtype=float32)
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
正しい値は返却されるのですが、途中に書かれたprint文は一切実行されません。これはどういうことなのでしょう?
Tracing
実はこの奇妙な動作はtf.function
が関数を計算グラフに構築して最適化する際にTracingという処理が関係しています。tf.function
はTensorFlow由来だけではないPython固有の処理も記述された関数を計算グラフに変換します。そして実際に計算グラフの実行に関係のないPython固有の処理(今回でいうとprint文)を省いてしまいます。しかしなぜ最初は実行されたのでしょうか?それはtf.function
が関数を計算グラフに変換するときにTracingという処理が走るからです。Pythonで書かれた関数は引数に明示的に型が存在しません。そのため色々な値を入力できて便利な一方、最適な計算グラフを作ろうとするtf.function
側からすると困り者です。そのため、引数に今まで入ってこなかった値や型が入力されて関数が最初に呼ばれるときにTracingといって、関数内のPython固有の処理を全て実行してみるという処理が走ります。「引数に今まで入ってこなかった値や型」と言いましたが、厳密には以下の基準となります。
- Pythonのプリミティブ型の場合、値が異なるものが入ってきたらTracing
- Pythonのオブジェクトの場合、idが異なるものが入ってきたらTracing
- TensorFlow由来のTensorの場合、型かshapeが異なるものが入ってきたらTracing
そのため先ほどの奇妙な挙動のカラクリは以下のようになります
print(double(3)) #初めてみる値なのでTracing
print()
print(double(2)) #初めてみる値なのでTracing
print()
print(double(tf.constant(1.1))) #初めてみる値なのでTracing
print()
print(double(tf.constant('a'))) #初めてみる型 shapeなのでTracing
print()
print(double(tf.constant('b'))) #以前見た型 shapeなので最適化されたグラフ実行
一回Tracing走るとTensorFlowはその結果構築された計算グラフを内部に保存します。そして次に以前Tracingした値や型/shapeの引数が入力された際は最適化された計算グラフを実行します。そのため上記プログラムでは最後の呼び出しではPythonのprint文が実行されずに、再び実行した際は全てのprint文が実行されなかったわけです。
ではどうするべきか?
そのため冒頭に戻りますが、
- Object MutationやPythonのlistのようなPythonの固有の挙動に依存するな
ということになります。先ほどのPrintですが代わりにtf.print
で記述すれば毎回実行されます。tf.summary
を用いたり、関数内で各種値のアップデートを行いたい場合はtf.Variable
を用いて行うなど、TensorFlow由来の機能をフル活用することでおかしな動作を防ぐこともできますしパフォーマンスも向上します。ただ、Python固有の処理を一切入れるなと言っているわけではないので注意してください。Pythonicな記述と併用できることによってより柔軟にプログラミングできるようになったメリットは大きいです。ただ、何でもかんでも考えなしに関数を定義してtf.function
をつけるとおかしな動作をするよということに注意してください。
まとめ
TensorFlow2.0になりGraphモードとEagerモードの両方の恩恵を受けられるようになったのは良いのですが、上記に紹介した落とし穴以外にもPython固有の機能に依存しすぎた関数を作ってしまうと思わぬバグを踏んでしまうことがあります。TensorFlowを使うならできるだけTensorFlow由来のメソッドを使って、Python固有の機能を使う際はAutoGraphの際の動作やTracingを意識した設計にしましょう。公式サイトにはその他様々な気をつけるべきことやこういう挙動を制御する方法などが多く記載されています。これから独自モデルを実装する必要などがありtf.function
を使う必要がある方はぜひご一読ください。