はじめに
この記事はTF2.0のPreview版がPyPIに公開されたことを受け、正式版の公開やその利用に向けて備える記事です。前回(TensorFlowの消失)では大幅なAPI変更を警告する内容でした。今回の記事では、TF2.0ではどう書くべきかについて書かれた、"Effective TensorFlow 2.0"の内容を紹介します。
(内容が更新されることが予想されます。このバージョンのもので説明します。)
要約
要約の要約をすると、「tf.functionを使いこなせ」という感じです。
- 主要な変更
- API掃除
- tf.app, tf.flags, tf.logging → absl-py
- tf.contrib → リホーム
- tf直下のマイナー関数 → tf.math
- その他、重複するもの → tf.summary, tf.keras.metrics, tf.keras.optimizers
- 変換は
tf_upgrade_v2
を使え(使い方)
- Eager
- TF1系では計算グラフをsession.run経由で手動コンパイル → Eagerモードがデフォルト
- グラフやセッションは実装にすぎないと感じられる
- tf.control_dependenciesとはおさらば
- グローバル禁止
- tf.global_variables_initializerなど、全Variableを操作する系のものを削除
- Variableは自分で管理しろ
- そうしないと、VariableもGCされる
- kerasなら簡単に出来るからいいでしょ
- sessionからfunctionへ
- tf.function(AutoGraph)を使え(JITコンパイル)
- グラフが1つだけ
- グラフモードの恩恵はすべて受けられる
- パフォーマンス
- ポータビリティ(要確認)
- シーケンシャルなモデル、強化学習、その他独自訓練ループもAutoGraphできるかも(要確認)
- API掃除
- 推奨される書き方
- 小さい関数にリファクタしろ
- tf.functionはtrainなどの大きい単位で1回使えば良い
- tf.kerasを使って変数管理せよ
- variablesやtrainable_variablesのプロパティが便利
- tf.kerasの層は、checkpointやSavedModelとして保存できる
- 保存できるのはtf.train.Checkpointable(仮)の機能
- 訓練のイテレーションやRNN等をtf.functionでコンパイル
- tf.functionがforに対応しいる
- tf.keras.metricsでロスなどを集計して、tf.summaryで記録
- 小さい関数にリファクタしろ
コード例
"Effective TensorFlow 2.0"の内容を参考に例をいくつか紹介します。"Effective TensorFlow 2.0"にもコードの例があるので、興味がある人はそちらも確認してください。
tf.saved_modelも試してみたいですが、まだ未実装の部分があるため触れません。
訓練もどき
パラメータ$x$、損失$x$の最小構成で、訓練のようなことをしてみましょう。
import tensorflow as tf
@tf.function
def train(v, opt):
for _ in range(10):
opt.minimize(lambda: v, [v]) # 損失の値ではなく、損失を返す関数を第一引数とする。また、更新する変数のリストを指定する。
print(v)
train(tf.Variable(1.), tf.keras.optimizers.SGD(0.001))
注意時項としては、tf.functionの中でoptimizerの初期化ができませんでした。そういうものなのでしょう。もちろん、kerasのfitメソッドも健在ですし、また、別の書き方として、勾配を取ってapply_gradientsを適用する方法もあります。
import tensorflow as tf
@tf.function
def train(v, opt):
for _ in range(10):
with tf.GradientTape() as tape:
dv = tape.gradient(v, [v])
opt.apply_gradients(zip(dv, [v]))
print(dv[0], v)
train(tf.Variable(1.), tf.keras.optimizers.SGD(0.001))
fizzbuzz
次の例は、何の変哲もないFizzBuzzです。これもtf.functionでグラフ化できます。グラフ化の様子はtf.autograph.to_codeを使うと見れます。
import tensorflow as tf
def fizzbuzz(n):
if n % 15 == 0:
return "FizzBuzz"
if n % 3 == 0:
return "Fizz"
if n % 5 == 0:
return "Buzz"
return n
def fizzbuzz_upto_100():
i = 1
while i <= 100:
print(fizzbuzz(i))
i += 1
print(tf.autograph.to_code(fizzbuzz_upto_100)) # グラフ化のコード表示
tf_fizzbuzz = tf.function(fizzbuzz_upto_100) # tf.functionはこのような使い方もできる
tf_fizzbuzz()
print(tf.autograph.to_code(fizzbuzz_upto_100))
でグラフ化して表示している部分の出力は、次のようになります。
from __future__ import print_function
def tf__fizzbuzz(n):
try:
with ag__.function_scope('fizzbuzz'):
cond_2 = ag__.eq(n % 15, 0)
def if_true_2():
with ag__.function_scope('if_true_2'):
return__1 = 'FizzBuzz'
return return__1
def if_false_2():
with ag__.function_scope('if_false_2'):
cond_1 = ag__.eq(n % 3, 0)
def if_true_1():
with ag__.function_scope('if_true_1'):
return__1 = 'Fizz'
return return__1
def if_false_1():
with ag__.function_scope('if_false_1'):
cond = ag__.eq(n % 5, 0)
def if_true():
with ag__.function_scope('if_true'):
return__1 = 'Buzz'
return return__1
def if_false():
with ag__.function_scope('if_false'):
return__1 = n
return return__1
return__1 = ag__.if_stmt(cond, if_true, if_false)
return__1 = return__1
return return__1
return__1 = ag__.if_stmt(cond_1, if_true_1, if_false_1)
return__1 = return__1
return return__1
return__1 = ag__.if_stmt(cond_2, if_true_2, if_false_2)
return return__1
except:
ag__.rewrite_graph_construction_error(ag_source_map__)
tf__fizzbuzz.autograph_info__ = {}
def tf__fizzbuzz_upto_100():
try:
with ag__.function_scope('fizzbuzz_upto_100'):
i = 1
def loop_test(i_1):
with ag__.function_scope('loop_test'):
return ag__.lt_e(i_1, 100)
def loop_body(i_1):
with ag__.function_scope('loop_body'):
with ag__.utils.control_dependency_on_returns(ag__.print_(tf__fizzbuzz(i_1))):
i = ag__.utils.alias_tensors(i_1)
i += 1
return i,
i = ag__.while_stmt(loop_test, loop_body, (i,), ())
except:
ag__.rewrite_graph_construction_error(ag_source_map__)
tf__fizzbuzz_upto_100.autograph_info__ = {}
注意時項ですが、なんでもかんでもto_codeが使えるわけではないようです。この例では、while i <= 100:
でループしていますが、for i in range(1, 101):
では、to_codeできませんでした。tf.functionだけなら動くようなので、to_codeのためにあれこれする必要はないでしょうが・・・(実はバグ?未実装?)。
まとめ
これまでGraphモードで書くことが多かったのですが、tf.functionで普通にいいんじゃないかなと思えてきました。optimizerがkeras側に集約されましたが、少々使い方が変わった程度で済みそうです。一方で、未実装なモデルの保存がどこまで柔軟にできるかが気になります。saved_model関連のものが出来上がったら試してみたいです。