0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Running DETR's Backbone on Raspberry Pi AI Camera and Transformer on Raspberry Pi

Last updated at Posted at 2025-07-07

Running DETR's Backbone on Raspberry Pi AI Camera and Transformer on Raspberry Pi

Overview of this Article

This is Matsuoka from Semiconductor Solutions Corporation's customer support for AITRIOS.

In the previous article, we separated the Backbone and the Transformer of the DETR network and verified end-to-end object detection functionality on a PC.

We will replace the Backbone with MobileNetV2 and train it, then separate the network into two parts: the Backbone and the Transformer.
Furthermore, we will add a Classifier to the Backbone and perform transfer learning for a classification model.
Finally, we will use Python to verify the object detection operation using two models: the Backbone with the added Classifier and the separated Transformer.

In this article, we'll verify the operation using an actual Raspberry Pi AI Camera.

  1. We will quantize the backbone with class classification on a PC.
  2. We will set up the Raspberry Pi environment, including copying the model to the Raspberry Pi.
  3. We will execute the Converter and Packager on the Raspberry Pi for the quantized classification-enabled backbone.
  4. We will implement Python code to retrieve the output tensor from the Raspberry Pi AI Camera and perform object detection when the probability of any class is high.
  5. We will capture images using the Raspberry Pi AI Camera and verify the end-to-end operation.

deployed_on_raspi.png

If you're reading this article, please also read the previous one.
To try the content of this article, you need to have completed at least up to the "Adding Classifier Layer to Backbone and Training for Classification" section in the previous article.

If you're not familiar with the Raspberry Pi AI Camera, please check this website as well.
Note that the Raspberry Pi AI Camera is a product intended for business use.

  • If you notice any errors or omissions in this article, or have any questions, please leave a comment on the article.
    Please understand that it may take some time to respond to comments, and in some cases, we may not be able to respond.
  • This article is intended to introduce an application example and does not guarantee the performance or quality when actually implemented.
  • We have not conducted any third-party patent searches.
  • For any issues related to AITRIOS, please refer to the AITRIOS support page.
  • For questions or comments regarding the Raspberry Pi AI Camera, please post them in the Raspberry Pi AI Camera forum:

Quantization of Backbone with Class Classification

About Quantization

To deploy PyTorch AI models to the Raspberry Pi AI Camera, you need to convert the floating-point model to an 8-bit integer model using the Model Compression Toolkit (MCT).
MCT is an Apache-2.0 licensed open-source software that runs on Python and provides quantization techniques such as Post-training quantization (PTQ).

PTQ determines the value range (clipping range) of each tensor in a trained model, that value can take based on actual dataset inputs.
Then it converts the model to an integer model by representing this clipping range with 8-bit integers.

This quantization calculation is called calibration.
For proper quantization, the calibration dataset needs to have a certain level of coverage for actual inputs.
Therefore, we perform calibration using the same dataset used for model training.

For more details, please refer to the MCT GitHub repository.
Note that after model quantization, MCT saves the model in ONNX format.

Quantization Methods Actually Used

Instead of PTQ, I used Gradient-based post-training quantization provided by MCT.

Initially, I applied PTQ for quantization.
However, when I evaluated the cosine similarity between the output tensor generated by PyTorch and the output tensor generated by the Raspberry Pi AI Camera, the results were not satisfactory. After switching to Gradient-based post-training quantization, the cosine similarity improved, so I decided to use this quantization method.

If the inference accuracy of the Raspberry Pi AI Camera significantly differs from the pre-quantization accuracy, it might be worth considering changing the quantization method.

Executing Quantization

For the implementation in this article, we will continue to use the feature input implementation folder, which is a copy of the cloned detr folder, as we did in the previous article.
From here on, we'll refer to this folder as the "Separated DETR Verification Folder".
We won't be using the folder cloned for training purposes.

We will quantize the backbone with class classification created in the previous article using the Model Compression Toolkit (MCT).
Place the quantization.py file, which is provided at the end of this section, directly under the Separated DETR Verification Folder and execute it.

python quantization.py

While the Dockerfile in the previous article includes the MCT installation, if you're using your own environment, please install model-compression-toolkit==2.0.0.
For the operating conditions of MCT version 2.0.0, refer to the Readme.md in the Version 2.0.0 release.

Quantization Code Flow

  1. Load the floating-point PyTorch model.

  2. Load the calibration dataset.

    As mentioned earlier, we simply use the same dataset used for training.
    For calibration images, instead of standardizing RGB values to mean 0 and variance 1, we normalize the range to [0..1], just as we did during training.

  3. Create a generator to feed the calibration dataset into MCT.

    Note that the shape of the dataset array provided into MCT is [batch count, batch size, channels, height, width].

  4. Set up the configuration for MCT.。

  5. Perform quantization using MCT.

  6. Save the quantized model.

    [!TIP]
    To perform inference with the saved ONNX model, it is necessary to configure the InferenceSession of ONNX Runtime with mctq.get_ort_session_options.

Quantization Code

quantization.py

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

import model_compression_toolkit as mct
from model_compression_toolkit.core import QuantizationErrorMethod

from for_separation.mobilenet import mobilenet_with_feature_output
from for_separation.my_coco import CocoClassificationDataset

from torchvision.models import mobilenet_v2
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='./backbone_with_classifier_weight.pth', help='The path to the keras model')
    parser.add_argument('--annotation_file', type=str, default='/data/image/coco/annotations/instances_train2017.json', help='The path to the annotation file')
    parser.add_argument('--image_folder', type=str, default='/data/image/coco/images/train2017', help='The path to the image folder')
    parser.add_argument('--quantized_model_path', type=str, default='separated_mobilenet_quantized.onnx', help='The path to the quantized model')
    parser.add_argument('--num_of_classes', default=91, type=int,  help='the number of classes')
    args = parser.parse_args()

    batch_size = 32

    #<1>  Load a floating-point PyTorch model.
    model = mobilenet_with_feature_output(num_of_classes = args.num_of_classes)
    model.load_state_dict(torch.load(args.model_path, map_location=torch.device('cpu')))

    #<2> Load a calibration dataset for quantization.
    #    The calibration dataset is normalized to match the normalization used during training.
    train_dataset = CocoClassificationDataset(
        annotation_file = args.annotation_file,
        image_folder = args.image_folder,
        num_of_classes = args.num_of_classes,
        transform = transforms.Compose([
            transforms.Resize(size=(224,224)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)
        ])
    )

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    image_data_loader = iter(train_loader)

    #<3> Create a representative dataset generator
    n_iter=len(train_loader)
    def representative_data_gen() -> list:
        ds_iter = iter(train_loader)
        for _ in range(n_iter):
            yield [next(ds_iter)[0]]

    #<4> Set a configuration.
    q_config = mct.core.QuantizationConfig(activation_error_method=QuantizationErrorMethod.MSE,
                                       weights_error_method=QuantizationErrorMethod.MSE,
                                       weights_bias_correction=True,
                                       shift_negative_activation_correction=True,
                                       z_threshold=16)
    tpc = mct.get_target_platform_capabilities("pytorch", 'imx500', target_platform_version='v1')
    ptq_config = mct.core.CoreConfig(quantization_config=q_config)

    #<5> Quantize the floating-point PyTorch model to the 8-bit integer PyTorch model.
    quantized_model, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization(model=model,
        representative_data_gen=representative_data_gen,
        core_config=ptq_config,
        target_platform_capabilities=tpc)

    #<6> Save the integer model as an ONNX model.
    mct.exporter.pytorch_export_model(model=quantized_model,
                                      save_model_path=args.quantized_model_path,
                                      repr_dataset=representative_data_gen)

Raspberry Pi Setup

Connect the Raspberry Pi AI Camera to the Raspberry Pi and install the necessary software.
Following Picamera2's recommendations, we'll install Python libraries without using a virtual environment.

Connecting the Raspberry Pi AI Camera

Lift the slider on the Raspberry Pi socket, insert the Raspberry Pi AI Camera's flexible flat cable into the socket, then push the slider back.

The images below show the Raspberry Pi 4 Model B.

connecting_pi_ai_camera.jpg

Software Installation

Verify that /home/username/.local/bin is in your PATH. If not, add the following to your .bashrc file to include it:

export PATH=$PATH:/home/username/.local/bin

Copying Files to Raspberry Pi

Copy the Separated DETR Verification Folder entirely to the Raspberry Pi.
We'll use the created model along with some modified code on the Raspberry Pi.

From this point on, we'll be operating on the Raspberry Pi.

Executing Converter and Packager on the Quantized Model

Execute the Converter on the quantized model.

imxconv-pt -i separated_mobilenet_quantized.onnx -o ./converted

The output file (packerOut.zip) will be saved in the converted output folder.
For more details, please refer to Executing model conversion.

Note that an error will occur if the output folder already exists.
If the folder already exists, rename or delete it before proceeding.

Next, execute the Packager on the output file from the Converter.

imx500-package -i ./converted/packerOut.zip -o ./packaged

The result file (network.rpk) will be saved in the packaged output folder.
For more details, please refer to How to use IMX500 Packager.

Implementation of Verification Code

Here's an overview of the code:

  • It retrieves the Output tensor from the Raspberry Pi AI Camera. If the probability of class classification is high for any class, it executes object detection using the Transformer and saves an image with the detection results drawn on it.
  • The program terminates after executing object detection with the Transformer a specified number of times.
  • The Output tensor is retrieved using a callback function.

Processing Flow of the Verification Code

  1. Create the Transformer model network and load its weights.
  2. Load the packaged AI model (Backbone with Classifier) and deploy it to the Raspberry Pi AI Camera.
  3. Retrieve probabilities and features from the output tensor obtained by the callback function.
  4. Evaluate the probabilities of class classification. If any are high, perform object detection using the Transformer.
  5. Save the resulting image with bounding boxes drawn.
  6. If the number of object detection executions has not reached the specified count, return to step 3.

Separately from this flow, the callback function update_output_tensor_and_input_tensor saves the output tensor and input tensor received from the Raspberry Pi AI Camera to global variables.

validate_with_ai_camera.py

Place this new file directly under the copied Separated DETR verification folder.

import argparse
import sys
import time

import cv2
import numpy as np

from picamera2 import CompletedRequest, MappedArray, Picamera2
from picamera2.devices import IMX500

from pathlib import Path
import torch
from models import build_model
from main import get_args_parser
from detect import detect, draw_boxes

output_tensor = None
overlay = None

def update_output_tensor(request: CompletedRequest):

    global output_tensor
    global overlay
    global roi

    np_outputs = imx500.get_outputs(request.get_metadata())
    if np_outputs is None:
        return

    output_tensor = np_outputs

    with MappedArray(request, stream="main") as m:
        overlay = m.array.copy()


def get_args():
    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])

    parser.add_argument("--model", type=str, help="Path of the model",
                        default="./packaged/network.rpk")
    parser.add_argument("--fps", type=int, default=1, help="Frames per second")
    parser.add_argument('--resutlt_image_path', type=str, default='./images', help='The path to an result image')
    parser.add_argument('--transformer_path', type=str, default='transformer.pth', help='The path to the transformer model')
    parser.add_argument('--num_executions', default='10', type=int, help='Number of executions')
    parser.add_argument('--probability_thres', default='0.7', type=float, help='Probability threshold value to execute transformer')
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    if args.resutlt_image_path:
        Path(args.resutlt_image_path).mkdir(parents=True, exist_ok=True)

    #<1> Create the transformer model and load weights for the model.
    device = torch.device(args.device)

    model, criterion, postprocessors = build_model(args)
    model.to(device)
    checkpoint = torch.load(args.transformer_path, map_location='cpu')
    model.load_state_dict(checkpoint)
    model.eval()


    #<2> Load a packagedmodel and deploy it to the Raspberry Pi AI Camera.
    imx500 = IMX500(args.model)

    picam2 = Picamera2(imx500.camera_num)
    config = picam2.create_preview_configuration(controls={"FrameRate": args.fps}, buffer_count=3)
    imx500.show_network_fw_progress_bar()

    #<3> Start Raspberry Pi AI Camera operation and set callback function.
    picam2.start(config, show_preview=True)
    # Register the callback function to get an output tensor and input tensor.
    picam2.pre_callback = update_output_tensor

    # Wait for the first output tensor from the Raspberry Pi AI Camera.
    time.sleep(1.0)
    while output_tensor == None:
        pass

    execution_count = 0
    while execution_count < args.num_executions:

        #<4> Obtain classification probabilities and features from the output tensor.
        #    The output tensor is updated by the callback function.
        probabilities = output_tensor[0]
        features = output_tensor[1]

        if np.amax(probabilities[1:91]) > args.probability_thres :
            #<5> Execute transformer to detect objects.
            features = features.reshape(1,256,7,7)
            features = torch.from_numpy(features.astype(np.float32)).clone()
            features.to(device)
            size = imx500.get_input_size() # Get the size of input tensor.
            scores, boxes = detect(features , model, device, size)

            if scores.shape[0] > 0:
                #<6> Draw boundary box into the image (input tensor) and save it.
                im = np.array(overlay, dtype=np.uint8)
                im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
                mat_img = draw_boxes(im, scores, boxes, args.resutlt_image_path)
                cv2.imwrite(args.resutlt_image_path + ('/draw_') + format(execution_count, '04') + ('.jpg'), mat_img)

                print('Detect')
                execution_count += 1

detect.py

We will modify the draw_boxes function in the detect.py file within the Separated DETR verification folder.
The input tensor size is 224 x 224, but since the image obtained on the Raspberry Pi is a 640 x 480 image, we will adjust the coordinates of the bounding boxes to be drawn.

def draw_boxes(image, scores, boxes, image_path):
    scores_txt = []
    for i in scores:
        scores_txt.append('{} : {:.3f}'.format(torch.argmax(i).item(),torch.amax(i).item()))

    j = 0
    for i in boxes.tolist():
        cv2.rectangle(image, (int(i[0]*640/224), int(i[1]*480/224)), (int(i[2]*640/224), int(i[3]*480/224)), (0, 0, 255),1)
        cv2.rectangle(image, (int(i[0]*640/224), int(i[1]*480/224)), (int(i[0]*640/224 + 9 * len(scores_txt[j])), int(i[1]*480/224) + 14 ), (0,255,255),thickness=-1)
        cv2.putText(image,scores_txt[j],(int(i[0]*640/224)+2, int(i[1]*480/224)+12),cv2.FONT_HERSHEY_PLAIN,1,(0,0,0),1,cv2.LINE_AA)
        j += 1

    mat_img = cv2.addWeighted(image, 0.5, image, 0.5, 0)

    return mat_img

Regarding the Verification of Output Tensor Array

The output Tensor array received from the Raspberry Pi AI Camera may differ from the original output tensor array defined in the PyTorch or Keras model.
The validate_with_ai_camera.py script mentioned above is implemented after verifying the output tensor array, so it should generally work without issues.
However, if time allows, it's recommended to check the output tensor array just to be sure.

You can verify the output tensor array in the dnnParams.xml file located in the Converted folder.
In this case, the content should be as follows:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<dnnParams>
    <networks>
        <network name="separated_moblienet_quantized" ordinal="0" type="">
            <inputTensors>
                <inputTensor persistency="1" ordinal="0" name="Placeholder.input.uid1:0" l2Offset="3900192" numOfDimensions="3" bitsPerElement="8" shift="0" scale="0.00390625" format="unsigned">
                    <dimensions>
                        <dimension size="3" serializationOrder="2" ordinal="0" padding="0"/>
                        <dimension size="224" serializationOrder="1" ordinal="1" padding="0"/>
                        <dimension size="224" serializationOrder="0" ordinal="2" padding="0"/>
                    </dimensions>
                </inputTensor>
            </inputTensors>
            <outputTensors>
                <outputTensor ordinal="1" name="transform-12-0-/resize/layer/Conv:0" l2Offset="4054816" numOfDimensions="3" bitsPerElement="8" shift="0" scale="0.0625" format="signed">
                    <dimensions>
                        <dimension size="256" serializationOrder="0" ordinal="0" padding="0"/>
                        <dimension size="7" serializationOrder="1" ordinal="1" padding="0"/>
                        <dimension size="7" serializationOrder="2" ordinal="2" padding="0"/>
                    </dimensions>
                </outputTensor>
                <outputTensor ordinal="0" name="transform-2-/backbone_classifier_1/layer/Gemm:0" l2Offset="4073248" numOfDimensions="1" bitsPerElement="8" shift="0" scale="0.00390625" format="unsigned">
                    <dimensions>
                        <dimension size="91" serializationOrder="0" ordinal="0" padding="0"/>
                    </dimensions>
                </outputTensor>
            </outputTensors>
        </network>
    </networks>
    <l2memory totalSize="8388480" coefficientsSize="2814048" reservedMemorySize="1024" networksRuntimeSize="1674240"/>
</dnnParams>

From the "ordinal" and "name" in the following lines of dnnParams.xml, we can determine that the Classifier output (probabilities) is described first, followed by the Convolution output (feature) of the backbone with Classifier.

<outputTensor ordinal="1" name="transform-12-0-/resize/layer/Conv:0">
<outputTensor ordinal="0" name="transform-2-/backbone_classifier_1/layer/Gemm:0">

Additionally, from the lines below, we can determine that the order of the array is the same as the original array because the values of 'serializationOrder' and 'ordinal' are the same.

<dimension size="256" serializationOrder="0" ordinal="0" padding="0"/>
<dimension size="7" serializationOrder="1" ordinal="1" padding="0"/>
<dimension size="7" serializationOrder="2" ordinal="2" padding="0"/>

If the dnnParams.xml differs from the content shown above, please modify the following code in validate_with_ai_camera.py.

probabilities = output_tensor[0]
features = output_tensor[1]

Executing Inference

When you run validate_with_ai_camera.py, the resulting images with bounding boxes, class IDs, and probabilities will be saved in the folder specified by the --result_image_path option.
Set the --device option to 'cpu'.

If you have modified Transformer parameters such as dim_feedforward or hidden_dim for experiments, please set these values using options.

python validate_with_ai_camera.py --device cpu

The detection accuracy appears to have slightly decreased, but overall, it is functioning as expected.
While the images mostly consist of large objects, I was impressed by its resistance to interference such as blur and noise.

Correct Detection 1 Correct Detection 2 Missed Detection False Detection
File name 000000002280.jpg 000000007713.jpg 000000024877.jpg 000000003623.jpg
Detection results

In consideration of copyright, the COCO images are presented in a concealed manner.

At the end

I sincerely appreciate your patience in following along with this lengthy article.
I hope that this journey has provided you with some valuable insights, even if just a small amount.

When you need assistance

If you encounter any issues while navigating through this article, please feel free to leave a comment or check the Raspberry Pi Forum for additional support.
I appreciate your understanding that it may take some time to respond to comments.

Additionally, if you have any issues regarding AITRIOS that are not covered in this article, please contact us via the link below:

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?