LoginSignup
0
2

【C++】学習済みPyTorchモデルのC++(TotchScript)移管

Last updated at Posted at 2023-06-19

TotchScriptとは

 TorchScriptは、Pythonコードからモデルを段階的に移行するためのツールを提供し、スタンドアロンの C++ プログラムなど、Python から独立して実行できる TorchScript プログラムへ置き換えることができる。つまり、使い慣れたPythonツールを使用して PyTorch でモデルを作成し、推論用にC++へエクスポートすることができ、高速化が図れる。

単純モデルのエクスポート

Pytorchで学習したモデルを元に、torch.jit.trace関数を用いてTorchScriptプログラムへ変換する。
※ ジャストインタイム(JIT)コンパイラー・・・実行時にコードをコンパイルするコンパイラの一種で、実行時に必要な部分のコードを即座にコンパイルして実行する。プラットフォーム(Windows、macOS、LinuxなどのOSやハードウェア)に依存しないソースコードや中間コードの状態でソフトウェアを配布できる。

ここに簡単な例を示します:

import torch

def my_function(a, b):
    return a * b

traced_function = torch.jit.trace(my_function, (torch.randn(3), torch.randn(3)))

この例では、2つのテンソルを入力として、その要素ごとの積を出力とする簡単な関数「my_function」を定義しました。次に、「torch.jit.trace」を使ってこの関数をトレースし、結果を新しい関数「traced_function」に格納しました。

「torch.jit.trace」を呼び出す際には、トレースするための関数の例の入力を提供する必要があります。この例では、形状が「(3,)」のランダムなテンソルを入力として使用しました。

このようにトレースされた関数は、通常のPython関数と同じように呼び出すことができます。JITコンパイラーは実行時に計算を最適化することで、元のPython関数よりも優れたパフォーマンスが得られます。

画像処理モデルのエクスポート

transposeメソッドとunsqueezeメソッドで、HWC(高さ、幅、チャネル)形式のフレームをCHW(チャネル、高さ、幅)形式のテンソルに変換している。

while True:
    ret, frame = cap.read()
    if ret == False:
        break
    result = model(frame)
    frame = frame.transpose((2, 0, 1))
    tensor = torch.from_numpy(frame)
    tensor = tensor.unsqueeze(0)
    example_input = torch.randn(1, 3, 224, 224)
    model.eval()
    input_tensor = torch.zeros((1, 3, 640, 640), dtype=torch.float32)
    traced_from_pytorch_model = torch.jit.trace(model, tensor)
    traced_from_pytorch_model.save("traced_from_pytorch_model.pt")

物体検知の代表格であるYOLOのPythonライブラリでは、Torchscript用のモデルをエクスポートするメゾッドが存在する。

from ultralytics import YOLO

# Load a model
model = YOLO("yolov8n.pt")  # load an official model

# Export the model
model.export(format="torchscript")

Visual Studioの設定

下記の記事を参照して、Visual Studio C++の拡張機能LibTorchを使用してビルドするための設定をする(割とややこしい)。

コード全文

ビルド後の実行ファイル(.exe)に引数として、torchscript形式で保存したPyTorchのモデルと解析対象の動画を渡すようにしている。

C++
#if _DEBUG
#pragma comment(lib, "opencv_world470d.lib")
#else
#pragma comment(lib, "opencv_world470.lib")
#endif

#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
#include <thread>
#include <memory>
using namespace std;

#include <opencv2/opencv.hpp>
using namespace cv;

#include "opencv2/highgui/highgui.hpp" //動画を表示する際に必要

constexpr int TIME_TO_SLEEP = 1000;

int main(int argc, const char* argv[]) {

    if (argc != 3)
    {
        cout << "" << endl;
        cout << " Error" << endl;
        cout << "" << endl;
        cout << " How to use : cout.exe torchscript sample.mp4" << endl;
        cout << "" << endl;
        cout << " argv:" << endl;
        cout << " cout.exe      : this program name" << endl;
        cout << " torchscript   : torchscript file name" << endl;
        cout << " sample.mp4    : input movie file path" << endl;
        cout << "" << endl;
        std::this_thread::sleep_for(std::chrono::milliseconds(TIME_TO_SLEEP));
        return 0;
    }

    std::string TorchScriptFile = argv[1];
    std::string VideoFile = argv[2];

    //VideoCapture cap(0);
    VideoCapture cap;
    cap.open(VideoFile);

    if (!cap.isOpened())
    {
        std::cerr << "cannot open this movie file!" << std::endl;
        std::this_thread::sleep_for(std::chrono::milliseconds(TIME_TO_SLEEP));
        return -1;
    }

    int    fourcc, width, height;
    double fps;
    width = (int)cap.get(cv::CAP_PROP_FRAME_WIDTH);	// フレーム横幅を取得
    height = (int)cap.get(cv::CAP_PROP_FRAME_HEIGHT);	// フレーム縦幅を取得
    fps = cap.get(cv::CAP_PROP_FPS);					// フレームレートを取得
    double max_frame = cap.get(CAP_PROP_FRAME_COUNT);
    fourcc = cv::VideoWriter::fourcc('m', 'p', '4', 'v');	// AVI形式を指定
    VideoWriter videoWriter;
    videoWriter.open("output.mp4", fourcc, fps*1.2, cv::Size(width, height), true);

    // TorchScript
    torch::Tensor tensor = torch::rand({ 2, 3 });
    if (torch::cuda::is_available()) {
        std::cout << "CUDA is available! Training on GPU" << std::endl;
        auto tensor_cuda = tensor.cuda();
        std::cout << tensor_cuda << std::endl;
    }
    else
    {
        std::cout << "CUDA is not available! Training on CPU" << std::endl;
        std::cout << tensor << std::endl;
    }

    if (argc != 3) {
        std::cerr << "usage: example-app <path-to-exported-script-module>\n";
        return -1;
    }


    torch::jit::script::Module module;
    std::cout << argv[1] << std::endl;
    try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load(TorchScriptFile);
    }
    catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }


    cv::Mat frame;
    int index = 0;
    //for (int i = 0; i < int(max_frame); i++)
    while (cap.read(frame)) {
        cap >> frame;

        if (frame.empty() == true) {
            continue;
        }

        int n = (int)cap.get(cv::CAP_PROP_POS_FRAMES);
        cout << n << endl;
        index++;
        
        // Load image
        cv::cvtColor(frame, frame, cv::COLOR_BGR2RGB);
        cv::resize(frame, frame, cv::Size(640, 640));

        // Convert the image from OpenCV format to a Torch tensor
        torch::Tensor tensor_image = torch::from_blob(frame.data, { 1, frame.rows, frame.cols, 3}, torch::kByte);
        //並び替え [C, H, W, Z] to [C, H, W, Z],
        tensor_image = tensor_image.permute({ 0, 3, 1, 2 });
        tensor_image = tensor_image.toType(torch::kFloat);
        tensor_image = tensor_image.div(255);
        std::cout << "The tensor image: " << tensor_image.sizes() << std::endl;

        // Use the TorchScript model to make a prediction on the input
        std::vector<torch::jit::IValue> inputs;
        inputs.push_back(tensor_image);
        at::Tensor output = module.forward({ tensor_image }).toTensor();
        std::cout << "The tensor output: " << output.sizes() << std::endl;

        // Process output tensor
        auto detection = output[0][0];
        auto num_detections = detection.size(0);
        std::cout << "Number of detections: " << num_detections << std::endl;

        for (int i = 0; i < num_detections; i++) {
            cout << detection[i] << endl;
            auto item = detection[i];
            auto class_id = item[5].item<int>();
            auto score = item[4].item<float>();
            auto bbox = item.slice(0, 0, 4).tolist();
            std::vector<float> bbox(item.slice(0, 0, 4).to(at::kCPU).data<float>(), item.slice(0, 0, 4).to(at::kCPU).data<float>() + item.slice(0, 0, 4).numel());
            std::cout << "Class ID: " << class_id << ", Score: " << score << ", BBox: " << bbox[0] << "," << bbox[1] << "," << bbox[2] << "," << bbox[3] << std::endl;
            std::cout << "Class ID: " << class_id << ", Score: " << score << std::endl;
        }

        cv::imshow("再生中", frame);
        videoWriter << frame;
        const int key = cv::waitKey(1);
        if (key == 'q')
        {
            break;
        }
    }

    cap.release();
    videoWriter.release();
    cv::destroyAllWindows();
    std::cout << "END!\n";

    std::this_thread::sleep_for(std::chrono::milliseconds(TIME_TO_SLEEP));
    return 0;

}

まとめ

TotchScriptによるPytorchモデルのC++移管方法を紹介した。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2