この記事は,ドコモアドベントカレンダー2日目の記事になります。
ドコモの酒井と申します。業務ではDeep Learningを用いた画像認識エンジンの研究開発に取り組んでいます。
TL;DR
Keras(バックエンドはtensorflow)からfrozen graphに変換して、uff形式に変換したうえでtensorrtで読み込んで速度計測したところ、処理速度が倍程度になりました。
動機
TensorRT5.0がNVIDIA社からリリースされました。TensorRTはcaffeやtensorflow、onnxなどの学習済みDeep Learningモデルを、GPU上で高速に推論できるように最適化してくれるライブラリです。TensorRTを使ってみた系の記事はありますが、結構頻繁にAPIが変わるようなので、5.0が出たのを機に一通り触ってみたいと思います。
環境
tensorrtのインストールに関しては、公式マニュアルをご参照ください。今回は以下のような環境でdocker上で動作確認しました。
- docker: 18.06.1-ce
- nvidia-docker: 2.0.3
- nvidia-driver: 410.48
- python: 3.5.2
- keras==2.2.4
- tensorflow==1.10.0
- tensorrt==5.0.2.6
利用したdockerfileは以下の通りです(不要なpytorchとかも入っています)。tensorrtのdevは公式サイト(要アカウント登録)から5.0.2.6をダウンロードしてください。
FROM tensorflow/tensorflow:1.10.0-devel-gpu-py3
COPY nv-tensorrt-repo-ubuntu1604-cuda9.0-trt5.0.2.6-ga-20181009_1-1_amd64.deb /tmp
#https://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html#installing-debian
RUN dpkg -i /tmp/nv-tensorrt-repo-ubuntu1604-cuda9.0-trt5.0.2.6-ga-20181009_1-1_amd64.deb && \
apt-key add /var/nv-tensorrt-repo-cuda9.0-trt5.0.2.6-ga-20181009/7fa2af80.pub
RUN apt-get update && apt-get install -y --no-install-recommends \
autoconf \
automake \
libtool \
pkg-config \
ca-certificates \
libprotobuf-dev \
protobuf-compiler \
cmake \
swig \
libglib2.0-0 \
wget \
libcudnn7 \
tensorrt \
libnvinfer-dev \
python3-libnvinfer-dev \
python3-libnvinfer \
uff-converter-tf && \
rm -rf /var/lib/apt/lists/* /tmp/*
RUN pip3 install --upgrade pip \
&& pip3 install keras pycuda
Kerasでのモデル読み込みと速度計測
今回はVGG19を対象とします。VGGモデルを読み込みます。意外と重要なのが、K.set_image_data_format('channels_first')
です。これを実行しないと、後々のuffを読み込むところで以下のようなエラーが出て進めないです。こちらのトピックにたどり着き、orderが原因だということがわかりました。
[TensorRT] ERROR: block1_conv1/convolution: kernel weights has count 1728 but 129024 was expected
[TensorRT] ERROR: UFFParser: Parser error: block1_conv1/BiasAdd: The input to the Scale Layer is required to have a minimum of 3 dimensions.
[TensorRT] ERROR: Network must have at least one output
from keras.applications.vgg19 import VGG19
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
import numpy as np
import keras.backend as K
import tensorflow as tf
K.set_image_data_format('channels_first')
model = VGG19(weights='imagenet')
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 3, 224, 224) 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 64, 224, 224) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 64, 224, 224) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 64, 112, 112) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 128, 112, 112) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 128, 112, 112) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 128, 56, 56) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 256, 56, 56) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 256, 56, 56) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 256, 56, 56) 590080
_________________________________________________________________
block3_conv4 (Conv2D) (None, 256, 56, 56) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 256, 28, 28) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 512, 28, 28) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 512, 28, 28) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 512, 28, 28) 2359808
_________________________________________________________________
block4_conv4 (Conv2D) (None, 512, 28, 28) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 512, 14, 14) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 512, 14, 14) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 512, 14, 14) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 512, 14, 14) 2359808
_________________________________________________________________
block5_conv4 (Conv2D) (None, 512, 14, 14) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 512, 7, 7) 0
_________________________________________________________________
flatten (Flatten) (None, 25088) 0
_________________________________________________________________
fc1 (Dense) (None, 4096) 102764544
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
predictions (Dense) (None, 1000) 4097000
=================================================================
Total params: 143,667,240
Trainable params: 143,667,240
Non-trainable params: 0
_________________________________________________________________
以下のように速度計測しました。
:convert_to_uff.py
img_path = './cat.jpg' #適当な画像パスを設定してください
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
%timeit model.predict(x)
13.7 ms ± 46.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Keras->frozen graph->uff
tensorrtをインストールすると入っている/usr/src/tensorrt/samples/python/end_to_end_tensorflow_mnist/model.py
を参考にmodelをfrozen graphとして保存します。output_namesに関してですが、本来model.output.op.name
を使えばよいです。しかしsoftmaxの処理が入っていると、最終的にuffをtensorrtが読み込む際にうまく読み込みができません。
[TensorRT] INFO: UFFParser: parsing predictions/kernel
[TensorRT] INFO: UFFParser: parsing predictions/MatMul
[TensorRT] INFO: UFFParser: parsing predictions/bias
[TensorRT] INFO: UFFParser: parsing predictions/BiasAdd
[TensorRT] INFO: UFFParser: parsing predictions/Softmax
[TensorRT] INFO: UFFParser: parsing MarkOutput_0
[TensorRT] ERROR: UFFParser: Parser error: MarkOutput_0: Order size is not matching the number dimensions of TensorRT
これは、devガイドに記載されている以下の影響と思われます。
- Note: If the input to a TensorFlow SoftMax op is not NHWC, TensorFlow will automatically insert a transpose layer with a non-constant permutation, causing the UFF converter to fail. It is therefore advisable to manually transpose SoftMax inputs to NHWC using a constant permutation.
そこで今回は、softmaxのひとつ前のopを出力とし、tensorrt側でsoftmaxを追加して対処することにしました。
for op in K.get_session().graph.get_operations():
if op.name.split('_')[0] not in ['Assign', 'Placeholder', 'IsVariableInitialized', 'init']:
print(op.name)
input_1
block1_conv1/random_uniform/shape
block1_conv1/random_uniform/min
----------中略-----------------
predictions/MatMul
predictions/BiasAdd
predictions/Softmax
本モデルのoutputはpredictions/Softmaxですので、そのひとつ前のBiasAddを出力として使うこととします。
:convert_to_uff.py
output_names = ['predictions/BiasAdd']
sess = K.get_session()
frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), output_names)
frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
# Save the model
with open('vgg19.pb', "wb") as ofile:
ofile.write(frozen_graph.SerializeToString())
INFO:tensorflow:Froze 38 variables.
INFO:tensorflow:Converted 38 variables to const ops.
uff保存用のAPIでpbをuffに変換します(pythonAPI)。公式のdevガイドにはconvert-to-uffを使う用に指示がありますが、こちらのスレッドでは、pythonのAPIを使うよう指示があります。なぜか、pythoAPIを使った変換を前述のdocker環境上のjupyterで実行すると、処理完了後もmemoryが増えていってkernelが落ちてしまいました。他の環境でも同様になるのかまでは検証できていませんが、以下の処理までを1つのスクリプトとして実行すると良いと思います。
:convert_to_uff.py
import uff
uff.from_tensorflow_frozen_model('vgg19.pb', output_nodes=output_names, output_filename='vgg19.uff')
UFF Version 0.5.5
=== Automatically deduced input nodes ===
[name: "input_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 3
}
dim {
size: 224
}
dim {
size: 224
}
}
}
}
]
=========================================
Using output node predictions/BiasAdd
Converting to UFF graph
DEBUG: convert reshape to flatten node
No. nodes: 103
UFF Output written to vgg19.uff
tensorrtで読み込み
モデル読み込みの入出力名を定義します。kerasのモデルの入出力の名前を持ってきます。
input_name = 'input_1'#model.input.op.name
output_name = 'predictions/BiasAdd'#model.output.op.name
モデル読み込みのための関数を定義します。
import tensorrt as trt
def build_engine(model_file):
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
# Parse the Uff Network
if not parser.register_input(input_name, (1, 3, 224, 224)):
raise Exception('error in registering input')
if not parser.register_output(output_name):
raise Exception('error in registering output')
if not parser.parse(model_file, network):
raise Exception('error in parsing')
# softmax layerを追加
softmax = network.add_softmax(network.get_output(0))
softmax.get_output(0).name = 'predictions/Softmax'
network.mark_output(softmax.get_output(0))
# Build and return an engine.
return builder.build_cuda_engine(network)
tensorrtで実行すると以下の通りです。
:trt_infer.py
import time
import sys
sys.path.append('/usr/src/tensorrt/samples/python')
import common
# おまじない
import pycuda.driver as cuda
import pycuda.autoinit
# 入力画像の準備
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
import numpy as np
img_path = './cat.jpg' #適当な画像パスを設定してください
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
with build_engine('./vgg19.uff') as engine:
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
with engine.create_execution_context() as context:
start_time = time.time()
img = x.ravel()
np.copyto(inputs[0].host, img)
common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print('time: %fms'%((time.time()-start_time)*1000))
time: 8.572340ms
まとめ
いくつかコツなどがありましたが、無事、tensorrtでkerasのモデルを読み込んで推論速度を検証することができ、高速化されていることが確認できました。
まだまだインターネット上の資料が少なく、うまくいかなかった際にエラーの原因を探るのが大変でした。利用者が増えて知見がたまっていくとよいなと思います(あと、TensorRTのAPI仕様が安定してくれるとよいでが)。
[補足]試してみたがうまくいかなかったもの
- kerasでbatchnorm layerの入ったモデルは、変換できませんでした(ResNet50など)。
- pytorchをonnxに変換して読み込めるかと思いましたが、ダメでした。
- pytorchは、onnxのopset versionが以下の通りで、そのままではtensorrtが使えない(tensorrtはopset 7を対象としている)
- pytorch==0.4.1 -> opset==6
- pytorch==1.0 -> opset=9
- opsetの変換ツールがonnxでは用意されていますが、pytorch公式のvggモデルでは、6->7、9->7の変換ともに失敗してしまいました
- pytorchは、onnxのopset versionが以下の通りで、そのままではtensorrtが使えない(tensorrtはopset 7を対象としている)