5
6

More than 3 years have passed since last update.

[English ver.] [Tensorflow Lite] Various Neural Network Model quantization methods for Tensorflow Lite (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization, EdgeTPU). As of May 05, 2020.

Last updated at Posted at 2020-05-09



Japanese English

- English -

1. Introduction

In this article, I'd like to share with you the quantization workflow I've been working on for six months. This is the output of know-how for converting Tensorflow checkpoints (.ckpt/.meta), FreezeGraph (.pb), saved_model (.pb), keras_model (.h5), Tensorflow.js models, and PyTorch checkpoints (.pth) into quantization models for Tensorflow Lite. Since Tensorflow was upgraded from v1.x to v2.x, it is necessary to take special steps to absorb the differences between the versions, and I often feel that there is a lack of material to start the conversion. Tensorflow, Tensorflow Lite, Keras, ONNX, PyTorch, and OpenVINO(OpenCV) are all used in combination.

I work hard on the quantization of the Neural Network every day. I'm using a lighter model to mass produce a quantized model with the goal of making fast reasoning without a GPU on edge terminals such as the RaspberryPi4. As an example, I did 8-bit integer quantization of the model, and the result of multistage inference between two models using only RaspberryPi4 CPU is shown in the following video. Two quantization models, Object Detection (MobileNetV2-SSDLite dm=0.5) and Head Pose Estimation, are run in series.


If you find the video too small to watch, PINTO_model_zoo has an enlarged sample GIF, which can be viewed over Wi-Fi or wired.

2. Table of contents

1. Introduction
2. Table of contents
3. Environment
4. Procedure
 4-1. Check the model's INPUT and OUTPUT names and types, and change the batch size and type
  4-1-1. In the case of a Tensorflow checkpoint
  4-1-2. In the case of Tensorflow Freeze_Graph
  4-1-3. In the case of Tensorflow saved_model
  4-1-4. In the case of Tensorflow/Keras .h5/.json

 4-2. Various quantization procedures
  4-2-1. Quantization from a Tensorflow checkpoint (.ckpt)
   4-2-1-1. Generating .meta from .index and .data-00000-of-00001
   4-2-1-2. Generate Freeze_Graph from checkpoint (.meta)
   4-2-1-3. Generate a saved_model from Freeze_Graph
   4-2-1-4. Weight Quantization from saved_model (weight-only quantization)
   4-2-1-5. Integer Quantization from saved_model (8-bit integer quantization)
   4-2-1-6. saved_model to Full Integer Quantization (all 8-bit integer quantization)
   4-2-1-7. Float16 Quantization from saved_model (Float16 quantization)
   4-2-1-8. Full Integer Quantization to EdgeTPU convert

  4-2-2. Quantization from a Tensorflow checkpoint (.meta)

  4-2-3. Quantization from Tensorflow Freeze_Graph (.pb)

  4-2-4. Quantization from Tensorflow saved_model (.pb)

  4-2-5. Quantization from Tensorflow/Keras (.h5/.json)
   4-2-5-1. Weight Quantization from .h5/.json (weight quantization)
   4-2-5-2. Generating the calibration data set
   4-2-5-3. Integer Quantization from .h5/.json (8-bit integer quantization)
   4-2-5-4. Full Integer Quantization from .h5/.json (all 8-bit integer quantization)
   4-2-5-5. Float16 Quantization from .h5/.json (Float16 quantization)
   4-2-5-6. Full Integer Quantization to EdgeTPU convert

  4-2-6. Quantization from a model for Tensorflow.js
   4-2-6-1. Advance preparation
   4-2-6-2. Generating a saved_model from Tensorflow.js
   4-2-6-3. Import saved_model generated by Tensorflow v2.x into Tensorflow v1.x and process the input shape
   4-2-6-4. Installation of Tensorflow v2.2.0
   4-2-6-5. Weight Quantization from saved_model (Weight-only quantization)
   4-2-6-6. Integer Quantization from saved_model (8-bit integer quantization)
   4-2-6-7. Full Integer Quantization from saved_model (All 8-bit integer quantization)
   4-2-6-8. Float16 Quantization from saved_model (Float16 quantization)
   4-2-6-9. Full Integer Quantization to EdgeTPU convert

  4-2-7. Quantize the model generated by the TensorFlow Object Detection API
   4-2-7-1. Generating a .pb file with Post-Process
   4-2-7-2. Weight Quantization from Freeze_Graph (Weight-only quantization)
   4-2-7-3. Integer Quantization from Freeze_Graph (8-bit integer quantization)
   4-2-7-4. Full Integer Quantization from Freeze_Graph (All 8-bit integer quantization)
   4-2-7-5. Float16 Quantization from Freeze_Graph (Float16 quantization)
   4-2-7-6. Full Integer Quantization to EdgeTPU convert

  4-2-8. Quantize models containing operations that are not supported by Tensorflow Lite but are supported by Tensorflow
   4-2-8-1. Generate Mask-RCNN Inception V2 .pb file
   4-2-8-2. Weight Quantization of Mask-RCNN Inception V2 (Weight-only quantization)
   4-2-8-3. Float16 Quantization in Mask-RCNN Inception V2 (Float16 quantization)
   4-2-8-4. Running a model with Flex Delegate (Tensorflow Select Ops) enabled

  4-2-9. Quantization from a model for PyTorch
   4-2-9-1. Advance preparation (PyTorch->ONNX)
   4-2-9-2. ONNX->Keras conversion by onnx2keras
   4-2-9-3. Weight Quantization from saved_model (Weight-only quantization)
   4-2-9-4. Integer Quantization from saved_model (8-bit integer quantization)
   4-2-9-5. Full Integer Quantization from saved_model (All 8-bit integer quantization)
   4-2-9-6. Float16 Quantization from saved_model (Float16 quantization)
   4-2-9-7. Full Integer Quantization to EdgeTPU convert

  4-2-10. Quantization of MediaPipe's model BlazeFace(.tflite)
   4-2-10-1. Build flatc and download schema.fbs
   4-2-10-2. Download MediaPipe's BlazeFace model (.tflite)
   4-2-10-3. Converting BlazeFace(.tflite) to saved_model(.pb)
   4-2-10-4. Weight Quantization from saved_model (weight-only quantization)
   4-2-10-5. Integer Quantization from saved_model (8-bit integer quantization)
   4-2-10-6. Full Integer Quantization from saved_model (All 8-bit integer quantization)
   4-2-10-7. Float16 Quantization from saved_model (Float16 quantization)
   4-2-10-8. Full Integer Quantization to EdgeTPU convert

 4-3. Performance benchmarks for the quantization model (.tflite)
  4-3-1. Building the TFLite Model Benchmark Tool
  4-3-2. Options for the TFLite Model Benchmark Tool
  4-3-3. Benchmark example of a model that includes only standard Tensorflow Lite operations (No XNNPACK, 4 Threads)
  4-3-4. Benchmark example of a model that includes only standard Tensorflow Lite operations (XNNPACK available, 4 Threads)
  4-3-5. Benchmark examples of models with non-standard Tensorflow Lite operations (Flex enabled, no XNNPACK, 4 Threads)
  4-3-6. Benchmark examples of models with non-standard Tensorflow Lite operations (Flex enabled, with XNNPACK, 4 Threads)
  4-3-7. Execution log sample of Benchmark_Tool

5. Finally

6. Reference articles

3. Environment

  1. Tensorflow-GPU v1.15.2
  2. Tensorflow v2.1.0, v2.2.0 or tf-nightly
  3. Accelerated and Tuned Python API Tensorflow Lite
  4. PyTorch
  5. Caffe
  6. OpenVINO 2020.2
  7. OpenCV 4.2
  8. onnx2keras
  9. Netron
  10. RaspberryPi4 + Ubuntu aarch64

- Go to Table of contents -

4. Procedure

4-1. Check the model's INPUT and OUTPUT names and types, and change the batch size and type

This is the hardest and most time-consuming first step. The amount of effort depends on the pattern of the conversion.

- Go to Table of contents -

4-1-1. In the case of a Tensorflow checkpoint

This pattern does not provide Freeze_Graph or saved_model. The quickest way to follow this pattern is to read or run the sample code for the Inference Test. Using Netron or Tensorflow's official visualization tools (Tensorboard or summarize_graph), it is not impossible to see the structure of the model, but the visualization fails because a lot of operations required for training are left in the file, or even if it is possible to visualize, the graph is too huge and it is difficult to find the INPUT and OUTPUT.

Now, let's check the INPUT/OUTPUT of the White-box-Cartoonization model that actually animates the live action as an example. In order to describe the various elements comprehensively, I have deliberately selected a high difficulty model with many moves this time. Note that you need to install Tensorflow v1.15.2 of the v1.x system to work on this model. If you have already installed Tensorflow v2.x, you need to uninstall it temporarily and install the v1.x system again. Alternatively, you can work in a Docker environment with Tensorflow already installed without polluting the environment.

Just in case, when I check the checkpoint that is offered. Yes, for some reason, kindly only the .meta file is not committed, is it? It's normal for people to lose motivation drastically at this point. That's right. If you started reading this article, you are not an ordinary person. Anyway, since it will not be a particular problem to proceed with the work, we will proceed as it is.
FireShot Capture 015 - White-box-Cartoonization_test_code_saved_models at master · SystemErr_ - github.com.png

First of all, read cartoonize.py, the logic for the inference test. Then, let's look at the contents of the cartoonize() method, which is called right after main() at the beginning of the program. In the case of a simple, beginner-friendly test code, it only takes a minute to get to the INPUT definition part. Tensorflow's input operations are defined using placeholder. However, upon closer inspection, the "name" attribute is not defined. In this case, you can name your own name because it will make the subsequent conversion and confirmation work difficult. Also, the definition of the tensor is tainted with None. In the case of quantization, it is essential that the input resolution is fixed. Therefore, the input resolution is replaced by a fixed value. This time, I chose 720x720.

【Before】_White-box-Cartoonization/blob/master/test_code/cartoonize.py#L25-L32
def cartoonize(load_folder, save_folder, model_path):
    input_photo = tf.placeholder(tf.float32, [1, None, None, 3]) #<--- This is INPUT
    network_out = network.unet_generator(input_photo)
    final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'generator' in var.name]
    saver = tf.train.Saver(var_list=gene_vars)
【After】_White-box-Cartoonization/blob/master/test_code/cartoonize.py#L25-L32
def cartoonize(load_folder, save_folder, model_path):
    input_photo = tf.placeholder(tf.float32, [1, 720, 720, 3], name='input') #<--- This is INPUT
    network_out = network.unet_generator(input_photo)
    final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'generator' in var.name]
    saver = tf.train.Saver(var_list=gene_vars)

The following is a note about placeholder in quantization.
 1. N (batch size), H (height), W (width) and C (RGB channel) must all be fixed with integer values.
 2. When defined by NCHW (channel-first), it should be defined or converted by NHWC (channel-last).
 3. The type of placeholder should be tf.float32 (most models are defined in tf.uint8)
 4. tf.cast operation does not support quantization operations and should be removed if possible.

By the way, the reason why it is 720x720 this time is because it is described in the logic of the preprocessing part of the image as follows, the operation to resize so that it does not go under 720 pixels in height and width. It seems that there is a limit to the size that can be specified depending on the model, so you have to read all the logic for inference tests or training logic. The White-box-Cartoonization gave an error and could not be deduced if the resolution was smaller than 720x720. If you don't want to read the logic, try repeatedly to see how small the resolution can be.

White-box-Cartoonization/blob/master/test_code/cartoonize.py#L11-L22
def resize_crop(image):
    h, w, c = np.shape(image)
    if min(h, w) > 720:
        if h > w:
            h, w = int(720*h/w), 720
        else:
            h, w = 720, int(720*w/h)
    image = cv2.resize(image, (w, h),
                       interpolation=cv2.INTER_AREA)
    h, w = (h//8)*8, (w//8)*8
    image = image[:h, :w, :]
    return image

Since many of the published models deal with images, a placeholder is defined for a uint8 type that takes an RGB value of 0-255, as shown below. In most cases, the next step in placeholder is to convert the type to float32 with a tf.cast operation. As mentioned above, tf.cast is an error when quantizing, so you should remove it at this time. When you actually run the inference, the image data read by OpenCV or Pillow will be of type uint8, so you need to cast it to type Float32 just before the business logic hands over the image to Tensorflow.

A_common_example_of_uint8_being_specified_in_the_placeholder
input_photo = tf.placeholder(tf.uint8, [1, 720, 720, 3], name='input')
casted_photo = tf.cast(input_photo, tf.float32)
Example_of_replacing_the_placeholder_with_tf.float32_and_removing_tf.cast
input_photo = tf.placeholder(tf.float32, [1, 720, 720, 3], name='input')

Now let's look up the name of OUTPUT. The inference test logic of this model White-box-Cartoonization is so simple that the final OUTPUT was defined just below the INPUT definition line. The name of the variable is final_out, politely. There are two ways to look up the name of an OUTPUT without using Netron, summarize_graph or Tensorboard.

 1. Digging deeper into the program and visually identifying the end of the model structure
 2. Anyway, run a test run and find out the name of the final operation in the debug print

Since I'm a person who loves to cut corners, I took the method of 2.

White-box-Cartoonization/blob/master/test_code/cartoonize.py#L25-L32
def cartoonize(load_folder, save_folder, model_path):
    input_photo = tf.placeholder(tf.float32, [1, 720, 720, 3], name='input') 
    network_out = network.unet_generator(input_photo)
    final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3) #<--- Here's an OUTPUT
    print("input_photo.name =", input_photo.name) #<--- Added one line for debug printing of INPUT name.
    print("input_photo.shape =", input_photo.shape) #<--- Added one more line for debug printing of INPUT shapes.
    print("final_out.name =", final_out.name) #<--- Added one line for the debug print of OUTPUT name.
    print("final_out.shape =", final_out.shape) #<--- Added one line for debug printing of OUTPUT shape.

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'generator' in var.name]
    saver = tf.train.Saver(var_list=gene_vars)

When I ran the program for the inference test, the names and shapes of the INPUT and OUTPUT are output in the debug print I added earlier. Apparently, the name of the OUTPUT is add_1:0. Also, it seems that the name and shape of placeholder, which I modified earlier, are correctly reflected.
Screenshot 2020-05-04 13:55:45.png

This concludes the procedure for checking INPUT/OUTPUT names in a checkpoint model. I'll explain how to generate .meta, and I'll explain how to generate Freeze_Graph and saved_model in the steps that follow.

- Go to Table of contents -

4-1-2. In the case of Tensorflow Freeze_Graph

This is a working pattern when the model is provided in Freeze_Graph (.pb) format. It is very easy to identify the name of INPUT/OUTPUT in this case. Let's check the model of Semantic Segmentation Mobile-DeeplabV3-plus (MobileNetV2) for example.

First, download the Freeze_Graph (.pb) file from the above repository. The only thing to note here is that models with special treatment such as ASPP will fail to quantize, so you should only get a model with a simple structure as possible.
FireShot Capture 016 - nolanliou_mobile-deeplab-v3-plus_ Deeplab-V3+ model with MobilenetV2__ - github.com.png
When using the command line to download materials from Google Drive, it is necessary to bypass the confirmation dialog, so you can download by executing the 3-line command as shown below.

Download_the_model_from_Google_Drive_using_the_command_line
$ curl -sc /tmp/cookie "https://drive.google.com/uc?export=download&id=1VF5yMz_tIkTOVfgmIgg7tPAJJEEcZ49B" > /dev/null
$ CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"
$ curl -Lb /tmp/cookie "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1VF5yMz_tIkTOVfgmIgg7tPAJJEEcZ49B" -o deeplab_v3_plus_mnv2_decoder_256.pb

Screenshot 2020-05-04 14:49:51.png
Check the structure of the model you have downloaded. There is a super useful site called Netron, which you can access first.
https://lutzroeder.github.io/netron/
Screenshot 2020-05-04 14:53:01.png
"Open Model..." button and open the file deeplab_v3_plus_mnv2_decoder_256.pb which was downloaded earlier. The names of INPUT/OUTPUT are immediately known. INPUT is Input, shape and type is Float32 [?, 256, 256, 3], OUTPUT is ArgMax, shape and type is Float32 [?, 256, 256, 3], OUTPUT is ArgMax, and the shape and type is Float32 [?, 256, 256]. At first glance, you may think that ExpandDims is appropriate for the final OUTPUT, but in fact, it's not a problem if you select ArgMax, which is almost the same as Semantic Segmentation's model. Also, the ? part of the shape is synonymous with None in variable batches. However, this is not a problem for the quantization operation without immobilization. It is automatically converted to 1 when performing the quantization of the subsequent work.
Screenshot 2020-05-04 14:55:39.png
Screenshot 2020-05-04 14:56:30.png
Now, the work up to this point has been too easy, hasn't it? Let's increase the level of difficulty a bit more. Next, let's try the following Python implementation of Tensorflow.js, based on Posenet v1.

First, Clone the repository to get the Freeze_Graph and issue the following command. The argument of --image_dir can be any folder path where there is a human image file. I created a folder called images and put 24 images of people in it.
Screenshot 2020-05-04 16:02:27.png

Convert a Tensorflow.js model to a Tensorflow model. The following processes require the introduction of Tensorflow v1.15.2.

Running_git_clone_and_getting_Freeze_Graph
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator
$ sudo pip3 install tensorflow-gpu==1.15.2

$ git clone https://github.com/rwightman/posenet-python.git
$ cd posenet-python
$ python3 image_demo.py \
    --model 101 \
    --image_dir ./images \
    --output_dir ./output

When the process is finished, three types of checkpoints, .pb and .pbtxt will be created under the folder _models.
Screenshot 2020-05-04 16:07:12.png
It's based on MobileNetV1, so the accuracy is not very good. As an aside, the repository at the URL linked at the top of this article has already committed the high-precision, slow Posenet v2 quantization model, converted on a ResNet50 basis.
tennis_in_crowd.jpg
The figure below shows the results of the Posenet v2 inference to the same image from the ResNet50 backbone version. It's hard to tell, but the accuracy seems to have improved slightly.
80100848-80801180-85ab-11ea-9464-f5574c9bc5e7.jpg

Now, let's visualize the .pb file using Netron as before. Freeze_Graph (.pb) has a name, type and shape name=image Float32 [1, ?, ?, 3] for name, type and shape. As it is, the quantization operation will fail. I will now explain how to convert the INPUT shape of a model whose H (height) and W (width) are 'None' with Freeze_Graph. Note that this procedure can only be performed with Tensorflow v1.x and does not work equally well for all models.
Screenshot 2020-05-04 16:15:43.png
This section describes a program to convert the shape of the INPUT of Freeze_Graph. The general flow is as follows.
 1. Defines a placeholder with a shape to be set after conversion.
 2. Load a Freeze_Graph that you already have at hand.
 3. Import the placeholder defined in 1.
 4. ( Check the debug print to make sure placholder is captured correctly )
 5. Remove all unnecessary Node with TransformGraph, a Tensorflow v1.x-based tool.
 6. Output a processed Freeze_Graph to a .pb file.

replacement_of_input_placeholder_float32_mobilenet.py
### tensorflow-gpu==1.15.2

import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph

with tf.compat.v1.Session() as sess:

    # shape=[1, ?, ?, 3] -> shape=[1, 513, 513, 3]
    # name='image' specifies the placeholder name of the converted model
    inputs = tf.compat.v1.placeholder(tf.float32, shape=[1, 513, 513, 3], name='image')
    with tf.io.gfile.GFile('./model-mobilenet_v1_101.pb', 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

    # 'image:0' specifies the placeholder name of the model before conversion
    tf.graph_util.import_graph_def(graph_def, input_map={'image:0': inputs}, name='')
    print([n for n in tf.compat.v1.get_default_graph().as_graph_def().node if n.name == 'image'])

    # Delete Placeholder "image" before conversion
    # see: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms
    # TransformGraph(
    #     graph_def(),
    #     input_op_name,
    #     output_op_names,
    #     conversion options
    # )
    optimized_graph_def = TransformGraph(
                              tf.compat.v1.get_default_graph().as_graph_def(),
                              'image',
                              ['heatmap','offset_2','displacement_fwd_2','displacement_bwd_2'],
                              ['strip_unused_nodes(type=float, shape="1,513,513,3")'])

    tf.io.write_graph(optimized_graph_def, './', 'model-mobilenet_v1_101_513.pb', as_text=False)

See Graph Transform Tool for specifications and usage of TransformGraph. Execute the created INPUT shape conversion program. The shape of the INPUT of the Freeze_Graph model-mobilenet_v1_101_513.pb generated after program execution is assumed to be [1, 513, 513, 3]. If you want to convert it to a slightly smaller shape, just change it like shape=[1, 257, 257, 3].

Execute_the_INPUT_shape_conversion_program_of_Freeze_Graph
$ python3 replacement_of_input_placeholder_float32_mobilenet.py

Screenshot 2020-05-04 17:31:30.png
Check the shape of the generated model-mobilenet_v1_101_513.pb with Netron.
Screenshot 2020-05-04 17:33:56.png
This is the end of the procedure to find INPUT/OUTPUT names in a Freeze_Graph model and to convert a Freeze_Graph shape.

- Go to Table of contents -

4-1-3. In the case of Tensorflow saved_model

If the model is provided in the saved_model (.pb) format, this is the pattern. Again, it is very easy to identify the name of the INPUT/OUTPUT. So far, however, I don't see many examples of pre-trained models being offered in this format. This time, let's check the names of INPUT and OUTPUT based on the following Head Pose Estimation.

First, please clone the repository.

Clone,_the_repository_for_Head_Pose_Estimation
$ git clone https://github.com/yinguobing/head-pose-estimation.git
$ cd head-pose-estimation/assets

If you display a saved_model with a very large number of operators with the Web version of Netron, a warning is displayed as shown below, and it may take more than 5 minutes to render.
Screenshot 2020-05-04 20:02:14.png
In the case of the proprietary version, it is displayed as shown below.
Screenshot 2020-05-04 20:00:45.png

So, if you want to check saved_model format INPUT/OUTPUT easily, you can use standard command saved_model_cli. Since there is a folder called pose_model under the assets folder and saved_model.pb is deployed underneath it, run an analysis command on the pose_model folder directly underneath the assets folder.
Screenshot 2020-05-04 19:55:12.png

$ saved_model_cli show --dir pose_model --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['predict']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['image'] tensor_info:
        dtype: DT_UINT8
        shape: (-1, -1, -1, 3)
        name: image_tensor:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 136)
        name: layer6/final_dense:0
  Method name is: tensorflow/serving/predict

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['image'] tensor_info:
        dtype: DT_UINT8
        shape: (-1, -1, -1, 3)
        name: image_tensor:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 136)
        name: layer6/final_dense:0
  Method name is: tensorflow/serving/predict

Looking at the definition of signature_def['serving_default'], the INPUT is defined as image Uint8[-1, -1, -1, 3] and the OUTPUT as output Float32[-1, 136]. As in the previous example, N (batch size), H (height), and W (width) of the INPUT are set to -1, so the signature needs to be rewritten. (-1 is synonymous with ? and None). Rather than my half-hearted explanation, the following article Summary of SavedModel - Qiita - t_shimmura is very helpful, so please see it. In addition, if you want to talk only about Head Pose Estimation, the training script and the export script to saved_model are described in the following repository.

For saved_model, we have only described how to use the saved_model_cli command to check the structure. I'll touch on this later in the process along with the checkpoint -> saved_model or Freeze_Graph -> saved_model conversion scripts.

- Go to Table of contents -

4-1-4. In the case of Tensorflow/Keras .h5/.json

If the model is provided in Keras (.h5/.json) format, this is the pattern.

For example, to check the names of INPUTs and OUTPUTs based on the above repository, you will need a .json file after the training as shown below.

model = Model(inputs=xxxx,outputs=yyyy)
# model save
model_json = model.to_json()
open(model_path + 'model.json', 'w').write(model_json)
model.save_weights(model_path + 'weights.h5')

If you open the above example model.json with Netron, you can visualize it as shown below.
Screenshot 2020-05-04 22:52:00.png
However, the syntax of Keras is very simple, so it would be faster to look directly at the program structure of the model.

If you want to change the shape of the input tensor, the QA of Keras -Transfer Learning - Change Input Tensor Shape is helpful. The easiest way seems to be to redefine an empty model with only a different INPUT size and re-transfer only the weights.

Example_of_changing_the_shape_of_the_INPUT_tensor
inputs = Input((None, None, 3))
.....
model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='mean_squared_error')
model.load_weights('my_model_name.h5')

inputs2 = Input((512, 512, 3))
....
model2 = Model(inputs=[inputs2], outputs=[outputs])
model2.compile(optimizer='adam', loss='mean_squared_error')
model2.set_weights(model.get_weights())

- Go to Table of contents -

4-2. Various quantization procedures

In this section I will try to quantize the various models into various patterns. In order to adapt the procedure to any situation, please be forgiven that there is some duplication of content in each section, and that some procedures are not necessary for quantization at the shortest distance.

- Go to Table of contents -

4-2-1. Quantization from a Tensorflow checkpoint (.ckpt)

Let's quantize a model that animates a live action, White-box-Cartoonization, as an example. As described in 4-1-1. In the case of a Tensorflow checkpoint, the trained checkpoint in this model is released in a special state where only .meta does not exist. There are three normal checkpoints, .index .data-00000-of-00001 .meta, but this one is missing one file. I'm going to explain this procedure from the point of view of creating .meta, but this is only under the special circumstance that .meta doesn't exist, and it's not necessary unless you have to. Please take this as an example only. Of the steps described below, 4-2-1-2. Generate Freeze_Graph from checkpoint (.meta) is performed with Tensorflow v1.15.2 due to model constraints, and 4-2-1-3. Generate a saved_model from Freeze_Graph to 4-2-1-7. Float16 Quantization from saved_model (Float16 quantization) is performed with the latest Tensorflow v1.2.2.x or tf-nightly to support the latest operators and avoid bugs in Tensorflow itself.

- Go to Table of contents -

4-2-1-1. Generating .meta from .index and .data-00000-of-00001

.index .data-00000-of-00001 .meta This step is not necessary if three types of checkpoints are provided: .index .data-00000-of-00001 .meta This section describes how to create .meta from .index and .data-00000-of-00001. I use the test code of White-box-Cartoonization as much as possible. Here's how the modified logic works
 1. checkpoint Create a folder export for temporary output.
 2. Build a model.
 3. Debug printing the name and shape of INPUT/OUTPUT.
 4. Restore the checkpoint.
 5. Immediately save to the export folder.

【Before】_White-box-Cartoonization/blob/master/test_code/cartoonize.py#L25-L39
def cartoonize(load_folder, save_folder, model_path):
    input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
    network_out = network.unet_generator(input_photo)
    final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)


    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'generator' in var.name]
    saver = tf.train.Saver(var_list=gene_vars)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)


    sess.run(tf.global_variables_initializer())
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
【After】_White-box-Cartoonization/blob/master/test_code/cartoonize.py#L25-L39
def cartoonize(load_folder, save_folder, model_path):
    import sys
    import shutil
    shutil.rmtree('./export', ignore_errors=True)

    input_photo = tf.placeholder(tf.float32, [1, 720, 720, 3], name='input')
    network_out = network.unet_generator(input_photo)
    final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)
    print("input_photo.name =", input_photo.name)
    print("input_photo.shape =", input_photo.shape)
    print("final_out.name =", final_out.name)
    print("final_out.shape =", final_out.shape)
    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'generator' in var.name]
    saver = tf.train.Saver(var_list=gene_vars)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    saver.save(sess, './export/model.ckpt')
    sys.exit(0)

Let's run it.

$ python3 cartoonize.py

Three types of checkpoints, .index .data-00000-of-00001 .meta, have been successfully generated.
Screenshot 2020-05-05 09:35:51.png
- Go to Table of contents -

4-2-1-2. Generate Freeze_Graph from checkpoint (.meta)

Now let's create a Freeze_Graph from a checkpoint. Let's take the .meta file we just created and create a Freeze_Graph. Here is a sample of the modified cartoonize method. However, it's OK to cut and run the new part of the file in a separate .py file.

cartoonize.py
def cartoonize(load_folder, save_folder, model_path):
    import sys
    #import shutil
    #shutil.rmtree('./export', ignore_errors=True)

    #input_photo = tf.placeholder(tf.float32, [1, 720, 720, 3], name='input')
    #network_out = network.unet_generator(input_photo)
    #final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)
    #print("input_photo.name =", input_photo.name)
    #print("input_photo.shape =", input_photo.shape)
    #print("final_out.name =", final_out.name)
    #print("final_out.shape =", final_out.shape)
    #all_vars = tf.trainable_variables()
    #gene_vars = [var for var in all_vars if 'generator' in var.name]
    #saver = tf.train.Saver(var_list=gene_vars)
    #config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #sess = tf.Session(config=config)
    #sess.run(tf.global_variables_initializer())
    #saver.restore(sess, tf.train.latest_checkpoint(model_path))
    #saver.save(sess, './export/model.ckpt')
    #sys.exit(0)
    graph = tf.get_default_graph()
    sess = tf.Session()
    saver = tf.train.import_meta_graph('./export/model.ckpt.meta')
    saver.restore(sess, './export/model.ckpt')
    tf.train.write_graph(sess.graph_def, './export', 'white_box_cartoonization_freeze_graph.pbtxt', as_text=True)
    tf.train.write_graph(sess.graph_def, './export', 'white_box_cartoonization_freeze_graph.pb', as_text=False)
    sys.exit(0)

Let's run it.

$ python3 cartoonize.py

The white_box_cartoonization_freeze_graph.pb has been successfully generated.
Screenshot 2020-05-05 10:03:03.png
Checking the structure with Netron doesn't seem to be a problem.
Screenshot 2020-05-05 10:06:33.png

- Go to Table of contents -

4-2-1-3. Generate a saved_model from Freeze_Graph

Generate a saved_model from Freeze_Graph. It is written to work for both Tensorflow v1.x and Tensorflow v2.x. For input_name= and outputs=, specify the names of INPUT and OUTPUT specified in 4-1-1. In the case of a Tensorflow checkpoint. It can probably be used for any model, as long as you know the names of the INPUT and OUTPUT, and have Freeze_Graph at hand.

freeze_the_saved_model.py
import tensorflow as tf
import os
import shutil
from tensorflow.python import ops

def get_graph_def_from_file(graph_filepath):
  tf.compat.v1.reset_default_graph()
  with ops.Graph().as_default():
    with tf.compat.v1.gfile.GFile(graph_filepath, 'rb') as f:
      graph_def = tf.compat.v1.GraphDef()
      graph_def.ParseFromString(f.read())
      return graph_def

def convert_graph_def_to_saved_model(export_dir, graph_filepath, input_name, outputs):
  graph_def = get_graph_def_from_file(graph_filepath)
  with tf.compat.v1.Session(graph=tf.Graph()) as session:
    tf.import_graph_def(graph_def, name='')
    tf.compat.v1.saved_model.simple_save(
        session,
        export_dir,# change input_image to node.name if you know the name
        inputs={input_name: session.graph.get_tensor_by_name('{}:0'.format(node.name))
            for node in graph_def.node if node.op=='Placeholder'},
        outputs={t.rstrip(":0"):session.graph.get_tensor_by_name(t) for t in outputs}
    )
    print('Graph converted to SavedModel!')

tf.compat.v1.enable_eager_execution()

input_name="input"
outputs = ['add_1:0']
shutil.rmtree('./saved_model', ignore_errors=True)
convert_graph_def_to_saved_model('./saved_model', './white_box_cartoonization_freeze_graph.pb', input_name, outputs)

"""
$ saved_model_cli show --dir saved_model --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 720, 720, 3)
        name: input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['add_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 720, 720, 3)
        name: add_1:0
  Method name is: tensorflow/serving/predict
"""

Let's run it.

Freeze_Graphからsaved_modelを生成
$ python3 freeze_the_saved_model.py

It was successfully generated.
Screenshot 2020-05-05 10:52:41.png
Screenshot 2020-05-05 10:53:28.png
Let's check the structure of saved_model. It seems to have worked out well.

Checking_the_structure_of_saved_model
$ saved_model_cli show --dir saved_model --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 720, 720, 3)
        name: input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['add_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 720, 720, 3)
        name: add_1:0
  Method name is: tensorflow/serving/predict

- Go to Table of contents -

4-2-1-4. Weight Quantization from saved_model (weight-only quantization)

Finally, the main topic is quantization. Create a program that does Weight Quantization from saved_model and generates a .tflite that can work with Tensorflow Lite.

weight_quantization.py
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model')
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('./white_box_cartoonization_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - white_box_cartoonization_weight_quant.tflite")

Execute it.

Weight_Quantizationの実行
$ python3 weight_quantization.py

It's been generated safely. The file size has been reduced to a quarter of the original Freeze_Graph. It is important to note at this point that even if the file size is reduced by a quarter, it does not mean that the inference performance is four times faster. You need to understand that quantization of weights only means mere compression of file size. It depends on the environment in which the inference is performed, but as an example, if you want to improve performance when performing inference on the RaspberryPi4 CPU, you need to perform 4-2-1-5. Integer Quantization from saved_model (8-bit integer quantization) in the next section.
Screenshot 2020-05-05 11:01:40.png

- Go to Table of contents -

4-2-1-5. Integer Quantization from saved_model (8-bit integer quantization)

Create a program that does a Integer Quantization from a saved_model and generates a .tflite that can work with Tensorflow Lite. In case of Integer Quantization, it is necessary to give the image data for calibration in the process of converting the number of Float32 to UInt8. If possible, it would be better to use the images used during training, but this time I used a dataset that I can easily prepare on hand. If you write tfds.load(...), Google deploys datasets for training in Tensorflow Datasets on the cloud, so it will be downloaded automatically. You only need to do the download once, and I recommend that you change download=True to download=False for the second and subsequent runs. The sample logic below is set to automatically download the Pascal-VOC 2007 image dataset, but if you want to use other image datasets, you can find them in the left pane of the Tensorflow Datasets Catalog page. Most image datasets are available, but some may need to be downloaded manually due to copyright or sensitive images. (e.g., a dataset of face images). Among the various quantization operations, this Integer Quantization shows the best performance when performing inference on the RaspberryPi4 CPU alone. If you are interested, please click here 3. TFLite Model Benchmark. Here are the benchmark results for Ubuntu 19.10 aarch64 on RaspberryPi4 using the Integer Quantization model. Also, the benchmark results of Post-training quantization with TF2.0 Keras - nb.o's Diary - Nextremer_nb_o here are very helpful. Very importantly, since the number of operations that support quantization is increasing all the time, I recommend using the latest Tensorflow (Tensorflow v2.2.x / tf-nightly) for Integer Quantization and the Full Integer Quantization operations described below.

The flow of processing in representative_dataset_gen() is as follows.
 1. Convert the data acquired by Tensorflow Datasets to Numpy
 2. Resize the image size to INPUT size 720x720
 3. Normalize image data to the range from -1 to 1
 4. To match the shape of the INPUT [1, 720, 720, 3], add one dimension for batch size to the beginning of the image data in [720, 720, 3]
 5. Return one image

integer_quantization.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

def representative_dataset_gen():
  for data in raw_test_data.take(100):
    image = data['image'].numpy()
    image = tf.image.resize(image, (720, 720))
    image = image / 127.5 - 1
    image = image[np.newaxis,:,:,:]
    yield [image]

tf.compat.v1.enable_eager_execution()

raw_test_data, info = tfds.load(name="voc/2007", 
                                 with_info=True, 
                                 split="validation", 
                                 data_dir="~/TFDS", 
                                 download=True)

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model')
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./white_box_cartoonization_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - white_box_cartoonization_integer_quant.tflite")

Let's run it.

Executing_Integer_Quantization
$ python3 integer_quantization.py

It's been generated safely.
Screenshot 2020-05-05 11:36:53.png

- Go to Table of contents -

4-2-1-6. saved_model to Full Integer Quantization (all 8-bit integer quantization)

The notes and logic structure are almost the same as Integer Quantization and are omitted. Now let's write a program to execute Full Integer Quantization.

※ Unfortunately, as of 05/05/2020, the operation Div in the model of White-box-Cartoonization does not support Full Integer Quantization, so the script below will abort. However, other models that do not include Div work fine, so I'll describe the logic on the premise that it can be diverted to other models.

The .tflite file generated by Full Integer Quantization is the file you will need to generate the model for EdgeTPU. The performance of CPU inference on RaspberryPi4 is exactly the same as the Integer Quantization model.

full_integer_quantization.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

def representative_dataset_gen():
  for data in raw_test_data.take(100):
    image = data['image'].numpy()
    image = tf.image.resize(image, (720, 720))
    image = image / 127.5 - 1
    image = image[np.newaxis,:,:,:]
    yield [image]

tf.compat.v1.enable_eager_execution()

raw_test_data, info = tfds.load(name="voc/2007", 
                                 with_info=True, 
                                 split="validation", 
                                 data_dir="~/TFDS", 
                                 download=False)

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model')
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./white_box_cartoonization_full_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Full Integer Quantization complete! - white_box_cartoonization_full_integer_quant.tflite")

- Go to Table of contents -

4-2-1-7. Float16 Quantization from saved_model (Float16 quantization)

Generates a GPU-optimized Float16 quantization model suitable for operations. The program is described below.

float16_quantization.py
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Float16 Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model')
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('./white_box_cartoonization_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Float16 Quantization complete! - white_box_cartoonization_float16_quant.tflite")

Execute it.

Float16_Quantizationの実行
$ python3 float16_quantization.py

It seems to have been generated safely.
Screenshot 2020-05-05 12:03:30.png

- Go to Table of contents -

4-2-1-8. Full Integer Quantization to EdgeTPU convert

These are the steps that can be taken if Full Integer Quantization is successful. Perform model compilation for use with Google Coral EdgeTPU. It will abort if it contains unsupported or poorly implemented operations. It is my opinion, but the compiler is still unstable.

Edge TPU Compiler can be deployed according to Edge TPU Compiler here.

Executing_EdgeTPU_compilation
$ edgetpu_compiler -s white_box_cartoonization_full_integer_quant.tflite

Incidentally, the latest compiler 2.1.302470888 is said to support efficient reasoning on multi-TPU. Currently, only C++ APIs are provided, which is a burden for me as a Python user. I've owned three of them...:confounded:
04.jpg
@iwatake2222 was one of the first to implement Pipeline at this repository.

- Go to Table of contents -

4-2-2. Quantization from a Tensorflow checkpoint (.meta)

Same as 4-2-1-2. Generate Freeze_Graph from checkpoint (.meta) to 4-2-1-8. Full Integer Quantization to EdgeTPU convert. Change the image dataset used for calibration according to the characteristics of the model.

- Go to Table of contents -

4-2-3. Quantization from Tensorflow Freeze_Graph (.pb)

4-2-1-3. Generate a saved_model from Freeze_Graph to 4-2-1-8. Full Integer Quantization to EdgeTPU convert. Change the image dataset used for calibration according to the characteristics of the model.

- Go to Table of contents -

4-2-4. Quantization from Tensorflow saved_model (.pb)

4-2-1-4. Weight Quantization from saved_model (weight-only quantization) to 4-2-1-8. Full Integer Quantization to EdgeTPU convert Change the image dataset used for calibration according to the characteristics of the model.

- Go to Table of contents -

4-2-5. Quantization from Tensorflow/Keras (.h5/.json)

If the model is provided in Keras' .h5 and .json formats, this is the pattern. This time we'll use Faster-Grad-CAM as an example. The older Tensorflows up to Tensorflow v1.x and Tensorflow v2.1 seem to have an OOM bug (Out of Memory) in Keras quantization, so it is recommended to introduce Tensorflow v2.2.0 or tf-nightly to work with them.


Thankfully, all the material needed for quantization is committed, except for the image dataset needed for calibration.
FireShot Capture 019 - Faster-Grad-CAM_model at master · shinmura0_Faster-Grad-CAM - github.com.png
I will use it to generate the missing calibration dataset.

- Go to Table of contents -

4-2-5-1. Weight Quantization from .h5/.json (weight quantization)

You can do Weight Quantization immediately because the material is available. First, we clone the repository.

$ git clone https://github.com/shinmura0/Faster-Grad-CAM.git
$ cd Faster-Grad-CAM/model

The following is the program to execute Weight Quantization. The difference between the quantization pattern from the previous Tensorflow checkpoint is the difference between the loading part of the model and weights and the method for conversion. TFLiteConverter has the ability to pass a weight-loaded model object as is.

 1. model = tf.keras.models.model_from_json(open('model.json').read())
 2. model.load_weights('weights.h5')
 3. tf.lite.TFLiteConverter.from_keras_model(model)

weight_quantization.py
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Weight Quantization - Input/Output=float32
# INPUT  = input_1 (float32, 1 x 96 x 96 x 3)
# OUTPUT = block_16_expand_relu, global_average_pooling2d_1
model = tf.keras.models.model_from_json(open('model.json').read())
model.load_weights('weights.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('./weights_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - weights_weight_quant.tflite")

Let's run it.

Executing_Weight_Quantization
$ python3 weight_quantization.py

It seems to have been generated safely.
Screenshot 2020-05-05 14:11:48.png
- Go to Table of contents -

4-2-5-2. Generating the calibration data set

Clone a dataset image for calibration from @karaage0703's repository. Faster-Grad-CAM's trained models are trained with this gu dataset. It's tedious to go through all the disparate images, so I'll use the subsequent steps to pack them into one file.

$ wget https://github.com/karaage0703/janken_dataset.git
$ cd gu

Create a program to pack Numpy into binary format files.

image_to_npy.py
from PIL import Image
import os, glob
import numpy as np

dataset = []

files = glob.glob("*.JPG")
for file in files:
    image = Image.open(file)
    image = image.convert("RGB")
    data = np.asarray(image)
    dataset.append(data)

dataset = np.array(dataset)
np.save("janken_dataset", dataset)

Now, let's give it a go.

Generating_.npy_format_datasets
$ python3 image_to_npy.py

It seems to have been generated without incident.
Screenshot 2020-05-05 14:27:08.png
- Go to Table of contents -

4-2-5-3. Integer Quantization from .h5/.json (8-bit integer quantization)

The method of Integer Quantization is almost the same as the previous method, except that it is handled slightly differently.

integer_quantization.py
import tensorflow as tf
import numpy as np

def representative_dataset_gen():
    raw_test_data = np.load('janken_dataset.npy')
    for image in raw_test_data:
        image = tf.image.resize(image, (96, 96))
        image = image / 255
        calibration_data = image[np.newaxis, :, :, :]
        yield [calibration_data]

tf.compat.v1.enable_eager_execution()

# Integer Quantization - Input/Output=float32
# INPUT  = input_1 (float32, 1 x 96 x 96 x 3)
# OUTPUT = block_16_expand_relu, global_average_pooling2d_1
model = tf.keras.models.model_from_json(open('model.json').read())
model.load_weights('weights.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./weights_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - weights_integer_quant.tflite")

Let's run it.

Executing_Integer_Quantization
$ python3 integer_quantization.py

It seems to have been generated without incident.
Screenshot 2020-05-05 14:35:07.png
- Go to Table of contents -

4-2-5-4. Full Integer Quantization from .h5/.json (all 8-bit integer quantization)

Write a program to do Full Integer Quantization.

full_integer_quantization.py
import tensorflow as tf
import numpy as np

def representative_dataset_gen():
    raw_test_data = np.load('janken_dataset.npy')
    for image in raw_test_data:
        image = tf.image.resize(image, (96, 96))
        image = image / 255
        calibration_data = image[np.newaxis, :, :, :]
        yield [calibration_data]

tf.compat.v1.enable_eager_execution()

# Integer Quantization - Input/Output=float32
# INPUT  = input_1 (float32, 1 x 96 x 96 x 3)
# OUTPUT = block_16_expand_relu, global_average_pooling2d_1
model = tf.keras.models.model_from_json(open('model.json').read())
model.load_weights('weights.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True   #<--- Tensorflow v2.2.x以降を使用している場合は不要
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./weights_full_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Full Integer Quantization complete! - weights_full_integer_quant.tflite")

Let's run it.

Executing_Full_Integer_Quantization
$ python3 full_integer_quantization.py

It seems to have been generated without incident.
Screenshot 2020-05-05 14:49:41.png
- Go to Table of contents -

4-2-5-5. Float16 Quantization from .h5/.json (Float16 quantization)

Write a program to do Float16 Quantization.

float16_quantization.py
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Weight Quantization - Input/Output=float32
# INPUT  = input_1 (float32, 1 x 96 x 96 x 3)
# OUTPUT = block_16_expand_relu, global_average_pooling2d_1
model = tf.keras.models.model_from_json(open('model.json').read())
model.load_weights('weights.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True   #<--- Not necessary if you are using Tensorflow v2.2.x or later.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('./weights_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Float16 Quantization complete! - weights_float16_quant.tflite")

Let's run it.

Executing_Float16_Quantization
$ python3 float16_quantization.py

It seems to have been generated safely.
Screenshot 2020-05-05 14:54:34.png
- Go to Table of contents -

4-2-5-6. Full Integer Quantization to EdgeTPU convert

A model with Full Integer Quantization is used to generate a model with EdgeTPU support.

Executing_EdgeTPU_compilation
$ edgetpu_compiler -s weights_full_integer_quant.tflite
Logging_EdgeTPU_compilation
Edge TPU Compiler version 2.1.302470888

Model compiled successfully in 359 ms.

Input model: weights_full_integer_quant.tflite
Input size: 1.00MiB
Output model: weights_full_integer_quant_edgetpu.tflite
Output size: 1.06MiB
On-chip memory used for caching model parameters: 1014.00KiB
On-chip memory remaining for caching model parameters: 6.78MiB
Off-chip memory used for streaming uncached model parameters: 0.00B
Number of Edge TPU subgraphs: 1
Total number of operations: 71
Operation log: weights_full_integer_quant_edgetpu.log

Model successfully compiled but not all operations are supported by the Edge TPU. A percentage of the model will instead run on the CPU, which is slower. If possible, consider updating your model to use only operations supported by the Edge TPU. For details, visit g.co/coral/model-reqs.
Number of operations that will run on Edge TPU: 68
Number of operations that will run on CPU: 3

Operator                       Count      Status

DEPTHWISE_CONV_2D              17         Mapped to Edge TPU
DEQUANTIZE                     2          Operation is working on an unsupported data type
MEAN                           1          Mapped to Edge TPU
ADD                            10         Mapped to Edge TPU
QUANTIZE                       1          Operation is otherwise supported, but not mapped due to some unspecified limitation
PAD                            5          Mapped to Edge TPU
CONV_2D                        35         Mapped to Edge TPU

It seems to have been generated safely.
Screenshot 2020-05-05 14:59:25.png
Screenshot 2020-05-05 15:00:57.png
- Go to Table of contents -

4-2-6. Quantization from a model for Tensorflow.js

Now it's time to tackle a more challenging task. Quantize Posenet V2 ResNet50 for Tensorflow.js, which was recently released by Google. To begin with, the code for checkpoint and training isn't published, so it's a tricky task. This is the repository from which to check. You will need to switch between Tensoflow v2.1.0 and Tensorflow v1.15.2. If you don't want to go through the trouble, please use PINTO_model_zoo, which is already committed with all patterns converted to the quantized model.



- Go to Table of contents -

4-2-6-1. Advance preparation

Installation_of_Tensorflow_v2.1.0
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator
$ sudo pip3 install tensorflow==2.1.0
Installation_of_tfjs-to-tf
$ git clone https://github.com/patlevin/tfjs-to-tf.git 
$ cd tfjs-to-tf 
$ sudo pip3 install . --no-deps
$ cd ..
Clone_in_posenet-python
$ git clone https://github.com/atomicbits/posenet-python.git
$ cd posenet-python
$ mkdir -p output

- Go to Table of contents -

4-2-6-2. Generating a saved_model from Tensorflow.js

Create a saved_model by the following command. By specifying stride, it is possible to generate a model that balances accuracy and speed. You must save one or more sample images in the folder path specified in the argument of --image_dir in advance. Let's run it.

Create_and_execute_saved_model
$ python3 image_demo.py \
    --model resnet50 \
    --stride 16 \
    --image_dir ./images \
    --output_dir ./output

Generating saved_model output stride with 16 created saved_model under tf_models/posenet/resnet50_float/stride16/.
Screenshot 2020-05-05 15:59:19.png
Check the structure with Netron. I know that INPUT is sub_2, but the shape is Float32 [1, ?, ?, 3]. As it is, the quantization will fail. I'm going to make another adjustment here.

※Actually, you can fix the INPUT shape of sub_2 in tfjs_models/posenet/resnet50_float/stride16/model-stride16.json just by modifying and re-converting the INPUT shape by hand. However, I'm going to go a long way just to explain the special processing for importing and processing Tensorflow v2 saved_model with Tensorflow v1.x.
Screenshot 2020-05-05 16:23:57.png
Screenshot 2020-05-05 16:01:54.png
- Go to Table of contents -

4-2-6-3. Import saved_model generated by Tensorflow v2.x into Tensorflow v1.x and process the input shape

Import saved_model into Tensorflow v1.15.2 and process the input shape. The only reason I use Tensorflow v1.15.2 is because I want to use the TransformGraph tool.

Installation_of_Tensorflow_v1.15.2
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator
$ sudo pip3 install tensorflow==1.15.2

The following is a program to change the input shape of saved_model. Replace the name sub_2 with the name image, while changing the input shape to [1, ?, ?, 3] to [1, 513, 513, 3], while replacing the name sub_2 with the name image. If you want to change the input form to a smaller form, you need to change 513 to 257. The part that is slightly different from the program introduced so far is that the logic to read the .pb file is slightly changed in order to import saved_model generated by Tensorflow v2.x into Tensorflow v1.x.

replacement_of_input_placeholder_float32_resnet.py
### tensorflow-gpu==1.15.2

import sys
import tensorflow as tf
from tensorflow.tools.graph_transforms import TransformGraph
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.compat.v1.Session() as sess:

    # shape=[1, ?, ?, 3] -> shape=[1, 513, 513, 3]
    # name='image' specifies the placeholder name of the converted model

    inputs = tf.compat.v1.placeholder(tf.float32, shape=[1, 513, 513, 3], name='image')
    #inputs = tf.compat.v1.placeholder(tf.float32, shape=[1, 385, 385, 3], name='image')
    #inputs = tf.compat.v1.placeholder(tf.float32, shape=[1, 321, 321, 3], name='image')
    #inputs = tf.compat.v1.placeholder(tf.float32, shape=[1, 257, 257, 3], name='image')
    #inputs = tf.compat.v1.placeholder(tf.float32, shape=[1, 225, 225, 3], name='image')

    with gfile.FastGFile('_tf_models/posenet/resnet50_float/stride32/saved_model.pb', 'rb') as f:
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        if 1 != len(sm.meta_graphs):
            print('More than one graph found. Not sure which to write')
            sys.exit(1)

    # 'image:0' specifies the placeholder name of the model before conversion
    tf.graph_util.import_graph_def(sm.meta_graphs[0].graph_def, input_map={'sub_2:0': inputs}, name='')
    print([n for n in tf.compat.v1.get_default_graph().as_graph_def().node if n.name == 'image'])

    # Delete Placeholder "image" before conversion
    # see: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms
    # TransformGraph(
    #     graph_def(),
    #     input_name,
    #     output_names,
    #     conversion options
    # )
    optimized_graph_def = TransformGraph(
                              tf.compat.v1.get_default_graph().as_graph_def(),
                              'image',
                              ['float_heatmaps','float_short_offsets','resnet_v1_50/displacement_fwd_2/BiasAdd','resnet_v1_50/displacement_bwd_2/BiasAdd'],
                              ['strip_unused_nodes(type=float, shape="1,513,513,3")'])

    tf.io.write_graph(optimized_graph_def, './', 'posenet_resnet50_32_513.pb', as_text=False)
Some_changes_from_TFv2.x_to_TFv1.x_to_import_into_saved_model
    with gfile.FastGFile('_tf_models/posenet/resnet50_float/stride32/saved_model.pb', 'rb') as f:
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        if 1 != len(sm.meta_graphs):
            print('More than one graph found. Not sure which to write')
            sys.exit(1)

Let's run it.

Execute_input_shape_change_of_saved_model_of_TFv2.x
$ python3 replacement_of_input_placeholder_float32_resnet.py

It seems to have been generated safely.
Screenshot 2020-05-05 17:07:00.png
Screenshot 2020-05-05 17:06:12.png
- Go to Table of contents -

4-2-6-4. Installation of Tensorflow v2.2.0

Change Tensorflow v1.15.2 to Tensorflow v2.2.0 before quantization.

Installation_of_Tensorflow_v2.2.0
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator
$ sudo pip3 install tensorflow==2.2.0

4-2-6-5. Weight Quantization from saved_model (Weight-only quantization)

The quantization procedure from here on is the same as the previous one. The program for quantization is described below.

weight_quantization_resnet.py
### tensorflow==2.2.0

import tensorflow as tf
import numpy as np

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_225')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_225_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_16_225_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_257')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_257_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_16_257_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_321')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_321_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_16_321_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_385')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_385_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_16_385_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_513')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_513_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_16_513_weight_quant.tflite")


# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_225')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_225_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_32_225_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_257')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_257_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_32_257_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_321')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_321_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_32_321_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_385')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_385_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_32_385_weight_quant.tflite")

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_513')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_513_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - posenet_resnet50_32_513_weight_quant.tflite")

- Go to Table of contents -

4-2-6-6. Integer Quantization from saved_model (8-bit integer quantization)

The method of Integer Quantization is the same as before. The images used for the calibration data set are 100 images with only people in them.

integer_quantization_resnet.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import os
import glob

## Generating a calibration data set
def representative_dataset_gen():
    folder = ["images"]
    image_size = 225
    raw_test_data = []
    for name in folder:
        dir = "./" + name
        files = glob.glob(dir + "/*.jpg")
        for file in files:
            image = Image.open(file)
            image = image.convert("RGB")
            image = image.resize((image_size, image_size))
            image = np.asarray(image).astype(np.float32)
            image = image[np.newaxis,:,:,:]
            raw_test_data.append(image)

    for data in raw_test_data:
        yield [data]

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_225')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_225_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_225_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_257')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_257_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_257_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_321')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_321_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_321_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_385')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_385_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_385_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_513')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_513_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_513_integer_quant.tflite")



# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_225')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_225_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_225_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_257')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_257_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_257_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_321')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_321_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_321_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_385')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_385_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_385_integer_quant.tflite")

# # Integer Quantization - Input/Output=float32
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_513')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_513_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_513_integer_quant.tflite")

- Go to Table of contents -

4-2-6-7. Full Integer Quantization from saved_model (All 8-bit integer quantization)

The method of Full Integer Quantization is the same as before. The images used for the calibration data set are 100 images with only people in them.

full_integer_quantization_resnet.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import os
import glob

## Generating a calibration data set
def representative_dataset_gen():
    folder = ["images"]
    image_size = 225
    raw_test_data = []
    for name in folder:
        dir = "./" + name
        files = glob.glob(dir + "/*.jpg")
        for file in files:
            image = Image.open(file)
            image = image.convert("RGB")
            image = image.resize((image_size, image_size))
            image = np.asarray(image).astype(np.float32)
            image = image[np.newaxis,:,:,:]
            raw_test_data.append(image)

    for data in raw_test_data:
        yield [data]

# Integer Quantization - Input/Output=uint8
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_225')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_225_full_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_225_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_257')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_257_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_257_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_321')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_321_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_321_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_385')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_385_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_385_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_513')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_16_513_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_16_513_full_integer_quant.tflite")



# Integer Quantization - Input/Output=uint8
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_225')
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_225_full_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_225_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_257')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_257_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_257_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_321')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_321_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_321_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_385')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_385_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_385_full_integer_quant.tflite")

# # Integer Quantization - Input/Output=uint8
# converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_513')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.representative_dataset = representative_dataset_gen
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# tflite_quant_model = converter.convert()
# with open('posenet_resnet50_32_513_full_integer_quant.tflite', 'wb') as w:
#     w.write(tflite_quant_model)
# print("Integer Quantization complete! - posenet_resnet50_32_513_full_integer_quant.tflite")

- Go to Table of contents -

4-2-6-8. Float16 Quantization from saved_model (Float16 quantization)

The method of Float16 Quantization is the same as before.

float16_quantization_resnet.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import os
import glob


# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_225')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_225_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_225_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_257')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_257_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_257_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_321')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_321_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_321_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_385')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_385_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_385_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_16_513')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_16_513_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_16_513_float16_quant.tflite")



# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_225')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_225_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_225_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_257')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_257_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_257_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_321')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_321_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_321_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_385')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_385_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_385_float16_quant.tflite")

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_posenet_resnet50_32_513')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('posenet_resnet50_32_513_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - posenet_resnet50_32_513_float16_quant.tflite")

- Go to Table of contents -

4-2-6-9. Full Integer Quantization to EdgeTPU convert

The method of compiling EdgeTPU is the same as the previous method.

EdgeTPU_compilation_of_Full_Integer_Quantization_model
$ edgetpu_compiler -s posenet_resnet50_16_225_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_16_257_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_16_321_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_16_385_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_16_513_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_32_225_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_32_257_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_32_321_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_32_385_full_integer_quant.tflite
$ edgetpu_compiler -s posenet_resnet50_32_513_full_integer_quant.tflite

- Go to Table of contents -

4-2-7. Quantize the model generated by the TensorFlow Object Detection API

What's with the Tensorflow Object Detection API? If you want to know more, please see this article I made Object Detection Tools and the following repositories: Object Detection API library of TensorFlow - Karaage San. It's very helpful.

The method of training with the Object Detection API is pretty clear in the above article, so I won't go into it here, but I will explain how to quantize the generated model as a model with Post-Process added to it to improve performance. Use Tensorflow v1.15.2.

Assume you have a MobileNetV2-SSDLite checkpoint at hand that has been cloned https://github.com/tensorflow/models.git and trained for 44,548 STEPs using the Object Detection API.
Screenshot 2020-05-05 20:10:03.png

- Go to Table of contents -

4-2-7-1. Generating a .pb file with Post-Process

Execute the following command to output a Freeze_Graph with post-processing added.

Generate_a_.pb_file_with_Post-Process_added_using_checkpoint
$ cd ${HOME}/models/research
$ export PYTHONPATH=`pwd`:`pwd`/slim:$PYTHONPATH
$ mkdir -p export

$ python3 object_detection/export_tflite_ssd_graph.py \
    --pipeline_config_path=pipeline.config \
    --trained_checkpoint_prefix=model.ckpt-44548 \
    --output_directory=export \
    --add_postprocessing_op=True

Screenshot 2020-05-05 20:07:32.png
Screenshot 2020-05-05 20:14:54.png
Screenshot 2020-05-05 20:16:14.png
TFLite_Detection_PostProcess is a custom operation.

- Go to Table of contents -

4-2-7-2. Weight Quantization from Freeze_Graph (Weight-only quantization)

I won't do anything special, but here are some points to consider when quantizing.
 1. Using the Tensorflow v1.x API from_frozen_graph
 2. For .pb files containing custom operations, converter.allow_custom_ops = True

weight_quantization.py
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Weight Quantization - Input/Output=float32
graph_def_file="tflite_graph_with_postprocess.pb"
input_arrays=["normalized_input_image_tensor"]
output_arrays=['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1', 
               'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']
input_tensor={"normalized_input_image_tensor":[1,300,300,3]}

converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 
                                                      output_arrays,input_tensor)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.allow_custom_ops = True
tflite_quant_model = converter.convert()
with open('./ssdlite_mobilenet_v2_voc_300_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - ssdlite_mobilenet_v2_voc_300_weight_quant.tflite")

Let's run it.

Executing_Weight_Quantization
$ python3 weight_quantization.py

It seems to have been generated safely.
Screenshot 2020-05-05 20:36:41.png
- Go to Table of contents -

4-2-7-3. Integer Quantization from Freeze_Graph (8-bit integer quantization)

It is almost identical to Weight Quantization.

integer_quantization_with_postprocess.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

def representative_dataset_gen():
  for data in raw_test_data.take(100):
    image = data['image'].numpy()
    image = tf.image.resize(image, (300, 300))
    image = image[np.newaxis,:,:,:]
    image = image - 127.5
    image = image * 0.007843
    yield [image]

tf.compat.v1.enable_eager_execution()

raw_test_data, info = tfds.load(name="voc/2007", with_info=True, 
                                split="validation", data_dir="~/TFDS", download=False)

# Integer Quantization - Input/Output=float32
graph_def_file="tflite_graph_with_postprocess.pb"
input_arrays=["normalized_input_image_tensor"]
output_arrays=['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1', 
               'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']
input_tensor={"normalized_input_image_tensor":[1,300,300,3]}

converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 
                                                      output_arrays,input_tensor)
converter.allow_custom_ops=True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./ssdlite_mobilenet_v2_voc_300_integer_quant_with_postprocess.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - ssdlite_mobilenet_v2_voc_300_integer_quant_with_postprocess.tflite")

- Go to Table of contents -

4-2-7-4. Full Integer Quantization from Freeze_Graph (All 8-bit integer quantization)

full_integer_quantization_with_postprocess.py
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

def representative_dataset_gen():
  for data in raw_test_data.take(100):
    image = data['image'].numpy()
    image = tf.image.resize(image, (300, 300))
    image = image[np.newaxis,:,:,:]
    image = image - 127.5
    image = image * 0.007843
    yield [image]

tf.compat.v1.enable_eager_execution()

raw_test_data, info = tfds.load(name="voc/2007", with_info=True, 
                                split="validation", data_dir="~/TFDS", download=False)

# Full Integer Quantization - Input/Output=float32
graph_def_file="tflite_graph_with_postprocess.pb"
input_arrays=["normalized_input_image_tensor"]
output_arrays=['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1', 
               'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']
input_tensor={"normalized_input_image_tensor":[1,300,300,3]}

converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 
                                                      output_arrays,input_tensor)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.allow_custom_ops=True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8,tf.lite.OpsSet.SELECT_TF_OPS]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./ssdlite_mobilenet_v2_voc_300_full_integer_quant_with_postprocess.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Full Integer Quantization complete! - ssdlite_mobilenet_v2_voc_300_full_integer_quant_with_postprocess.tflite")

- Go to Table of contents -

4-2-7-5. Float16 Quantization from Freeze_Graph (Float16 quantization)

float16_quantization.py
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

# Float16 Quantization - Input/Output=float32
graph_def_file="tflite_graph_with_postprocess.pb"
input_arrays=["normalized_input_image_tensor"]
output_arrays=['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1', 
               'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3']
input_tensor={"normalized_input_image_tensor":[1,300,300,3]}

converter = tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, 
                                                      output_arrays,input_tensor)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
converter.allow_custom_ops = True
tflite_quant_model = converter.convert()
with open('./ssdlite_mobilenet_v2_voc_300_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Float16 Quantization complete! - ssdlite_mobilenet_v2_voc_300_float16_quant.tflite")

- Go to Table of contents -

4-2-7-6. Full Integer Quantization to EdgeTPU convert

$ edgetpu_compiler -s ssdlite_mobilenet_v2_voc_300_full_integer_quant_with_postprocess.tflite

Screenshot 2020-05-05 21:43:24.png
- Go to Table of contents -

4-2-8. Quantize models containing operations that are not supported by Tensorflow Lite but are supported by Tensorflow

The operators implemented in Tensorflow Lite are not exactly the same as those in Tensorflow itself, and a significant amount of them remain unimplemented. Unfortunately, the models that have been implemented in Tensorflow for a long time have been difficult to convert all of them to Tensorflow Lite. However, with the Flex Delegate delegate function implemented around the end of last year, it is now possible to offload processing to the main body of Tensorflow if there is an operator not yet implemented in Tensorflow Lite. The implementation doesn't seem to be perfect yet, and some models don't allow Integer Quantization, and there is only a C++ API and no Python API. The recognition is not so high, but I think that it is a convenient function which can expand the range of model utilization casually for engineers who can implement it in C++.

Here is an example of Mask-RCNN Inception V2. Unfortunately, Mask-RCNN Inception V2 does not support Integer Quantization and Full Integer Quantization at this time. You also need to implement Tensorflow v2.2.0 or tf-nightly to perform this procedure.


1. https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models
2. http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz

4-2-8-1. Generate Mask-RCNN Inception V2 .pb file

I will proceed on the assumption that the Tensorflow Object Detection API is in place. First, download the Mask-RCNN Inception V2 checkpoint from the official website and convert it to Freeze_Graph using the Object Detection API script.

Commands_for_generating_.pb_files
https://github.com/matterport/Mask_RCNN/issues/563
https://github.com/PINTO0309/TensorflowLite-flexdelegate

$ cd ~/Downloads
$ wget http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz
$ tar -zxvf mask_rcnn_inception_v2_coco_2018_01_28.tar.gz
$ cd ${HOME}/models/research
$ export PYTHONPATH=`pwd`:`pwd`/slim:$PYTHONPATH
$ mkdir -p ${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/export

$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator
$ sudo pip3 install tensorflow==2.2.0

$ python3 object_detection/export_inference_graph.py \
  --input_type=image_tensor \
  --pipeline_config_path=${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/pipeline.config \
  --trained_checkpoint_prefix=${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/model.ckpt \
  --output_directory=${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/test \
  --input_shape=1,256,256,3 \
  --write_inference_graph=True

$ python3 object_detection/export_inference_graph.py \
  --input_type=image_tensor \
  --pipeline_config_path=${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/pipeline.config \
  --trained_checkpoint_prefix=${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/model.ckpt \
  --output_directory=${HOME}/Downloads/mask_rcnn_inception_v2_coco_2018_01_28/test \
  --input_shape=1,512,512,3 \
  --write_inference_graph=True

- Go to Table of contents -

4-2-8-2. Weight Quantization of Mask-RCNN Inception V2 (Weight-only quantization)

The points of this work are as follows.
 1. Tensorflow v2.2.0-rc0 or higher is installed
 2. tf.lite.OpsSet.SELECT_TF_OPS is specified in converter.target_spec.supported_ops

weight_quantization.py
### tensorflow==2.2.0

import tensorflow as tf

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model')
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
tflite_quant_model = converter.convert()
with open('./mask_rcnn_inception_v2_coco_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - mask_rcnn_inception_v2_coco_weight_quant.tflite")

- Go to Table of contents -

4-2-8-3. Float16 Quantization in Mask-RCNN Inception V2 (Float16 quantization)

The point of the work is the same as Weight Quantization.

float16_quantization.py
### tensorflow==2.2.0

import tensorflow as tf

# Float16 Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('./saved_model')
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
converter.target_spec.supported_types = [tf.float16]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('./mask_rcnn_inception_v2_coco_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Float16 Quantization complete! - mask_rcnn_inception_v2_coco_float16_quant.tflite")

- Go to Table of contents -

4-2-8-4. Running a model with Flex Delegate (Tensorflow Select Ops) enabled

Unfortunately, I don't have the skills to implement in C++. However, I made an effort to implement and even got ENet to work, including Tensorflow Lite's non-supported operations. The contents are a mess, but you can find the wreckage in this repository TensorflowLite-flexdelegate. You need to build Tensorflow Lite to enable the Flex feature. Please see the above repository for more information.

- Go to Table of contents -

4-2-9. Quantization from a model for PyTorch

Lately, I feel like there are more and more interesting models and projects of PyTorch implementations. Here's an example of how to convert a PyTorch model to a Tensorflow Lite quantized model. Here I try to convert the PyTorch model of 3D Multi-Person Pose Estimation to a quantized model of Tensorflow Lite. This 3D PoseEstimation (Multi-Person) by OpenVINO + Corei7 CPU only [14 FPS-18 FPS] - Qiita - PINTO is the original story. You need Tensorflow v2.2.0 to perform this work.


Let me make use of the models and converters published below.


- Go to Table of contents -

4-2-9-1. Advance preparation (PyTorch->ONNX)

An overview of PyTorch's quantization workflow is provided below.
 1. Clone open_model_zoo
 2. Download a public model using downloader.py from open_model_zoo
 3. Converting a PyTorch model to an ONNX model using converter.py` of open_model_zoo
   ( Standard features of PyTorch torch.onnx._export(...) It's probably OK)
 4. Convert a Keras model to saved_model
 5. Convert a Keras model to saved_model
 6. Quantize saved_model

Model_download_and_PyTorch->ONNX_conversion
$ git clone https://github.com/opencv/open_model_zoo.git
$ cd open_model_zoo/tools/downloader
$ ./downloader.py --name human-pose-estimation-3d-0001
$ ./converter.py --name human-pose-estimation-3d-0001

https___qiita-image-store.s3.ap-northeast-1.amazonaws.com_0_194769_e304f5cb-4714-e5b2-af1e-ecf28a57b947.png
A part of the structure of the ONNX model is shown in the figure below.
Screenshot 2020-05-06 00:02:48.png
- Go to Table of contents -

4-2-9-2. ONNX->Keras conversion by onnx2keras

First, install Tensorflow v2.2.0 and onnx2keras.

Installation_of_onnx2keras
$ sudo pip3 uninstall tensorboard-plugin-wit tb-nightly \
                      tf-estimator-nightly tensorflow-gpu \
                      tensorflow tf-nightly tensorflow_estimator
$ sudo pip3 install tensorflow==2.2.0
$ sudo pip3 install onnx2keras

Next, here's the program to convert the ONNX model to the Keras model, and then to saved_model

onnx_to_keras.py
import onnx
from onnx2keras import onnx_to_keras
import tensorflow as tf
import shutil

onnx_model = onnx.load('human-pose-estimation-3d-0001.onnx')
k_model = onnx_to_keras(onnx_model=onnx_model, input_names=['data'], change_ordering=True)

shutil.rmtree('saved_model', ignore_errors=True)
tf.saved_model.save(k_model, 'saved_model')

Let's run it.

ONNX->Keras->saved_model
$ python3 onnx_to_keras.py

It seems to have been generated safely. Once you've come this far, it's no different from the quantization procedure you've been working on so far.
Screenshot 2020-05-06 00:09:01.png
Screenshot 2020-05-06 00:09:13.png
- Go to Table of contents -

4-2-9-3. Weight Quantization from saved_model (Weight-only quantization)

weight_quantization.py
### tensorflow=2.2.0

import tensorflow as tf

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
#converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
tflite_quant_model = converter.convert()
with open('human_pose_estimation_3d_0001_256x448_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - human_pose_estimation_3d_0001_256x448_weight_quant.tflite")

- Go to Table of contents -

4-2-9-4. Integer Quantization from saved_model (8-bit integer quantization)

The image data set for the calibration is 100 images extracted from Pascal-VOC 2007, which contains only the images of people.

integer_quantization.py
### tensorflow==2.2.0

import tensorflow as tf
import numpy as np

def representative_dataset_gen():
    for image in raw_test_data:
        image = tf.image.resize(image, (256, 448))
        image = image[np.newaxis,:,:,:]
        image = image - 127.5
        image = image * 0.007843
        yield [image]

raw_test_data = np.load('calibration_data_img.npy', allow_pickle=True)

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('human_pose_estimation_3d_0001_256x448_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - human_pose_estimation_3d_0001_256x448_integer_quant.tflite")

- Go to Table of contents -

4-2-9-5. Full Integer Quantization from saved_model (All 8-bit integer quantization)

full_integer_quantization.py
### tensorflow==2.2.0

import tensorflow as tf
import numpy as np

def representative_dataset_gen():
    for image in raw_test_data:
        image = tf.image.resize(image, (256, 448))
        image = image[np.newaxis,:,:,:]
        image = image - 127.5
        image = image * 0.007843
        yield [image]

raw_test_data = np.load('calibration_data_img.npy', allow_pickle=True)

# Full Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('human_pose_estimation_3d_0001_256x448_full_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Full Integer Quantization complete! - human_pose_estimation_3d_0001_256x448_full_integer_quant.tflite")

- Go to Table of contents -

4-2-9-6. Float16 Quantization from saved_model (Float16 quantization)

float16_quantization.py
### tensorflow==2.2.0

import tensorflow as tf

# Float16 Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('human_pose_estimation_3d_0001_256x448_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Float16 Quantization complete! - human_pose_estimation_3d_0001_256x448_float16_quant.tflite")

- Go to Table of contents -

4-2-9-7. Full Integer Quantization to EdgeTPU convert

Compile to the EdgeTPU model using the Full Integer Quantization model.

$ edgetpu_compiler -s human_pose_estimation_3d_0001_256x448_full_integer_quant.tflite

The operator Elu, which appears early in the model, does not support quantization, resulting in a very disappointing EdgeTPU model. However, some of the early stages have managed to be converted to support EdgeTPU.
Screenshot 2020-05-06 00:22:30.png
- Go to Table of contents -

4-2-10. Quantization of MediaPipe's model BlazeFace(.tflite)

Here I quantize a model called BlazeFace, published by Google in a project called MediaPipe. This will be the most challenging pattern in the quantization workflow to date. In the first place, not all checkpoint, Freeze_Graph, and saved_model are provided, but only .tflite. The conversion procedure is as follows.

 1. Build flatc
 2. Download schema.fbs
 3. Download the model face_detection_front.tflite for BlazeFace.
 4. Parse a .tflite into a .json using flatc.
 5. Generate a network based on the model structure read from .json in 4. while extracting weights from .tflite
 6. Convert to saved_model using the weights and network extracted in 5.
 7. Perform various types of quantization.

4-2-10-1. Build flatc and download schema.fbs

Build_flatc
$ cd ~
$ git clone https://github.com/google/flatbuffers.git
$ cd flatbuffers
$ cmake -G "Unix Makefiles"

-- The C compiler identification is GNU 7.5.0
-- The CXX compiler identification is GNU 7.5.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for strtof_l
-- Looking for strtof_l - found
-- Looking for strtoull_l
-- Looking for strtoull_l - found
-- `tests/monster_test.fbs`: add generation of C++ code with '--no-includes;--gen-compare'
-- `tests/monster_test.fbs`: add generation of binary (.bfbs) schema
-- `tests/namespace_test/namespace_test1.fbs`: add generation of C++ code with '--no-includes;--gen-compare'
-- `tests/namespace_test/namespace_test2.fbs`: add generation of C++ code with '--no-includes;--gen-compare'
-- `tests/union_vector/union_vector.fbs`: add generation of C++ code with '--no-includes;--gen-compare'
-- `tests/native_type_test.fbs`: add generation of C++ code with ''
-- `tests/arrays_test.fbs`: add generation of C++ code with '--scoped-enums;--gen-compare'
-- `tests/arrays_test.fbs`: add generation of binary (.bfbs) schema
-- `tests/monster_test.fbs`: add generation of C++ embedded binary schema code with '--no-includes;--gen-compare'
-- `tests/monster_extra.fbs`: add generation of C++ code with '--no-includes;--gen-compare'
-- `samples/monster.fbs`: add generation of C++ code with '--no-includes;--gen-compare'
-- `samples/monster.fbs`: add generation of binary (.bfbs) schema
Proceeding with version: 1.12.0.42
-- Configuring done
-- Generating done
-- Build files have been written to: /home/b920405/git/flatbuffers

$ make

Scanning dependencies of target flatc
[  1%] Building CXX object CMakeFiles/flatc.dir/src/idl_parser.cpp.o
[  2%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_text.cpp.o
[  3%] Building CXX object CMakeFiles/flatc.dir/src/reflection.cpp.o
[  4%] Building CXX object CMakeFiles/flatc.dir/src/util.cpp.o
[  5%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_cpp.cpp.o
[  7%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_csharp.cpp.o
[  8%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_dart.cpp.o
[  9%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_kotlin.cpp.o
[ 10%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_go.cpp.o
[ 11%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_java.cpp.o
[ 12%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_js_ts.cpp.o
[ 14%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_php.cpp.o
[ 15%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_python.cpp.o
[ 16%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_lobster.cpp.o
[ 17%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_lua.cpp.o
[ 18%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_rust.cpp.o
[ 20%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_fbs.cpp.o
[ 21%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_grpc.cpp.o
[ 22%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_json_schema.cpp.o
[ 23%] Building CXX object CMakeFiles/flatc.dir/src/idl_gen_swift.cpp.o
[ 24%] Building CXX object CMakeFiles/flatc.dir/src/flatc.cpp.o
[ 25%] Building CXX object CMakeFiles/flatc.dir/src/flatc_main.cpp.o
[ 27%] Building CXX object CMakeFiles/flatc.dir/src/code_generators.cpp.o
[ 28%] Building CXX object CMakeFiles/flatc.dir/grpc/src/compiler/cpp_generator.cc.o
[ 29%] Building CXX object CMakeFiles/flatc.dir/grpc/src/compiler/go_generator.cc.o
[ 30%] Building CXX object CMakeFiles/flatc.dir/grpc/src/compiler/java_generator.cc.o
[ 31%] Building CXX object CMakeFiles/flatc.dir/grpc/src/compiler/python_generator.cc.o
[ 32%] Building CXX object CMakeFiles/flatc.dir/grpc/src/compiler/swift_generator.cc.o
[ 34%] Linking CXX executable flatc
[ 34%] Built target flatc
Scanning dependencies of target flathash
[ 35%] Building CXX object CMakeFiles/flathash.dir/src/flathash.cpp.o
[ 36%] Linking CXX executable flathash
[ 36%] Built target flathash
Scanning dependencies of target flatbuffers
[ 37%] Building CXX object CMakeFiles/flatbuffers.dir/src/idl_parser.cpp.o
[ 38%] Building CXX object CMakeFiles/flatbuffers.dir/src/idl_gen_text.cpp.o
[ 40%] Building CXX object CMakeFiles/flatbuffers.dir/src/reflection.cpp.o
[ 41%] Building CXX object CMakeFiles/flatbuffers.dir/src/util.cpp.o
[ 42%] Linking CXX static library libflatbuffers.a
[ 42%] Built target flatbuffers
Scanning dependencies of target generated_code
[ 43%] Run generation: 'samples/monster.bfbs'
[ 44%] Run generation: 'tests/monster_test_generated.h'
[ 45%] Run generation: 'tests/monster_test.bfbs'
[ 47%] Run generation: 'tests/namespace_test/namespace_test1_generated.h'
[ 48%] Run generation: 'tests/namespace_test/namespace_test2_generated.h'
[ 49%] Run generation: 'tests/union_vector/union_vector_generated.h'
[ 50%] Run generation: 'tests/native_type_test_generated.h'
[ 51%] Run generation: 'tests/arrays_test_generated.h'
[ 52%] Run generation: 'tests/arrays_test.bfbs'
[ 54%] Run generation: 'tests/monster_test_bfbs_generated.h'
[ 55%] Run generation: 'tests/monster_extra_generated.h'
[ 56%] Run generation: 'samples/monster_generated.h'
[ 57%] All generated files were updated.
[ 57%] Built target generated_code
Scanning dependencies of target flatsamplebfbs
[ 58%] Building CXX object CMakeFiles/flatsamplebfbs.dir/src/idl_parser.cpp.o
[ 60%] Building CXX object CMakeFiles/flatsamplebfbs.dir/src/idl_gen_text.cpp.o
[ 61%] Building CXX object CMakeFiles/flatsamplebfbs.dir/src/reflection.cpp.o
[ 62%] Building CXX object CMakeFiles/flatsamplebfbs.dir/src/util.cpp.o
[ 63%] Building CXX object CMakeFiles/flatsamplebfbs.dir/samples/sample_bfbs.cpp.o
[ 64%] Linking CXX executable flatsamplebfbs
[ 65%] Built target flatsamplebfbs
Scanning dependencies of target flatsamplebinary
[ 67%] Building CXX object CMakeFiles/flatsamplebinary.dir/samples/sample_binary.cpp.o
[ 68%] Linking CXX executable flatsamplebinary
[ 69%] Built target flatsamplebinary
Scanning dependencies of target flattests
[ 70%] Building CXX object CMakeFiles/flattests.dir/src/idl_parser.cpp.o
[ 71%] Building CXX object CMakeFiles/flattests.dir/src/idl_gen_text.cpp.o
[ 72%] Building CXX object CMakeFiles/flattests.dir/src/reflection.cpp.o
[ 74%] Building CXX object CMakeFiles/flattests.dir/src/util.cpp.o
[ 75%] Building CXX object CMakeFiles/flattests.dir/src/idl_gen_fbs.cpp.o
[ 76%] Building CXX object CMakeFiles/flattests.dir/tests/test.cpp.o
[ 77%] Building CXX object CMakeFiles/flattests.dir/tests/test_assert.cpp.o
[ 78%] Building CXX object CMakeFiles/flattests.dir/tests/test_builder.cpp.o
[ 80%] Building CXX object CMakeFiles/flattests.dir/tests/native_type_test_impl.cpp.o
[ 81%] Building CXX object CMakeFiles/flattests.dir/src/code_generators.cpp.o
[ 82%] Linking CXX executable flattests
[ 91%] Built target flattests
Scanning dependencies of target flatsampletext
[ 92%] Building CXX object CMakeFiles/flatsampletext.dir/src/idl_parser.cpp.o
[ 94%] Building CXX object CMakeFiles/flatsampletext.dir/src/idl_gen_text.cpp.o
[ 95%] Building CXX object CMakeFiles/flatsampletext.dir/src/reflection.cpp.o
[ 96%] Building CXX object CMakeFiles/flatsampletext.dir/src/util.cpp.o
[ 97%] Building CXX object CMakeFiles/flatsampletext.dir/samples/sample_text.cpp.o
[ 98%] Linking CXX executable flatsampletext
[100%] Built target flatsampletext

$ cp flatc ~ && cd ~
$ wget https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs

- Go to Table of contents -

4-2-10-2. Download MediaPipe's BlazeFace model (.tflite)

Download_face_detection_front.tflite
$ wget https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite

- Go to Table of contents -

4-2-10-3. Converting BlazeFace(.tflite) to saved_model(.pb)

blazeface_tflite_to_pb.py
### tensorflow-gpu==1.15.2

#!/usr/bin/env python
# coding: utf-8

import os
import numpy as np
import json
import tensorflow as tf
import shutil
from pathlib import Path
home = str(Path.home())

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
schema = "schema.fbs"
binary = home + "/flatc"
model_path = "face_detection_front.tflite"
output_pb_path = "face_detection_front.pb"
output_savedmodel_path = "saved_model"
model_json_path = "face_detection_front.json"
num_tensors = 176
output_node_names = ['classificators', 'regressors']

def gen_model_json():
    if not os.path.exists(model_json_path):
        cmd = (binary + " -t --strict-json --defaults-json -o . {schema} -- {input}".format(input=model_path, schema=schema))
        print("output json command =", cmd)
        os.system(cmd)


def parse_json():
    j = json.load(open(model_json_path))
    op_types = [v['builtin_code'] for v in j['operator_codes']]
    # print('op types:', op_types)
    ops = j['subgraphs'][0]['operators']
    # print('num of ops:', len(ops))
    return ops, op_types


def make_graph(ops, op_types, interpreter):
    tensors = {}
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    # print(input_details)
    for input_detail in input_details:
        tensors[input_detail['index']] = tf.compat.v1.placeholder(
            dtype=input_detail['dtype'],
            shape=input_detail['shape'],
            name=input_detail['name'])

    for index, op in enumerate(ops):
        print('op: ', op)
        op_type = op_types[op['opcode_index']]
        if op_type == 'CONV_2D':
            input_tensor = tensors[op['inputs'][0]]
            weights_detail = interpreter._get_tensor_details(op['inputs'][1])
            bias_detail = interpreter._get_tensor_details(op['inputs'][2])
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            # print('weights_detail: ', weights_detail)
            # print('bias_detail: ', bias_detail)
            # print('output_detail: ', output_detail)
            weights_array = interpreter.get_tensor(weights_detail['index'])
            weights_array = np.transpose(weights_array, (1, 2, 3, 0))
            bias_array = interpreter.get_tensor(bias_detail['index'])
            weights = tf.Variable(weights_array, name=weights_detail['name'])
            bias = tf.Variable(bias_array, name=bias_detail['name'])
            options = op['builtin_options']
            output_tensor = tf.nn.conv2d(
                input_tensor,
                weights,
                strides=[1, options['stride_h'], options['stride_w'], 1],
                padding=options['padding'],
                dilations=[
                    1, options['dilation_h_factor'],
                    options['dilation_w_factor'], 1
                ],
                name=output_detail['name'] + '/conv2d')
            output_tensor = tf.add(
                output_tensor, bias, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'DEPTHWISE_CONV_2D':
            input_tensor = tensors[op['inputs'][0]]
            weights_detail = interpreter._get_tensor_details(op['inputs'][1])
            bias_detail = interpreter._get_tensor_details(op['inputs'][2])
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            # print('weights_detail: ', weights_detail)
            # print('bias_detail: ', bias_detail)
            # print('output_detail: ', output_detail)
            weights_array = interpreter.get_tensor(weights_detail['index'])
            weights_array = np.transpose(weights_array, (1, 2, 3, 0))
            bias_array = interpreter.get_tensor(bias_detail['index'])
            weights = tf.Variable(weights_array, name=weights_detail['name'])
            bias = tf.Variable(bias_array, name=bias_detail['name'])
            options = op['builtin_options']
            output_tensor = tf.nn.depthwise_conv2d(
                input_tensor,
                weights,
                strides=[1, options['stride_h'], options['stride_w'], 1],
                padding=options['padding'],
                # dilations=[
                #     1, options['dilation_h_factor'],
                #     options['dilation_w_factor'], 1
                # ],
                name=output_detail['name'] + '/depthwise_conv2d')
            output_tensor = tf.add(
                output_tensor, bias, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'MAX_POOL_2D':
            input_tensor = tensors[op['inputs'][0]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            options = op['builtin_options']
            output_tensor = tf.nn.max_pool(
                input_tensor,
                ksize=[
                    1, options['filter_height'], options['filter_width'], 1
                ],
                strides=[1, options['stride_h'], options['stride_w'], 1],
                padding=options['padding'],
                name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'PAD':
            input_tensor = tensors[op['inputs'][0]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            paddings_detail = interpreter._get_tensor_details(op['inputs'][1])
            # print('output_detail:', output_detail)
            # print('paddings_detail:', paddings_detail)
            paddings_array = interpreter.get_tensor(paddings_detail['index'])
            paddings = tf.Variable(
                paddings_array, name=paddings_detail['name'])
            output_tensor = tf.pad(
                input_tensor, paddings, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'RELU':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor = tensors[op['inputs'][0]]
            output_tensor = tf.nn.relu(
                input_tensor, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'RESHAPE':
            input_tensor = tensors[op['inputs'][0]]
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            options = op['builtin_options']
            output_tensor = tf.reshape(
                input_tensor, options['new_shape'], name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'ADD':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor_0 = tensors[op['inputs'][0]]
            input_tensor_1 = tensors[op['inputs'][1]]
            output_tensor = tf.add(input_tensor_0, input_tensor_1, name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        elif op_type == 'CONCATENATION':
            output_detail = interpreter._get_tensor_details(op['outputs'][0])
            input_tensor_0 = tensors[op['inputs'][0]]
            input_tensor_1 = tensors[op['inputs'][1]]
            options = op['builtin_options']
            output_tensor = tf.concat([input_tensor_0, input_tensor_1],
                                      options['axis'],
                                      name=output_detail['name'])
            tensors[output_detail['index']] = output_tensor
        else:
            raise ValueError(op_type)


def main():

    tf.compat.v1.disable_eager_execution()

    gen_model_json()
    ops, op_types = parse_json()

    interpreter = tf.lite.Interpreter(model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print(input_details)
    print(output_details)
    for i in range(num_tensors):
        detail = interpreter._get_tensor_details(i)
        print(detail)

    make_graph(ops, op_types, interpreter)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    graph = tf.compat.v1.get_default_graph()
    # writer = tf.summary.FileWriter(os.path.splitext(output_pb_path)[0])
    # writer.add_graph(graph)
    # writer.flush()
    # writer.close()
    with tf.compat.v1.Session(config=config, graph=graph) as sess:
        sess.run(tf.compat.v1.global_variables_initializer())
        graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=graph.as_graph_def(),
            output_node_names=output_node_names)

        with tf.io.gfile.GFile(output_pb_path, 'wb') as f:
            f.write(graph_def.SerializeToString())

        shutil.rmtree('saved_model', ignore_errors=True)
        tf.compat.v1.saved_model.simple_save(
            sess,
            output_savedmodel_path,
            inputs={'input': graph.get_tensor_by_name('input:0')},
            outputs={
                'classificators': graph.get_tensor_by_name('classificators:0'),
                'regressors': graph.get_tensor_by_name('regressors:0')
            })

if __name__ == '__main__':
    main()

"""
$ saved_model_cli show --dir saved_model --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 128, 128, 3)
        name: input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['classificators'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, -1, 1)
        name: classificators:0
    outputs['regressors'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, -1, 16)
        name: regressors:0
  Method name is: tensorflow/serving/predict
"""
$ python3 blazeface_tflite_to_pb.py

- Go to Table of contents -

4-2-10-4. Weight Quantization from saved_model (weight-only quantization)

weight_quantization.py
### tensorflow==2.2.0

import tensorflow as tf
import numpy as np

# Weight Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
with open('face_detection_front_128_weight_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete! - face_detection_front_128_weight_quant.tflite")
$ python3 weight_quantization.py

- Go to Table of contents -

4-2-10-5. Integer Quantization from saved_model (8-bit integer quantization)

integer_quantization.py
### tensorflow==2.2.0

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import os
import glob

def representative_dataset_gen():
  for data in raw_test_data.take(100):
    image = data['image'].numpy()
    image = tf.image.resize(image, (128, 128))
    image = image[np.newaxis,:,:,:]
    image = image - 127.5
    image = image * 0.007843
    yield [image]

raw_test_data, info = tfds.load(name="the300w_lp", with_info=True, split="train", data_dir="~/TFDS", download=True)

# Integer Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('face_detection_front_128_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Integer Quantization complete! - face_detection_front_128_integer_quant.tflite")
$ python3 integer_quantization.py

- Go to Table of contents -

4-2-10-6. Full Integer Quantization from saved_model (All 8-bit integer quantization)

full_integer_quantization.py
### tensorflow==2.2.0

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import os
import glob

def representative_dataset_gen():
  for data in raw_test_data.take(100):
    image = data['image'].numpy()
    image = tf.image.resize(image, (128, 128))
    image = image[np.newaxis,:,:,:]
    image = image - 127.5
    image = image * 0.007843
    yield [image]

raw_test_data, info = tfds.load(name="the300w_lp", with_info=True, split="train", data_dir="~/TFDS", download=False)

# Integer Quantization - Input/Output=uint8
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()
with open('face_detection_front_128_full_integer_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Full Integer Quantization complete! - face_detection_front_128_full_integer_quant.tflite")
$ python3 full_integer_quantization.py

- Go to Table of contents -

4-2-10-7. Float16 Quantization from saved_model (Float16 quantization)

float16_quantization.py
### tensorflow==2.2.0

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import os
import glob

# Float16 Quantization - Input/Output=float32
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()
with open('face_detection_front_128_float16_quant.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Float16 Quantization complete! - face_detection_front_128_float16_quant.tflite")
$ python3 float16_quantization.py

- Go to Table of contents -

4-2-10-8. Full Integer Quantization to EdgeTPU convert

$ edgetpu_compiler -s face_detection_front_128_full_integer_quant.tflite

- Go to Table of contents -

4-3. Performance benchmarks for the quantization model (.tflite)

When benchmarking the performance of the generated .tflite file, it is very tedious to write a validation program based on the characteristics of the model every time. So I compiled and used the program for Benchmark, which is published in the official repository of Tensorflow. It is a useful tool to adjust the number of Multi-Threads for inference, and to enable XNNPACK or GPU Delegate in the boot option, which allows you to benchmark against various environments quite easily. Here, I will explain how to build and use it.
https://github.com/PINTO0309/PINTO_model_zoo#3-tflite-model-benchmark

- Go to Table of contents -

4-3-1. Building the TFLite Model Benchmark Tool

Below are the steps to prepare only three environments that I can immediately prepare at my fingertips. The rest of the environment is up to everyone to implement.

Build_TFLite_Model_Benchmark_Tool
$ sudo apt-get install python-future

## Bazel for Ubuntu18.04 x86_64 install
$ wget https://github.com/bazelbuild/bazel/releases/download/2.0.0/bazel-2.0.0-installer-linux-x86_64.sh
$ sudo chmod +x bazel-2.0.0-installer-linux-x86_64.sh
$ ./bazel-2.0.0-installer-linux-x86_64.sh
$ sudo apt-get install -y openjdk-8-jdk

## Bazel for RaspberryPi3/4 Raspbian/Debian Buster armhf install
$ wget https://github.com/PINTO0309/Bazel_bin/raw/master/2.0.0/Raspbian_Debian_Buster_armhf/openjdk-8-jdk/install.sh
$ ./install.sh
$ curl -sc /tmp/cookie \
  "https://drive.google.com/uc?export=download&id=1LQUSal55R6fmawZS9zZuk6-5ZFOdUqRK" > /dev/null
$ CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"
$ curl -Lb /tmp/cookie \
  "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1LQUSal55R6fmawZS9zZuk6-5ZFOdUqRK" \
  -o adoptopenjdk-8-hotspot_8u222-b10-2_armhf.deb
$ sudo apt-get install -y ./adoptopenjdk-8-hotspot_8u222-b10-2_armhf.deb

## Bazel for RaspberryPi3/4 Raspbian/Debian Buster aarch64 install
$ wget https://github.com/PINTO0309/Bazel_bin/raw/master/2.0.0/Raspbian_Debian_Buster_aarch64/openjdk-8-jdk/install.sh
$ ./install.sh
$ curl -sc /tmp/cookie \
  "https://drive.google.com/uc?export=download&id=1VwLxzT3EOTbhSzwvRF2H4ChTQyTQBt3x" > /dev/null
$ CODE="$(awk '/_warning_/ {print $NF}' /tmp/cookie)"
$ curl -Lb /tmp/cookie \
  "https://drive.google.com/uc?export=download&confirm=${CODE}&id=1VwLxzT3EOTbhSzwvRF2H4ChTQyTQBt3x" \
  -o adoptopenjdk-8-hotspot_8u222-b10-2_arm64.deb
$ sudo apt-get install -y ./adoptopenjdk-8-hotspot_8u222-b10-2_arm64.deb

## Clone Tensorflow v2.1.0+
$ git clone --depth 1 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow

## Build and run TFLite Model Benchmark Tool

## Flex Delegate disabled version, it only takes a very short time to build.
$ bazel build \
  -c opt \
  tensorflow/lite/tools/benchmark:benchmark_model

## Flex Delegate valid version, it takes a long time to build.
$ bazel build \
  -c opt \
  --config=noaws \
  --config=nohdfs \
  --config=nonccl \
  tensorflow/lite/tools/benchmark:benchmark_model_plus_flex

- Go to Table of contents -

4-3-2. Options for the TFLite Model Benchmark Tool

$ bazel run -c opt tensorflow/lite/tools/benchmark:benchmark_model -- --help

Flags:
    --input_layer_value_files=          string  optional    A map-like string representing value file. Each item is separated by ',', and the item value consists of input layer name and value file path separated by ':', e.g. input1:file_path1,input2:file_path2. If the input_name appears both in input_layer_value_range and input_layer_value_files, input_layer_value_range of the input_name will be ignored.
    --use_xnnpack=false                 bool    optional    use XNNPack
    --disable_nnapi_cpu=false           bool    optional    Disable the NNAPI CPU device
    --nnapi_accelerator_name=           string  optional    the name of the nnapi accelerator to use (requires Android Q+)
    --nnapi_execution_preference=       string  optional    execution preference for nnapi delegate. Should be one of the following: fast_single_answer, sustained_speed, low_power, undefined
    --use_nnapi=false                   bool    optional    use nnapi delegate api
    --use_gpu=false                     bool    optional    use gpu
    --max_delegated_partitions=0        int32   optional    Max partitions to be delegated.
    --profiling_output_csv_file=        string  optional    File path to export profile data as CSV, if not set prints to stdout.
    --max_profiling_buffer_entries=1024 int32   optional    max profiling buffer entries
    --enable_op_profiling=false         bool    optional    enable op profiling
    --require_full_delegation=false     bool    optional    require delegate to run the entire graph
    --allow_fp16=false                  bool    optional    allow fp16
    --use_legacy_nnapi=false            bool    optional    use legacy nnapi api
    --num_runs=50                       int32   optional    expected number of runs, see also min_secs, max_secs
    --input_layer_value_range=          string  optional    A map-like string representing value range for *integer* input layers. Each item is separated by ':', and the item value consists of input layer name and integer-only range values (both low and high are inclusive) separated by ',', e.g. input1,1,2:input2,0,254
    --input_layer_shape=                string  optional    input layer shape
    --input_layer=                      string  optional    input layer names
    --graph=                            string  optional    graph file name
    --warmup_min_secs=0.5               float   optional    minimum number of seconds to rerun for, potentially making the actual number of warm-up runs to be greater than warmup_runs
    --warmup_runs=1                     int32   optional    minimum number of runs performed on initialization, to allow performance characteristics to settle, see also warmup_min_secs
    --output_prefix=                    string  optional    benchmark output prefix
    --benchmark_name=                   string  optional    benchmark name
    --num_threads=1                     int32   optional    number of threads
    --run_delay=-1                      float   optional    delay between runs in seconds
    --max_secs=150                      float   optional    maximum number of seconds to rerun for, potentially making the actual number of runs to be less than num_runs. Note if --max-secs is exceeded in the middle of a run, the benchmark will continue to the end of the run but will not start the next run.
    --min_secs=1                        float   optional    minimum number of seconds to rerun for, potentially making the actual number of runs to be greater than num_runs

- Go to Table of contents -

4-3-3. Benchmark example of a model that includes only standard Tensorflow Lite operations (No XNNPACK, 4 Threads)

$ bazel run -c opt tensorflow/lite/tools/benchmark:benchmark_model -- \
  --graph=${HOME}/work/tensorflow/head_pose_estimator_integer_quant.tflite \
  --num_threads=4 \
  --warmup_runs=1 \
  --enable_op_profiling=true

- Go to Table of contents -

4-3-4. Benchmark example of a model that includes only standard Tensorflow Lite operations (XNNPACK available, 4 Threads)

$ bazel run -c opt tensorflow/lite/tools/benchmark:benchmark_model -- \
  --graph=${HOME}/work/tensorflow/head_pose_estimator_integer_quant.tflite \
  --num_threads=4 \
  --warmup_runs=1 \
  --use_xnnpack=true \
  --enable_op_profiling=true

- Go to Table of contents -

4-3-5. Benchmark examples of models with non-standard Tensorflow Lite operations (Flex enabled, no XNNPACK, 4 Threads)

$ bazel run \
  -c opt \
  --config=noaws \
  --config=nohdfs \
  --config=nonccl \
  tensorflow/lite/tools/benchmark:benchmark_model_plus_flex -- \
  --graph=${HOME}/git/tf-monodepth2/monodepth2_flexdelegate_weight_quant.tflite \
  --num_threads=4 \
  --warmup_runs=1 \
  --enable_op_profiling=true

- Go to Table of contents -

4-3-6. Benchmark examples of models with non-standard Tensorflow Lite operations (Flex enabled, with XNNPACK, 4 Threads)

$ bazel run \
  -c opt \
  --config=noaws \
  --config=nohdfs \
  --config=nonccl \
  tensorflow/lite/tools/benchmark:benchmark_model_plus_flex -- \
  --graph=${HOME}/git/tf-monodepth2/monodepth2_flexdelegate_weight_quant.tflite \
  --num_threads=4 \
  --warmup_runs=1 \
  --use_xnnpack=true \
  --enable_op_profiling=true

- Go to Table of contents -

4-3-7. Execution log sample of Benchmark_Tool

Execution_log_sample_of_Benchmark_Tool
STARTING!
Min num runs: [50]
Min runs duration (seconds): [1]
Max runs duration (seconds): [150]
Inter-run delay (seconds): [-1]
Num threads: [4]
Benchmark name: []
Output prefix: []
Min warmup runs: [1]
Min warmup runs duration (seconds): [0.5]
Graph: [/home/b920405/work/tensorflow/head_pose_estimator_integer_quant.tflite]
Input layers: []
Input shapes: []
Input value ranges: []
Input layer values files: []
Allow fp16 : [0]
Require full delegation : [0]
Enable op profiling: [1]
Max profiling buffer entries: [1024]
CSV File to export profiling data to: []
Max number of delegated partitions : [0]
Use gpu : [0]
Use xnnpack : [0]
Loaded model /home/b920405/work/tensorflow/head_pose_estimator_integer_quant.tflite
The input model file size (MB): 7.37157
Initialized session in 0.39ms.
Running benchmark for at least 1 iterations and at least 0.5 seconds but terminate if exceeding 150 seconds.
count=3 first=182671 curr=171990 min=171990 max=182671 avg=176216 std=4636

Running benchmark for at least 50 iterations and at least 1 seconds but terminate if exceeding 150 seconds.
count=50 first=174371 curr=183952 min=173778 max=203173 avg=181234 std=6641

Average inference timings in us: Warmup: 176216, Init: 390, Inference: 181234
Profiling Info for Benchmark Initialization:
============================== Run Order ==============================
                 [node type]              [start]     [first]    [avg ms]        [%]      [cdf%]      [mem KB]  [times called]  [Name]
             AllocateTensors                0.000       0.058       0.058   100.000%    100.000%         0.000          1   AllocateTensors/0

============================== Top by Computation Time ==============================
                 [node type]              [start]     [first]    [avg ms]        [%]      [cdf%]      [mem KB]  [times called]  [Name]
             AllocateTensors                0.000       0.058       0.058   100.000%    100.000%         0.000          1   AllocateTensors/0

Number of nodes executed: 1
============================== Summary by node type ==============================
                 [Node type]      [count]     [avg ms]      [avg %]     [cdf %]   [mem KB]  [times called]
             AllocateTensors            1        0.058     100.000%    100.000%      0.000          1

Timings (microseconds): count=1 curr=58
Memory (bytes): count=0
1 nodes observed

Operator-wise Profiling Info for Regular Benchmark Runs:
============================== Run Order ==============================
                 [node type]              [start]     [first]    [avg ms]        [%]      [cdf%]      [mem KB]  [times called]  [Name]
                    QUANTIZE                0.000       0.164       0.166     0.092%      0.092%         0.000          1   [input_image_tensor_int8]:0
                     CONV_2D                0.166       9.293       9.710     5.358%      5.449%         0.000          1   [conv2d/Relu]:1
                 MAX_POOL_2D                9.876       0.523       0.547     0.302%      5.751%         0.000          1   [max_pooling2d/MaxPool]:2
                     CONV_2D               10.423      40.758      41.859    23.097%     28.848%         0.000          1   [conv2d_2/Relu]:3
                     CONV_2D               52.282      73.752      76.566    42.248%     71.095%         0.000          1   [conv2d_3/Relu]:4
                 MAX_POOL_2D              128.848       0.259       0.261     0.144%     71.240%         0.000          1   [max_pooling2d_2/MaxPool]:5
                     CONV_2D              129.109      15.460      16.203     8.940%     80.180%         0.000          1   [conv2d_4/Relu]:6
                     CONV_2D              145.312      13.194      13.908     7.674%     87.854%         0.000          1   [conv2d_5/Relu]:7
                 MAX_POOL_2D              159.220       0.043       0.046     0.026%     87.880%         0.000          1   [max_pooling2d_3/MaxPool]:8
                     CONV_2D              159.266       4.272       4.473     2.468%     90.348%         0.000          1   [conv2d_6/Relu]:9
                     CONV_2D              163.740       5.437       5.745     3.170%     93.518%         0.000          1   [conv2d_7/Relu]:10
                 MAX_POOL_2D              169.485       0.029       0.031     0.017%     93.535%         0.000          1   [max_pooling2d_4/MaxPool]:11
                     CONV_2D              169.516       4.356       4.558     2.515%     96.050%         0.000          1   [conv2d_8/Relu]:12
             FULLY_CONNECTED              174.074       6.666       6.992     3.858%     99.908%         0.000          1   [dense/Relu]:13
             FULLY_CONNECTED              181.066       0.160       0.167     0.092%    100.000%         0.000          1   [logits/BiasAdd_int8]:14
                  DEQUANTIZE              181.232       0.001       0.001     0.000%    100.000%         0.000          1   [logits/BiasAdd]:15

============================== Top by Computation Time ==============================
                 [node type]              [start]     [first]    [avg ms]        [%]      [cdf%]      [mem KB]  [times called]  [Name]
                     CONV_2D               52.282      73.752      76.566    42.248%     42.248%         0.000          1   [conv2d_3/Relu]:4
                     CONV_2D               10.423      40.758      41.859    23.097%     65.344%         0.000          1   [conv2d_2/Relu]:3
                     CONV_2D              129.109      15.460      16.203     8.940%     74.285%         0.000          1   [conv2d_4/Relu]:6
                     CONV_2D              145.312      13.194      13.908     7.674%     81.959%         0.000          1   [conv2d_5/Relu]:7
                     CONV_2D                0.166       9.293       9.710     5.358%     87.316%         0.000          1   [conv2d/Relu]:1
             FULLY_CONNECTED              174.074       6.666       6.992     3.858%     91.174%         0.000          1   [dense/Relu]:13
                     CONV_2D              163.740       5.437       5.745     3.170%     94.344%         0.000          1   [conv2d_7/Relu]:10
                     CONV_2D              169.516       4.356       4.558     2.515%     96.859%         0.000          1   [conv2d_8/Relu]:12
                     CONV_2D              159.266       4.272       4.473     2.468%     99.327%         0.000          1   [conv2d_6/Relu]:9
                 MAX_POOL_2D                9.876       0.523       0.547     0.302%     99.629%         0.000          1   [max_pooling2d/MaxPool]:2

Number of nodes executed: 16
============================== Summary by node type ==============================
                 [Node type]      [count]     [avg ms]      [avg %]     [cdf %]   [mem KB]  [times called]
                     CONV_2D            8      173.016      95.471%     95.471%      0.000          8
             FULLY_CONNECTED            2        7.157       3.949%     99.421%      0.000          2
                 MAX_POOL_2D            4        0.884       0.488%     99.908%      0.000          4
                    QUANTIZE            1        0.166       0.092%    100.000%      0.000          1
                  DEQUANTIZE            1        0.000       0.000%    100.000%      0.000          1

Timings (microseconds): count=50 first=174367 curr=183949 min=173776 max=203169 avg=181231 std=6640
Memory (bytes): count=0
16 nodes observed

Note: as the benchmark tool itself affects memory footprint, the following is only APPROXIMATE to the actual memory footprint of the model at runtime. Take the information at your discretion.
Peak memory footprint (MB): init=0 overall=14.7656

- Go to Table of contents -

5. Finally

I'm not good at English, so it was very hard for me to write an article in English. Please point out any oddities in the text.

- Go to Table of contents -

6. Reference articles

  1. Tensorflow Lite, Frequently Asked Questions - Github
  2. DockerHub Tensorflow
  3. https://github.com/PINTO0309/PINTO_model_zoo.git
  4. https://github.com/PINTO0309/Tensorflow-bin.git
  5. https://github.com/PINTO0309/TensorflowLite-bin.git

- Go to Table of contents -

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6