TotchScriptとは
TorchScriptは、Pythonコードからモデルを段階的に移行するためのツールを提供し、スタンドアロンの C++ プログラムなど、Python から独立して実行できる TorchScript プログラムへ置き換えることができる。つまり、使い慣れたPythonツールを使用して PyTorch でモデルを作成し、推論用にC++へエクスポートすることができ、高速化が図れる。
単純モデルのエクスポート
Pytorchで学習したモデルを元に、torch.jit.trace関数を用いてTorchScriptプログラムへ変換する。
※ ジャストインタイム(JIT)コンパイラー・・・実行時にコードをコンパイルするコンパイラの一種で、実行時に必要な部分のコードを即座にコンパイルして実行する。プラットフォーム(Windows、macOS、LinuxなどのOSやハードウェア)に依存しないソースコードや中間コードの状態でソフトウェアを配布できる。
ここに簡単な例を示します:
import torch
def my_function(a, b):
return a * b
traced_function = torch.jit.trace(my_function, (torch.randn(3), torch.randn(3)))
この例では、2つのテンソルを入力として、その要素ごとの積を出力とする簡単な関数「my_function」を定義しました。次に、「torch.jit.trace」を使ってこの関数をトレースし、結果を新しい関数「traced_function」に格納しました。
「torch.jit.trace」を呼び出す際には、トレースするための関数の例の入力を提供する必要があります。この例では、形状が「(3,)」のランダムなテンソルを入力として使用しました。
このようにトレースされた関数は、通常のPython関数と同じように呼び出すことができます。JITコンパイラーは実行時に計算を最適化することで、元のPython関数よりも優れたパフォーマンスが得られます。
画像処理モデルのエクスポート
transposeメソッドとunsqueezeメソッドで、HWC(高さ、幅、チャネル)形式のフレームをCHW(チャネル、高さ、幅)形式のテンソルに変換している。
while True:
ret, frame = cap.read()
if ret == False:
break
result = model(frame)
frame = frame.transpose((2, 0, 1))
tensor = torch.from_numpy(frame)
tensor = tensor.unsqueeze(0)
example_input = torch.randn(1, 3, 224, 224)
model.eval()
input_tensor = torch.zeros((1, 3, 640, 640), dtype=torch.float32)
traced_from_pytorch_model = torch.jit.trace(model, tensor)
traced_from_pytorch_model.save("traced_from_pytorch_model.pt")
物体検知の代表格であるYOLOのPythonライブラリでは、Torchscript用のモデルをエクスポートするメゾッドが存在する。
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8n.pt") # load an official model
# Export the model
model.export(format="torchscript")
Visual Studioの設定
下記の記事を参照して、Visual Studio C++の拡張機能LibTorchを使用してビルドするための設定をする(割とややこしい)。
コード全文
ビルド後の実行ファイル(.exe)に引数として、torchscript形式で保存したPyTorchのモデルと解析対象の動画を渡すようにしている。
#if _DEBUG
#pragma comment(lib, "opencv_world470d.lib")
#else
#pragma comment(lib, "opencv_world470.lib")
#endif
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
#include <thread>
#include <memory>
using namespace std;
#include <opencv2/opencv.hpp>
using namespace cv;
#include "opencv2/highgui/highgui.hpp" //動画を表示する際に必要
constexpr int TIME_TO_SLEEP = 1000;
int main(int argc, const char* argv[]) {
if (argc != 3)
{
cout << "" << endl;
cout << " Error" << endl;
cout << "" << endl;
cout << " How to use : cout.exe torchscript sample.mp4" << endl;
cout << "" << endl;
cout << " argv:" << endl;
cout << " cout.exe : this program name" << endl;
cout << " torchscript : torchscript file name" << endl;
cout << " sample.mp4 : input movie file path" << endl;
cout << "" << endl;
std::this_thread::sleep_for(std::chrono::milliseconds(TIME_TO_SLEEP));
return 0;
}
std::string TorchScriptFile = argv[1];
std::string VideoFile = argv[2];
//VideoCapture cap(0);
VideoCapture cap;
cap.open(VideoFile);
if (!cap.isOpened())
{
std::cerr << "cannot open this movie file!" << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(TIME_TO_SLEEP));
return -1;
}
int fourcc, width, height;
double fps;
width = (int)cap.get(cv::CAP_PROP_FRAME_WIDTH); // フレーム横幅を取得
height = (int)cap.get(cv::CAP_PROP_FRAME_HEIGHT); // フレーム縦幅を取得
fps = cap.get(cv::CAP_PROP_FPS); // フレームレートを取得
double max_frame = cap.get(CAP_PROP_FRAME_COUNT);
fourcc = cv::VideoWriter::fourcc('m', 'p', '4', 'v'); // AVI形式を指定
VideoWriter videoWriter;
videoWriter.open("output.mp4", fourcc, fps*1.2, cv::Size(width, height), true);
// TorchScript
torch::Tensor tensor = torch::rand({ 2, 3 });
if (torch::cuda::is_available()) {
std::cout << "CUDA is available! Training on GPU" << std::endl;
auto tensor_cuda = tensor.cuda();
std::cout << tensor_cuda << std::endl;
}
else
{
std::cout << "CUDA is not available! Training on CPU" << std::endl;
std::cout << tensor << std::endl;
}
if (argc != 3) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
std::cout << argv[1] << std::endl;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(TorchScriptFile);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
cv::Mat frame;
int index = 0;
//for (int i = 0; i < int(max_frame); i++)
while (cap.read(frame)) {
cap >> frame;
if (frame.empty() == true) {
continue;
}
int n = (int)cap.get(cv::CAP_PROP_POS_FRAMES);
cout << n << endl;
index++;
// Load image
cv::cvtColor(frame, frame, cv::COLOR_BGR2RGB);
cv::resize(frame, frame, cv::Size(640, 640));
// Convert the image from OpenCV format to a Torch tensor
torch::Tensor tensor_image = torch::from_blob(frame.data, { 1, frame.rows, frame.cols, 3}, torch::kByte);
//並び替え [C, H, W, Z] to [C, H, W, Z],
tensor_image = tensor_image.permute({ 0, 3, 1, 2 });
tensor_image = tensor_image.toType(torch::kFloat);
tensor_image = tensor_image.div(255);
std::cout << "The tensor image: " << tensor_image.sizes() << std::endl;
// Use the TorchScript model to make a prediction on the input
std::vector<torch::jit::IValue> inputs;
inputs.push_back(tensor_image);
at::Tensor output = module.forward({ tensor_image }).toTensor();
std::cout << "The tensor output: " << output.sizes() << std::endl;
// Process output tensor
auto detection = output[0][0];
auto num_detections = detection.size(0);
std::cout << "Number of detections: " << num_detections << std::endl;
for (int i = 0; i < num_detections; i++) {
cout << detection[i] << endl;
auto item = detection[i];
auto class_id = item[5].item<int>();
auto score = item[4].item<float>();
auto bbox = item.slice(0, 0, 4).tolist();
std::vector<float> bbox(item.slice(0, 0, 4).to(at::kCPU).data<float>(), item.slice(0, 0, 4).to(at::kCPU).data<float>() + item.slice(0, 0, 4).numel());
std::cout << "Class ID: " << class_id << ", Score: " << score << ", BBox: " << bbox[0] << "," << bbox[1] << "," << bbox[2] << "," << bbox[3] << std::endl;
std::cout << "Class ID: " << class_id << ", Score: " << score << std::endl;
}
cv::imshow("再生中", frame);
videoWriter << frame;
const int key = cv::waitKey(1);
if (key == 'q')
{
break;
}
}
cap.release();
videoWriter.release();
cv::destroyAllWindows();
std::cout << "END!\n";
std::this_thread::sleep_for(std::chrono::milliseconds(TIME_TO_SLEEP));
return 0;
}
まとめ
TotchScriptによるPytorchモデルのC++移管方法を紹介した。