PINTO_model_zoo
TensorflowLite-bin
Bazel_bin
1.Introduction
自力で Full Integer Quantization を実施してモデルのパフォーマンスを上げたいと常々考えています。 今回は、公式のチュートリアルを完全に無視して独自の Full Integer Quantization 手順を整理してみました。 公式のチュートリアルどおりに、 Freeze Graph
-> Optimization
-> saved_model形式へ変換
-> Post-process追加
-> Quantization
の順に実施すると、Quantization
の段階で、 カスタムオペレーションが含まれている場合に Check failed: dim_size >= 1 (0 vs. 1)
エラーとなって変換に失敗します。 今回は敢えて Pipeline を通さずに、素のモデルをそのまま Quantization してパフォーマンスを計測してみたいと思います。 ただし、 Tensorflow Lite のパフォーマンスを最大化するため、 RaspberryPi4 には Raspbian ではなく、 Ubuntu 19.10 aarch64 (64bit) を導入して検証するという、かなりトリッキーな検証を行います。 検証の結果、 64bit Ubuntu では、 32bit Raspbian のおよそ4倍のパフォーマンスが得られました。 なお、この記事は Full Integer Quantization と パフォーマンスの計測 のみを行い、精度の検証は行いませんのであしからず。 公式のチュートリアルはKerasベースのものばかりで、トレーニング済みモデルから Quantization を行う手順がまとまっている記事を見つけることができませんでしたので、意地になって取り組みました。 いつも通り意味不明に、 Neural Compute Stick2 も EdgeTPU も使用せず、 CPU only の推論にこだわっています。
2.Environment
- Ubuntu 18.04 x86_64 (作業用PC)
- Tensorflow-GPU v1.15.0
- Tensorflow Dataset (COCO Dataset, 91クラス, 背景1クラス込み)
- RaspberryPi4 (Ubuntu 19.10 aarch64 【64bit】)
- MobileNetV3-SSD
- Tensorflow Lite v1.15.0
- Bazel 0.29.1
3.Procedure
下記の手順を順番に実施し、MobileNetV3-SSD
のモデルを最終的に Full Integer Quantization
します。 最後のキャリブレーション用データ・セットの生成処理が間違っているかもしれませんが、そこはご愛嬌で。 間違いにお気づきの方は、是非修正リクエストいただけますと嬉しいです。
3−1.Pre-environment preparation
Quantization 作業に必要となる必要最低限のパッケージを導入します。
$ cd ~
$ sudo pip3 install tensorflow-gpu==1.15.0
$ git clone --depth 1 https://github.com/tensorflow/models.git
$ cd models/research
$ git clone https://github.com/cocodataset/cocoapi.git
$ cd cocoapi/PythonAPI
$ make
$ cp -r pycocotools ../..
$ cd ../..
$ wget -O protobuf.zip https://github.com/google/protobuf/releases/download/v3.0.0/protoc-3.0.0-linux-x86_64.zip
$ unzip protobuf.zip
$ ./bin/protoc object_detection/protos/*.proto --python_out=.
$ sudo apt-get install -y protobuf-compiler python3-pil python3-lxml python3-tk
$ sudo -H pip3 install Cython contextlib2 jupyter matplotlib
3−2.Download Tensorflow official trained model
私の Google Drive に退避してある公式のトレーニング済み MobileNetV3-SSD のモデルをダウンロードして展開します。 Small版 と Large版 の2種類をダウンロードします。
$ export PYTHONPATH=${PWD}:${PWD}/object_detection:${PWD}/slim:${PYTHONPATH}
$ mkdir -p ssd_mobilenet_v3_small_coco_2019_08_14 && cd ssd_mobilenet_v3_small_coco_2019_08_14
$ curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=1uqaC0Y-yRtzkpu1EuZ3BzOyh9-i_3Qgi" > /dev/null
$ CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"
$ curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1uqaC0Y-yRtzkpu1EuZ3BzOyh9-i_3Qgi" -o ssd_mobilenet_v3_small_coco_2019_08_14.tar.gz
$ tar -zxvf ssd_mobilenet_v3_small_coco_2019_08_14.tar.gz
$ rm ssd_mobilenet_v3_small_coco_2019_08_14.tar.gz
$ cd ..
$ mkdir -p ssd_mobilenet_v3_large_coco_2019_08_14 && cd ssd_mobilenet_v3_large_coco_2019_08_14
$ curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=1NGLjKRWDQZ_kibQHlLZ7Eetuuz1waC7X" > /dev/null
$ CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"
$ curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1NGLjKRWDQZ_kibQHlLZ7Eetuuz1waC7X" -o ssd_mobilenet_v3_large_coco_2019_08_14.tar.gz
$ tar -zxvf ssd_mobilenet_v3_large_coco_2019_08_14.tar.gz
$ rm ssd_mobilenet_v3_large_coco_2019_08_14.tar.gz
$ cd ..
3−3.Create a conversion script from checkpoint format to saved_model format
モデルの最適化 と saved_model形式 への変換を一気に行うスクリプトを作成します。 ココが公式のチュートリアルには無い独自のシーケンスになります。
import tensorflow as tf
import os
import shutil
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from tensorflow.python import ops
from tensorflow.tools.graph_transforms import TransformGraph
def freeze_model(saved_model_dir, output_node_names, output_filename):
output_graph_filename = os.path.join(saved_model_dir, output_filename)
initializer_nodes = ''
freeze_graph.freeze_graph(
input_saved_model_dir=saved_model_dir,
output_graph=output_graph_filename,
saved_model_tags = tag_constants.SERVING,
output_node_names=output_node_names,
initializer_nodes=initializer_nodes,
input_graph=None,
input_saver=False,
input_binary=False,
input_checkpoint=None,
restore_op_name=None,
filename_tensor_name=None,
clear_devices=True,
input_meta_graph=False,
)
def get_graph_def_from_file(graph_filepath):
tf.reset_default_graph()
with ops.Graph().as_default():
with tf.gfile.GFile(graph_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def optimize_graph(model_dir, graph_filename, transforms, input_name, output_names, outname='optimized_model.pb'):
input_names = [input_name] # change this as per how you have saved the model
graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
optimized_graph_def = TransformGraph(
graph_def,
input_names,
output_names,
transforms)
tf.train.write_graph(optimized_graph_def,
logdir=model_dir,
as_text=False,
name=outname)
print('Graph optimized!')
def convert_graph_def_to_saved_model(export_dir, graph_filepath, input_name, outputs):
graph_def = get_graph_def_from_file(graph_filepath)
with tf.Session(graph=tf.Graph()) as session:
tf.import_graph_def(graph_def, name='')
tf.compat.v1.saved_model.simple_save(
session,
export_dir,# change input_image to node.name if you know the name
inputs={input_name: session.graph.get_tensor_by_name('{}:0'.format(node.name))
for node in graph_def.node if node.op=='Placeholder'},
outputs={t.rstrip(":0"):session.graph.get_tensor_by_name(t) for t in outputs}
)
print('Optimized graph converted to SavedModel!')
tf.compat.v1.enable_eager_execution()
# Look up the name of the placeholder for the input node
graph_def=get_graph_def_from_file('./ssd_mobilenet_v3_small_coco_2019_08_14/frozen_inference_graph.pb')
input_name_small=""
for node in graph_def.node:
if node.op=='Placeholder':
print("##### ssd_mobilenet_v3_small_coco_2019_08_14 - Input Node Name #####", node.name) # this will be the input node
input_name_small=node.name
# Look up the name of the placeholder for the input node
graph_def=get_graph_def_from_file('./ssd_mobilenet_v3_large_coco_2019_08_14/frozen_inference_graph.pb')
input_name_large=""
for node in graph_def.node:
if node.op=='Placeholder':
print("##### ssd_mobilenet_v3_large_coco_2019_08_14 - Input Node Name #####", node.name) # this will be the input node
input_name_large=node.name
# ssd_mobilenet_v3 output names
output_node_names = ['raw_outputs/class_predictions','raw_outputs/box_encodings']
outputs = ['raw_outputs/class_predictions:0','raw_outputs/box_encodings:0']
# Optimizing the graph via TensorFlow library
transforms = []
optimize_graph('./ssd_mobilenet_v3_small_coco_2019_08_14', 'frozen_inference_graph.pb', transforms, input_name_small, output_node_names, outname='optimized_model_small.pb')
optimize_graph('./ssd_mobilenet_v3_large_coco_2019_08_14', 'frozen_inference_graph.pb', transforms, input_name_large, output_node_names, outname='optimized_model_large.pb')
# convert this to a s TF Serving compatible mode - ssd_mobilenet_v3_small_coco_2019_08_14
shutil.rmtree('./ssd_mobilenet_v3_small_coco_2019_08_14/0', ignore_errors=True)
convert_graph_def_to_saved_model('./ssd_mobilenet_v3_small_coco_2019_08_14/0',
'./ssd_mobilenet_v3_small_coco_2019_08_14/optimized_model_small.pb', input_name_small, outputs)
# convert this to a s TF Serving compatible mode - ssd_mobilenet_v3_large_coco_2019_08_14
shutil.rmtree('./ssd_mobilenet_v3_large_coco_2019_08_14/0', ignore_errors=True)
convert_graph_def_to_saved_model('./ssd_mobilenet_v3_large_coco_2019_08_14/0',
'./ssd_mobilenet_v3_large_coco_2019_08_14/optimized_model_large.pb', input_name_large, outputs)
モデルの最適化 と saved_model形式 への変換を実行します。 saved_model形式 のファイル群は 「0」という名前のフォルダの中に生成されます。
$ python3 freeze_the_saved_model.py
生成された saved_model形式 のファイル群を分析して、 INPUT と OUTPUT の構造を覗いてみます。 Tensorflow に標準で装備されている saved_model_cli
というコマンドを利用すると可視化できます。
$ saved_model_cli show --dir ./ssd_mobilenet_v3_small_coco_2019_08_14/0 --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['normalized_input_image_tensor'] tensor_info:
dtype: DT_FLOAT
shape: (1, 320, 320, 3)
name: normalized_input_image_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['raw_outputs/box_encodings'] tensor_info:
dtype: DT_FLOAT
shape: (1, 2034, 4)
name: raw_outputs/box_encodings:0
outputs['raw_outputs/class_predictions'] tensor_info:
dtype: DT_FLOAT
shape: (1, 2034, 91)
name: raw_outputs/class_predictions:0
Method name is: tensorflow/serving/predict
$ saved_model_cli show --dir ./ssd_mobilenet_v3_large_coco_2019_08_14/0 --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['normalized_input_image_tensor'] tensor_info:
dtype: DT_FLOAT
shape: (1, 320, 320, 3)
name: normalized_input_image_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['raw_outputs/box_encodings'] tensor_info:
dtype: DT_FLOAT
shape: (1, 2034, 4)
name: raw_outputs/box_encodings:0
outputs['raw_outputs/class_predictions'] tensor_info:
dtype: DT_FLOAT
shape: (1, 2034, 91)
name: raw_outputs/class_predictions:0
Method name is: tensorflow/serving/predict
COCOデータ・セットは 25GB ほどもある超巨大なデータ・セットですので、真面目にイチから TFRecords形式 へ変換を始めると数時間掛かってしまいます。 面倒ですので、私が生成したTFRecordデータ・セットの Testプリセット
部分のみを 私の Google Drive からダウンロードしてしまいます。
$ curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=1Uk9F4Tc-9UgnvARIVkloSoePUynyST6E" > /dev/null
$ CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"
$ curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1Uk9F4Tc-9UgnvARIVkloSoePUynyST6E" -o TFDS.tar.gz
$ tar -zxvf TFDS.tar.gz
$ rm TFDS.tar.gz
Weight Quantization
と Integer Quantization
と Full Integer Quantization
を一気に実行する Pythonスクリプト を作成します。 本来、効率的なプログラムを書くならば、リテラルで書かれているファイルパスの部分は全て変数化すべきですが、 ガガガッとフィーリングで作成しましたのでお許しください。
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
def representative_dataset_gen():
for data in raw_test_data.take(100):
image = data['image'].numpy()
image = tf.image.resize(image, (320, 320))
image = image[np.newaxis,:,:,:]
yield [image]
tf.compat.v1.enable_eager_execution()
# Generating a calibration data set
#raw_test_data, info = tfds.load(name="coco/2017", with_info=True, split="test", data_dir="./TFDS")
raw_test_data, info = tfds.load(name="coco/2017", with_info=True, split="test", data_dir="./TFDS", download=False)
print(info)
# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./ssd_mobilenet_v3_small_coco_2019_08_14/0')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('./ssd_mobilenet_v3_small_coco_2019_08_14/mobilenet_v3_small_weight_quant.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Weight Quantization complete! - mobilenet_v3_small_weight_quant.tflite")
# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./ssd_mobilenet_v3_small_coco_2019_08_14/0')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./ssd_mobilenet_v3_small_coco_2019_08_14/mobilenet_v3_small_integer_quant.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Integer Quantization complete! - mobilenet_v3_small_integer_quant.tflite")
# Full Integer Quantization - Input/Output=int8
converter = tf.lite.TFLiteConverter.from_saved_model('./ssd_mobilenet_v3_small_coco_2019_08_14/0')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()
with open('./ssd_mobilenet_v3_small_coco_2019_08_14/mobilenet_v3_small_full_integer_quant.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Full Integer Quantization complete! - mobilenet_v3_small_full_integer_quant.tflite")
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
def representative_dataset_gen():
for data in raw_test_data.take(100):
image = data['image'].numpy()
image = tf.image.resize(image, (320, 320))
image = image[np.newaxis,:,:,:]
yield [image]
tf.compat.v1.enable_eager_execution()
# Generating a calibration data set
#raw_test_data, info = tfds.load(name="coco/2017", with_info=True, split="test", data_dir="./TFDS")
raw_test_data, info = tfds.load(name="coco/2017", with_info=True, split="test", data_dir="./TFDS", download=False)
# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./ssd_mobilenet_v3_large_coco_2019_08_14/0')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('./ssd_mobilenet_v3_large_coco_2019_08_14/mobilenet_v3_large_weight_quant.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Weight Quantization complete! - mobilenet_v3_large_weight_quant.tflite")
# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./ssd_mobilenet_v3_large_coco_2019_08_14/0')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./ssd_mobilenet_v3_large_coco_2019_08_14/mobilenet_v3_large_integer_quant.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Integer Quantization complete! - mobilenet_v3_large_integer_quant.tflite")
# Full Integer Quantization - Input/Output=int8
converter = tf.lite.TFLiteConverter.from_saved_model('./ssd_mobilenet_v3_large_coco_2019_08_14/0')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()
with open('./ssd_mobilenet_v3_large_coco_2019_08_14/mobilenet_v3_large_full_integer_quant.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Full Integer Quantization complete! - mobilenet_v3_large_full_integer_quant.tflite")
Weight Quantization
、Integer Quantization
、Full Integer Quantization
を2種類のモデルに対して一気に実行します。
$ python3 quantization_ssd_mobilenet_v3_small_coco_2019_08_14.py
$ python3 quantization_ssd_mobilenet_v3_large_coco_2019_08_14.py
こちら How to install Ubuntu 19.10 aarch64 (64bit) on RaspberryPi4 を参考に、 RaspberryPi4 へ Ubuntu 19.10 aarch64 を導入します。
次に、上記で作成した Full Integer Quantization 済みの .tflite ファイル mobilenet_v3_small_full_integer_quant.tflite
を RaspberryPi4 の HOME
(/home/pi など) へコピーします。
Tensorflow の標準ツール TFLite Model Benchmark Tool
を使用して、 mobilenet_v3_small_full_integer_quant.tflite
のパフォーマンスを計測します。 このモデルは Post-Process が含まれていませんので、公式が公開しているモデルより処理量が少なくパフォーマンスが若干高くなります。 下記は全て RaspberryPi4 上で実施する手順です。
$ sudo apt-get install wget curl git zip unzip python3-pil \
python3-opencv python3-pip libhdf5-dev openjdk-8-jdk net-tools
$ sudo -H pip3 install pip --upgrade
## Bazel for RaspberryPi3/4 Ubuntu 19.10 install
$ wget https://github.com/PINTO0309/Bazel_bin/raw/master/0.29.1/Ubuntu1910_aarch64/openjdk-8-jdk/install.sh
$ ./install.sh
## Clone Tensorflow v1.15.0
$ git clone -b v1.15.0 --depth 1 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
## Build and run TFLite Model Benchmark Tool
$ bazel run -c opt tensorflow/lite/tools/benchmark:benchmark_model -- \
--graph=${HOME}/mobilenet_v3_small_full_integer_quant.tflite \
--num_threads=4 \
--warmup_runs=1 \
--enable_op_profiling=true
下記がモデルのパフォーマンスを計測した結果です。 avg=28121.2
の部分はマイクロ秒単位で表示されます。 つまり、1回の推論の平均実行時間は 28ms
という結果なっています。 実際は推論の前後で、UIにバウンディングボックスを表示するための処理などを実装することになりますので、前処理・後処理のコストだけ遅くなることにご注意ください。
Number of nodes executed: 176
============================== Summary by node type ==============================
[Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called]
CONV_2D 61 10.255 36.582% 36.582% 0.000 61
DEPTHWISE_CONV_2D 27 5.058 18.043% 54.625% 0.000 27
MUL 26 5.056 18.036% 72.661% 0.000 26
ADD 14 4.424 15.781% 88.442% 0.000 14
QUANTIZE 13 1.633 5.825% 94.267% 0.000 13
HARD_SWISH 10 0.918 3.275% 97.542% 0.000 10
LOGISTIC 1 0.376 1.341% 98.883% 0.000 1
AVERAGE_POOL_2D 9 0.199 0.710% 99.593% 0.000 9
CONCATENATION 2 0.084 0.300% 99.893% 0.000 2
RESHAPE 13 0.030 0.107% 100.000% 0.000 13
Timings (microseconds): count=50 first=28827 curr=28176 min=27916 max=28827 avg=28121.2 std=165
Memory (bytes): count=0
176 nodes observed
Full Integer Quantization後のモデルの構造は下図のとおりです。 公式チュートリアルに記載の Pipeline を使用せずに一気に Quantization しましたので、後処理の部分がもとのモデルのままの構造になっています。 こちら https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py と https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph_lib.py を参考に Tensor を自力で分解して推論結果を取り出す必要があります。
**MobileNetV3-SSDのグラフ構造**
4.Finally
めちゃくちゃ時間を掛けましたが、推論の可視化を実装する前に力尽きました。。。 気が向いたら Post-process の実装と精度計測を頑張ってみたいと思います。
5.Reference articles
- [deeplab] what's the parameters of the mobilenetv3 pretrained model?
- When you want to fine-tune DeepLab on other datasets, there are a few cases
- [deeplab] Training deeplab model with ADE20K dataset
- Running DeepLab on PASCAL VOC 2012 Semantic Segmentation Dataset
- Quantize DeepLab model for faster on-device inference
- https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
- https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/quantize.md
- the quantized form of Shape operation is not yet implemented
- Post-training quantization
- Converter command line reference
- Quantization-aware training
- Converting a .pb file to .meta in TF 1.3
- Minimal code to load a trained TensorFlow model from a checkpoint and export it with SavedModelBuilder
- How to restore Tensorflow model from .pb file in python?
- Error with tag-sets when serving model using tensorflow_model_server tool
- ValueError: No 'serving_default' in the SavedModel's SignatureDefs. Possible values are 'name_of_my_model'
- kerasのモデルをデプロイする手順 - Signature作成方法解説
- TensorFlow で学習したモデルのグラフを
tf.train.import_meta_graph
でロードする - Tensorflowのグラフ操作 Part1
- Configure input_map when importing a tensorflow model from metagraph file
- TFLite Model Benchmark Tool
- How to install Ubuntu 19.10 aarch64 (64bit) on RaspberryPi4
- https://zhuanlan.zhihu.com/p/90690452
6.Appendix
1. anchors
の値を抽出・保存
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.framework import tensor_util
import numpy as np
GRAPH_PB_PATH = './tflite_graph.pb' #path to your .pb file
with tf.Session() as sess:
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
wts = [n for n in graph_nodes if n.op=='Const']
for n in wts:
if n.name == 'anchors':
print("Name of the node - %s" % n.name)
print("Value - ")
anchors = tensor_util.MakeNdarray(n.attr['value'].tensor)
print("anchors.shape =", anchors.shape)
print(anchors)
np.save('./anchors.npy', anchors)
np.savetxt('./anchors.csv', anchors, delimiter=',')
break
2. anchors
の値をロード
import numpy as np
anchors = np.load('./anchors.npy')
print(anchors)
3. raw_outputs/box_encodings
のデコード
https://stackoverflow.com/questions/54436186/how-to-decode-raw-outputs-box-encodings-from-tensorflow-object-detection-ssd-mob
import math
y_scale = 10.0
x_scale = 10.0
h_scale = 5.0
w_scale = 5.0
# ssdlite_mobilenet_v2
### box_encoding = [1917, 4]
### anchors = [1917, 4]
### num_boxes = 1917
# box_encoding[i][0] = box_centersize.y
# box_encoding[i][1] = box_centersize.x
# box_encoding[i][2] = box_centersize.h
# box_encoding[i][3] = box_centersize.w
# anchors[i][0] = anchor.y
# anchors[i][1] = anchor.x
# anchors[i][2] = anchor.h
# anchors[i][3] = anchor.w
def decode_box_encodings(box_encoding, anchors, num_boxes):
decoded_boxes = np.zeros((num_boxes, 4), dtype=np.float32)
for i in range(num_boxes):
ycenter = box_encoding[i][0] / y_scale * anchors[i][2] + anchors[i][0]
xcenter = box_encoding[i][1] / x_scale * anchors[i][3] + anchors[i][1]
half_h = 0.5 * math.exp((box_encoding[i][2] / h_scale)) * anchors[i][2]
half_w = 0.5 * math.exp((box_encoding[i][3] / w_scale)) * anchors[i][3]
decoded_boxes[i][0] = (ycenter - half_h) # ymin
decoded_boxes[i][1] = (xcenter - half_w) # xmin
decoded_boxes[i][2] = (ycenter + half_h) # ymax
decoded_boxes[i][3] = (xcenter + half_w) # xmax
return decoded_boxes
import numpy as np
max_detections = 10
non_max_suppression_score_threshold = 0.3
intersection_over_union_threshold = 0.6
def Non_Maximum_Suprression(box_encoding, class_predictions):
val, idx = class_predictions[:,1:].max(axis=1), \
class_predictions[:,1:].argmax(axis=1)
thresh_val, thresh_idx = np.array(val)[val>=non_max_suppression_score_threshold], \
np.array(idx)[val>=non_max_suppression_score_threshold]
thresh_box = np.array(box_encoding)[val>=non_max_suppression_score_threshold]
thresh_box_stack = np.hstack((thresh_box, thresh_idx[:, np.newaxis], thresh_val[:, np.newaxis]))
thresh_box_desc = thresh_box_stack[np.argsort(thresh_box_stack[:, 5])[::-1]]
active_box_candidate = np.ones((thresh_box_desc.shape[0], 1))
thresh_box_stack = np.hstack((thresh_box_stack, active_box_candidate))
num_boxes_kept, num_active_candidate = thresh_box_stack.shape[0]
output_size = min(num_active_candidate, max_detections)
num_selected_count = 0
for i in range(num_boxes_kept):
if (num_active_candidate == 0 or num_selected_count >= output_size):
break
if (thresh_box_stack[i, 6] == 1):
thresh_box_stack[i, 6] = 0
num_active_candidate -= 1
num_selected_count += 1
else:
continue
# thresh_box_stack = [ymin, xmin, ymax, xmax, class_idx, prob]
for j in range(i + 1, num_boxes_kept):
if (thresh_box_stack[j, 6] == 1):
intersection_over_union = ComputeIntersectionOverUnion(thresh_box_stack[i], thresh_box_stack[j])
if (intersection_over_union > intersection_over_union_threshold):
thresh_box_stack[i, 6] = 0
num_active_candidate -= 1
num_selected_count += 1
return thresh_box_stack[thresh_box_stack[:, 6] == 1]
def ComputeIntersectionOverUnion(box_i, box_j):
area_i = (box_i[2] - box_i[0]) * (box_i[3] - box_i[1])
area_j = (box_j[2] - box_j[0]) * (box_j[3] - box_j[1])
if (area_i <= 0 or area_j <= 0):
return 0.0
intersection_ymin = max(box_i[0], box_j[0])
intersection_xmin = max(box_i[1], box_j[1])
intersection_ymax = min(box_i[2], box_j[2])
intersection_xmax = min(box_i[3], box_j[3])
intersection_area = max(intersection_ymax - intersection_ymin, 0.0) * max(intersection_xmax - intersection_xmin, 0.0)
return intersection_area / (area_i + area_j - intersection_area)