search
LoginSignup
17

More than 1 year has passed since last update.

posted at

Tensorflow2でmodel.predictの推論が遅い!ので他の方法を検証した

はじめに

動画に対してリアルタイム解析を行う場合、どうしても推論速度を意識しないとなりません。30fpsで解析するためには前処理・後処理を合わせて33msec以内に行う必要があります。
最近tensorflow2に移行したのですが、どうもtf.keras.model.predictが遅いような気がして、ほかに早く推論できる方法はないかと調べてみました。

すでに公知の事実のようでしたが、結果として1枚の画像を解析する場合にtf.keras.model.predictは悪手のようです。

参考サイト

検証条件
CPU:Ryzen9 5950x
GPU: RTX3070
Memory: 64GB
OS: Windows10 Pro (jupyter notebook使用)
Python: 3.7.9
tensorflow: 2.4.0
CUDA: 11.0
cuDNN: 8.0.5

検証の方法

今回はtf.keras.applications.VGG16を特に変更せず使用しました。
速度を検証する手法は以下の3種類です。
1.model.predict()
2.model()
3.model.predict_on_batch()

検証用のデータは下記の2種類です。
A. 1枚の画像を想定したnp.random.rand(1,224,224,3)で生成したデータ
B. 複数の画像(4枚)を想定したnp.random.rand(4,224,224,3)で生成したデータ

python3
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    for device in physical_devices:
        tf.config.experimental.set_virtual_device_configuration(device, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=3072)])
else:
    print("Not enough GPU hardware devices available")

import numpy as np
import timeit

model = tf.keras.applications.VGG16()
arr = np.random.rand(1,224,224,3)
batch = np.random.rand(4,224,224,3)
#注:以下はjupyter notebookでの表記です。
%timeit  -n 50 model.predict(arr)
%timeit  -n 50 model(arr,training=False)
%timeit  -n 50 model.predict_on_batch(arr)

%timeit  -n 50 model.predict(batch)
%timeit  -n 50 model(batch,training=False)
%timeit  -n 50 model.predict_on_batch(batch)

結果

データAに対する推論速度 データBに対する推論速度
model.predict() 21.8 ms ± 216 µs 28.2 ms ± 267 µs
model() 6.33 ms ± 53.9 µs 14.9 ms ± 27.3 µs
model.predict_on_batch() 7.24 ms ± 49.5 µs 14 ms ± 140 µs

上記結果をみて驚いたのですが、結構差がありました(ご存じの方にはすみません・・・)

ということで1枚の画像に対しては
model() < model.predict_on_batch() < model.predict()の順であり

複数の画像に対しては
model.predict_on_batch() < model() < model.predict()の順でした。

いずれの場合でもmodel.predict()は最も遅く、 model()model.predict_on_batch()がほぼ同程度の速度でした。

考察

そもそもtensorflow2のmodel.predict()は1つのデータに対して実行するのに適していないとのことです(ソース)

具体的には、1つのデータに対してはmodel(x)もしくはmodel(x, training=False)が適しているということでした。
またmodel.predict_on_batch()も同様に早いということです(ソース)

ということでmodel()もしくはmodel.predict_on_batch()で推論する方が良いようです。
複数枚画像の処理においてはmodel.predict_on_batch()の方が高速ですので、
常にmodel.predict_on_batch()でよいような気がします。。。

おわりに

今更ながらmodel.predict()が遅くいことを知りました。。。
そのあたりAIの実装においては結構重要だと思うのですが、公式のチュートリアルに書いてあったかなぁ・・・?

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
17