LoginSignup
29
14

More than 3 years have passed since last update.

[English] Converting PyTorch, ONNX, Caffe, and OpenVINO (NCHW) models to Tensorflow / TensorflowLite (NHWC) in a snap

Last updated at Posted at 2020-11-14

Japanese English

1. Introduction

By following the steps in this article, you will finally be able to convert PyTorch's high-precision Semantic Segmentation U^2-Net into TensorFlow Lite. It looks like the diagram below.
ezgif.com-gif-maker (2).gif

TensorFlow is insanely unwieldy. The latest very interesting models that are released daily are PyTorch implementations across the board, and why don't you implement them in TensorFlow! I think about this on a regular basis. It's quite tedious to reimplement and re-train every single model in TensorFlow, even though we know in advance that the model will be accurate and perform well from the paper's benchmark results. There may be a discrepancy between what you think "TensorFlow is difficult to handle" and what I think is difficult to handle, but in general, I think the following are the bottlenecks.

  1. While all other frameworks are NCHW style, only TensorFlow is NHWC style
  2. The NHWC format is difficult to convert to other frameworks because of the default
  3. Operations with the same processing, same name and different interfaces are scattered all over the place and are super chaotic
  4. The latest interesting implementations feel relatively few and far between
  5. The syntax is difficult

But I think there are some good points as well. Here are some of the good things I consider to be true.

  1. Tuning the model for the device for a high performance experience
  2. Wide range of optimizations for execution environments such as TensorFlow.js, TensorRT, TF-TRT, TensorFlow Lite and MediaPipe
  3. The structure of the model is very beautiful when properly optimized.

Yes, 3. is completely my hobby. In this article, I'd like to share some of my expertise on how to take advantage of the advantages of TensorFlow while also counteracting the disadvantages. The training will be done on a framework of your choice and the reasoning will be optimized on TensorFlow. That's what I'm going to aim for.

For the muddy know-how on model conversions that I wrote in the last issue, click here [Tensorflow Lite] Various Neural Network Model quantization methods for Tensorflow Lite (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization, EdgeTPU). As of May 05, 2020. The title is in English, but the content is in Japanese. An English version is also available here. [English ver.] [Tensorflow Lite] Various Neural Network Model quantization methods for Tensorflow Lite (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization, EdgeTPU). As of May 05, 2020.

2. Various tools for model transformation

Wouldn't it be nice to be able to convert models between frameworks and run interesting models on the framework of your choice? The following tools are currently available for model conversion. For the purposes of this article, ONNX is only used as a temporary relay framework to freeze the PyTorch model. By the way, the main difference between my crude conversion tool (openvino2tensorflow) and the main tools below is that the NCHW format It's a place where you can convert to NHWC format straight away, and even perform quantization.

3. The pain of converting from NCHW-formatted models to NHWC-formatted models

As you'll see when you try it, none of the tools, other than my own tools mentioned in the previous section, can convert NCHW format to NHWC format very well. Even if you can, you'll often have a lot of garbage Transpose layers embedded in it. Conversion from NHWC format to NCHW format is generally supported. If you want to convert NCHW format to NHWC format neatly, you need to extract the weight information recorded in the model as a Numpy array and transpose it to all of them, which is very time-consuming. I've actually converted a number of models by hand, but I make a lot of mistakes anyway. Anybody can convert a huge model or a model with complex structure (e.g., a model with many branches) manually, but mistakes are inevitable. Anyway, it's hard. It's crazy hard.

「Yes! It took me two days to finally convert!」
 ↓
「Ah...What's this? I've done some trash refining...」

It makes my heart sick. I tried to do a manual TensorFlow Lite transformation of PyTorch's EfficientDet-D0 with success, but there was so much disjointing in the nature of the model that I felt like throwing up.

Regarding the NCHW to NHWC conversion, there was actually a history of direct discussions between the author of the YoloV4 paper, Alexey, and myself. So far, it's a long-lived issue that has played catch-up 37 times, but if you're interested in how it came about, it might be interesting to take a peek at it.
GitHub issue: Is there an easy way to convert ONNX or PB from (NCHW) to (NHWC)?
Screenshot 2020-11-12 09:18:11.png

4. Why convert via OpenVINO?

In this article, I will perform the NCHW to NHWC conversion, optimizing the model in the following sequence: PyTorch -> ONNX -> OpenVINO -> TensorFlow / Tensorflow Lite. It does not convert from ONNX or any other NCHW format to TensorFlow's NHWC format all at once. To put it simply, because the Model Optimizer of OpenVINO is quite excellent, I dare to incorporate the conversion to OpenVINO into my workflow only once.

There are five advantages of going through a one-time conversion to OpenVINO in my opinion, as follows.

  1. The Model Optimizer of OpenVINO optimizes the model during the conversion process by itself
  2. OpenVINO itself, as a dedicated inference framework, is specialized in the role of inference and inter-framework conversion, so it is sophisticated as a common format.
  3. Supports transformations of many frameworks, including PyTorch(ONNX), TensorFlow, Caffe, MXNet, and more.
  4. All operation information and connections between operations are output in a simple, human-readable XML file so that the structure of the trained model can be easily rewritten later using an editor.
  5. It's incorporated into OpenCV.

As for #5, it's hard to call it a benefit.

By the way, Alexey explains the usefulness of openvino2tensorflow in the comments section of an article in Russian called How to pack your neurons into a coffee machine.
Как запихать нейронку в кофеварку - How to pack your neurons into a coffee machine

5. The beauty of the model structure

An important factor in generating a deep learning model is
 1. Size
 2. Precision
 3. The beauty of the structure
I'm sorry. I'm probably the only one who gives beauty as a determining factor. I have a collection of models, so I found myself adding Beauty as an element of pursuit before long as I continued working on it. So what does Beauty mean? If you look at the diagram below, you may be able to get a better idea of what I'm talking about. Just for looks, when you convert to the TensorFlow Lite format, the activation functions and BatchNormarization are merged into Convolution and neatly packaged into an ONNX model about two-thirds the size of the original.

The model used for the conversion test from ONNX to TensorFlow Lite is the ONNX model, a 3D skeletal detection model ThreeDPoseUnityBarracuda by Digital- Standard Co. If you want to use it for your hobby or research, please release it as a free program, but if you want to use it for commercial purposes, there are restrictions. Therefore, please read the LICENSE clause carefully before using it.
https://digital-standard.com/threedpose/models/Resnet34_3inputs_448x448_20200609.onnx

ONNX OpenVINO TFLite
Resnet34_3inputs_448x448_20200609 onnx_ Resnet34_3inputs_448x448_20200609 xml model_float32 tflite

6. Conversion procedure

Converts a model in the following way: PyTorch -> ONNX -> OpenVINO -> TensorFlow / Tensorflow Lite. Here are the steps I've incorporated into my own workflow based on the advice I received on OSS.

6-1. Installing TensorFlow and OpenVINO

Install TensorFlow and OpenVINO. TensorFlow and OpenVINO must be installed with v2.3.1 or later, and OpenVINO must be installed with 2021.1 or later.

Installing_TensorFlow
$ sudo pip3 install tensorflow==2.3.1 --upgrade

The installation procedure for OpenVINO differs depending on the OS. Follow the instructions below to install OpenVINO. The supported OpenVINO versions for this procedure are 2021.1 or later.

6-2. Installing openvino2tensorflow

Install my own tool openvino2tensorflow to automatically convert OpenVINO IR models to TensorFlow saved_models, .pb files, .tflite files and .h5 files. You can install the latest package by simply running the following command. I am fixing bugs and adding more layers of support on an almost daily basis, so I recommend that you run this every time you start working on it.

Installing_openvino2tensorflow
$ sudo pip3 install openvino2tensorflow --upgrade

【Reference】 GitHub: https://github.com/PINTO0309/openvino2tensorflow

6-3. Installing PyTorch

Please select the appropriate version of the installer for your environment from the website below. In my case, I used the following specifications. The only thing to note is that you need to select the appropriate version of CUDA in your environment.

https://pytorch.org/
Screenshot 2020-11-13 00:25:38.png

Installing_PyTorch
$ sudo pip3 install torch==1.7.0+cu101 \
  torchvision==0.8.1+cu101 torchaudio==0.7.0 \
  -f https://download.pytorch.org/whl/torch_stable.html

6-4. Installing the ONNX Runtime

Installing_onnxruntime
$ sudo pip3 install onnxruntime --upgrade

【Reference】 GitHub https://github.com/microsoft/onnxruntime

6-5. Installing ONNX Simplifier

As you may have felt if you've ever output an ONNX model, the ONNX model structure is quite redundant. For example, the structure in the figure below is
simple_reshape.png
This is what happens when you convert ONNX. Ahhh... It's not beautiful.
complicated_reshape.png
When you put the ONNX model into the ONNX Simplifier that we are introducing here, it optimizes the size and structure of the overall model weights at the same time. The following figure shows how it works. It's easy to see how well it's been optimized.
comparison.png

Installation_of_ONNX_Simplifier
$ sudo pip3 install onnx-simplifier --upgrade

【Reference】 GitHub: https://github.com/daquexian/onnx-simplifier

6-6. PyTorch -> ONNX conversion

Now, let's finally get to the real work of model conversion. This article describes a special procedure for converting PyTorch to ONNX. In this procedure, I will leave you with a relatively simple and accurate model U^2-Net (Youth Square Net). I won't mention TorchScript, so if you're curious, here's an article Using TorchScript to save PyTorch models - Qiita - hirune924 is very helpful.

6-6-1. Clone the sample repository used for the conversion procedure

Clone the repository of the PyTorch implementation of the U^2-Net model of high-precision Semantic Segmentation. An image of Semantic Segmentation with U^2-Net is shown below. It's incredibly detailed and cool.
u2netqual.png

Clone_of_U^2-Net
$ git clone https://github.com/NathanUA/U-2-Net.git
$ cd U-2-Net

6-6-2. Generate onnx using pytorch_to_onnx.py, a backend module of OpenVINO's model_downloader

OpenVINO's bundled tool model_downloader downloads various models and converts them to ONNX by calling the module for automatic conversion to OpenVINO IR on the back end. You can find out what it looks like, but you'll find that it's provided by a mere Python script. This special procedure uses pytorch_to_onnx.py, called by model_downloader, to convert PyTorch's model to ONNX straight away. The advantage is that in most cases, with the exception of the very specific PyTorch model, you can easily convert .pth to ONNX with a single command without having to make any changes to the PyTorch program.
Screenshot 2020-11-12 22:16:18.png
Now, let's try to convert the sample PyTorch model U^2-Net. The parameters of the model conversion script pytorch_to_onnx.py are as follows.

No. Parameter Meaning
1 import-module Specify the part of the .py file name that contains the model structure without the .py part. If there are model files several levels down in the folder hierarchy, you can use folder_name1.folder_name2.u2net to specify the name of the folder separated by "." Example: For u2net.py in the current hierarchy, specify u2net.
2 model-name Specify the CLASS name of the model in the .py file in No.1. Example: class U2NETP(nn.Module): of U2NETP
3 input-shape Specifies the input resolution of the model in NCHW format, separated by commas with no spaces.
4 weights Specify the relative or absolute path to the .pth file where the PyTorch weights are stored.
5 output-file Specifies the name of the ONNX file after exporting.
6 input-names Specify the name of the model's input variables. In most cases, the argument name specified in the forward function of No.1 is OK. To specify more than one variable, separate them with a comma-separated list with no spaces. For example: "x,y,z"
7 output-names Specify the name of the model's output variable. In most cases, the name of the variable specified in the return statement of the forward function in No.1 is fine. To specify more than one variable, separate them with a comma without spaces. Example: "out1,out2,out3,out4"

Execute the following commands directly under the cloned repository folder.

ONNX_conversion_of_.pth_files_by_pytorch_to_onnx.py
$ python3 ${INTEL_OPENVINO_DIR}/deployment_tools/tools/model_downloader/pytorch_to_onnx.py \
  --import-module model.u2net \
  --model-name U2NETP \
  --input-shape 1,3,320,320 \
  --weights saved_models/u2netp/u2netp.pth \
  --output-file u2netp_320x320.onnx \
  --input-names "x" \
  --output-names "F.sigmoid(d0),F.sigmoid(d1),F.sigmoid(d2),F.sigmoid(d3),F.sigmoid(d4),F.sigmoid(d5),F.sigmoid(d6)"

It is successful if ONNX check passed successfully. is displayed as shown in the following figure.
Screenshot 2020-11-12 23:45:15.png
u2netp_320x320.onnx is generated properly.
Screenshot 2020-11-12 23:46:27.png

6-7. Optimizing the ONNX Model

All you have to do is specify the ONNX file you just generated as a parameter and execute the following command.

Perform_optimization_of_ONNX_models
$ python3 -m onnxsim u2netp_320x320.onnx u2netp_320x320_opt.onnx

It's optimized like crazy. The more complex the model, the more effective it is.

Before optimization u2netp_320x320.onnx After optimization u2netp_320x320_opt.onnx
u2netp_320x320.onnx.png u2netp_320x320_opt.onnx (1).png

【Reference】 GitHub https://github.com/daquexian/onnx-simplifier

6-8. ONNX -> OpenVINO IR conversion

Now, take u2netp_320x320_opt.onnx, which was optimized and generated earlier, and convert it to IR format using OpenVINO's converter. Execute the following command. If you want to convert Caffe's model, just follow the steps from here.

Conversion_to_OpenVINO_IR_format
$ python3 ${INTEL_OPENVINO_DIR}/deployment_tools/model_optimizer/mo.py \
  --input_model u2netp_320x320_opt.onnx \
  --input_shape [1,3,320,320] \
  --output_dir openvino/320x320/FP32 \
  --data_type FP32

Three files have been generated: .bin .mapping .xml.
Screenshot 2020-11-13 21:52:37.png
See below for the parameters you can specify for model_optimizer.

【Reference】 https://docs.openvinotoolkit.org/latest/openvino_docs_MO_DG_prepare_model_convert_model_Convert_Model_From_ONNX.html

【Reference】 https://docs.openvinotoolkit.org/latest/openvino_docs_MO_DG_prepare_model_convert_model_Converting_Model_General.html

6-9. OpenVINO IR -> TensorFlow / TensorFlow Lite conversion

Now for the last step. Use my homemade miscellaneous tool, openvino2tensorflow to generate NHWC-formatted TensorFlow / TensorFlow Lite saved_model, .pb, .h5 and .tflite from an NCHW-formatted IR model. Execute the following command.

No. Parameter Meaning
1 model_path The relative or absolute path to the xml file of the IR model. The xml and bin files must be in the same folder.
2 model_output_path You can specify a relative or absolute path to the output destination of the converted model file in NHWC format.
3 output_saved_model True or False. To output as saved_model, specify True.
4 output_h5 True or False. To output in .h5 format, specify True.
5 output_weight_and_json True or False. To output weights and JSON, specify True.
6 output_pb True or False. To output in .pb format, specify True.
7 output_no_quant_float32_tflite True or False. To output in .tflite Float32 precision, specify True.
8 output_weight_quant_tflite True or False. To quantize a .tflite weight, specify True.
9 output_float16_quant_tflite True or False. To float16-quantize a .tflite weight, specify True.
10 replace_swish_and_hardswish True or False. To swap Swish and Hard-Swish in the activation function, specify True. This is for performance verification of EfficientDet.
11 debug Enable debug mode. Output the configuration information of a specific layer in the middle of conversion by debugging print.
12 debug_layer_number Specify the number of the layer whose shape you want to verify by debugging print. Valid only if --debug is specified.
Generate_TensorFlow_format_(NHWC)_models_from_OpenVINO_IR_format_(NCHW)
$ openvino2tensorflow \
  --model_path openvino/320x320/FP32/u2netp_320x320.xml \
  --model_output_path saved_model_320x320 \
  --output_saved_model True \
  --output_h5 True \
  --output_pb True \
  --output_no_quant_float32_tflite True \
  --output_weight_quant_tflite True \
  --output_float16_quant_tflite True

PyTorch (NCHW) -> ONNX (NCHW) -> OpenVINO (NCHW) -> TensorFlow Lite (NHWC) conversion is complete. What do you think, was it easy? If you use this tool in saved_model format, you can convert to TFJS, TF-TRT or CoreML with a single OSS tool. Click here for additional instructions. GitHub - PINTO0309/PINTO_model_zoo I've committed a lot of sample scripts, so if you're curious about them, you can use them as a reference. As of 14.11.2020, 71 different models have been converted and committed.

Model Structure of TFLite model_float32.tflite
u2netp_320x320_float32.tflite.png

【Reference】 https://github.com/PINTO0309/openvino2tensorflow

7. Finally

I'm not a believer in TensorFlow. The reason why I am so insistent on converting to TensorFlow models in spite of this is that TensorFlow has a rich set of supported runtime environments. I'm not averse to PyTorch because of the wealth of interesting models it has to offer, but I'm not a fan of it at the moment because of the number of tricky models that can't be exported to ONNX. You may suggest that I should just use TorchScript, but I can't understand and can't stand the frequent errors that occur despite the model output function provided as a standard feature of the framework itself.(For now.) I think that as it becomes more sophisticated, it will become easier to use and more accessible.

8. Reference Articles

  1. https://github.com/digital-standard/ThreeDPoseUnityBarracuda
  2. Saving PyTorch models using TorchScript - Qiita - hirune924
  3. Model Structure Visualization Tool Netron
  4. https://github.com/PINTO0309/PINTO_model_zoo
29
14
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
14