LoginSignup
3
3

More than 3 years have passed since last update.

tensorflow2.1 + tensorflow-tensorrt で Transformer はまだ早かった

Last updated at Posted at 2020-02-12

TL;DR

  • tensorflow 2.x 系では RNN や Transformer などの tftrt へのコンバートはできない
    • tftrt が convert_variables_to_constants_v2 を呼ぶ
    • convert_variables_to_constants_v2 は tf2.1 現在制御構文(if, while など)に対応していない
    • RNN 系や Transformer のデコード部分が辛い

背景

会社のプロジェクトでしゃべるロボットを作っており、その会話の応答速度を速めるために TensorRT を使う試行を行っていました。
会話の text-to-text の応答生成のモデルには Transformer ベースのオリジナルのネットワークが使われておりこの推論部分には以下の特徴があります。

  • Attention を何度も行う部分や推論時のデコード処理に繰り返し(while)がある
  • 過去の会話系列をベクトル化する処理と、応答生成のデコードで2種類の処理がある
    • どちらの処理をさせるかを指示するため if 文が入る

問題

これらのネットワークを tftrt のモデルにコンバートしようとすると以下のエラーが起こります。

import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants

print('Converting to TF-TRT...')
trt_model_dir = '...'
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(
    precision_mode=trt.TrtPrecisionMode.FP16,
    max_workspace_size_bytes=8000000000
)

converter = trt.TrtGraphConverterV2(input_saved_model_dir='...',
                                    conversion_params=conversion_params)
converter.convert()
converter.save(output_saved_model_dir=trt_model_dir)
print('Done Converting to TF-TRT')
output
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-20-3a7a7c89e0ac> in <module>
      8 converter = trt.TrtGraphConverterV2(input_saved_model_dir=hp['base_dir'],
      9                                     conversion_params=conversion_params)
---> 10 converter.convert()
     11 converter.save(output_saved_model_dir=trt_model_dir)
     12 print('Done Converting to TF-TRT')

~/.local/share/virtualenvs/ml_sandbox-YhSVM9Gx/lib/python3.6/site-packages/tensorflow_core/python/compiler/tensorrt/trt_convert.py in convert(self, calibration_input_fn)
    978                                   self._input_saved_model_tags)
    979     func = self._saved_model.signatures[self._input_saved_model_signature_key]
--> 980     frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    981     grappler_meta_graph_def = saver.export_meta_graph(
    982         graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

~/.local/share/virtualenvs/ml_sandbox-YhSVM9Gx/lib/python3.6/site-packages/tensorflow_core/python/framework/convert_to_constants.py in convert_variables_to_constants_v2(func, lower_control_flow)
    508         print('op:', name_to_node[input_name].op)
    509         print('node:', name_to_node[input_name])
--> 510         raise ValueError("Cannot find the Placeholder op that is an input "
    511                          "to the ReadVariableOp.")
    512       _save_placeholder(input_name, node.attr["dtype"])

ValueError: Cannot find the Placeholder op that is an input to the ReadVariableOp.

(一部デバッグのために convert_to_constants.py を書き換えてるので行数などは異なります)

調査

いろいろ試した結果どうも if が入るとこのエラーが出るということがわかったので、ただ dense layer を重ねたものに if 文で処理が分岐する再現サンプルを作りました。
https://gist.github.com/halhorn/2a44b9004aa0352fd4dfba0fb9942e8e

こちらでも同様のエラーが出ます。
エラー箇所 付近を見ると、指定された op が Placeholder ではない場合にこのエラーが起きます。

convert_variables_to_constants_v2 の doc を読んでみると、この関数は tf2.x 系用のもので、制御構文には今の所対応していないと書かれています。

The current implementation only works for graphs that do not contain any
control flow or embedding related ops.

エラー内容では入力のノードの op が Placeholder でないといけないと言っているわけですが、実際エラー箇所の if 文に入ってきたノードの op の種類を見てみると、 while 系の場合 Enter が、 if 文系の場合は Switch が来ていました。

TFTRT ではない本家の TensorRT もまだ tf2.x 系には対応していないと聞いていますし、 TensorRT 関係を使いたいなら tf2.x 系はまだ早かったかという印象があります。(我々はもう2.1にしてしまいましたが)

追記:パフォーマンス

現在の TensorRT は CNN 系に最適化されており Transformer などには効果があるかが不明でした。
そこで Transformer の計算量の多くを占めている(と私がかってに思ってる) Dense layer のみを大量に積んだネットワークでのパフォーマンスを測定しました。

精度はどちらも fp16 を使い、生の tensorflow モデルと tftrt コンバート済みのモデルでは、1推論あたり3.5ms と 1.31ms と、2倍以上の差がでました。(tf2.0 での tftrt (trt5.0系) では2倍の差は出なかったので、 tf2.1 になって trt のバージョンが上がったことで高速化されたのかもしれません。)

transformer にもし使えれば有効な高速化手段となっていたと思うので大変残念です。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3