0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

C++版PyTorch(LibTorch)試してみた

Posted at

はじめに

PyTorchで学習したモデルに推論処理を行わせたいとき、外側の処理も含めた性能を重視するとC++で動かしたいと考えることがある。

ONNXを経由してIR形式に変換してOpenVINOで実行することでより高速になるが、ネットワークレイヤが対応していない場合など、うまく変換できない場面は多い。

そこで、PyTorchのC++拡張であるLibTorchの利用が検討に上がったため、実際に使って見た記録を残す。

環境構築

公式からLibTorchをダウンロードする。C++17を使うので、「Download here (cxx11 ABI)」の方のURLからzipファイルをダウンロードする。

image.png

展開するだけでインストール完了。ビルド済みの共有ライブラリとヘッダファイルが格納されている。

wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.1.2+cpu.zip 

基本の使い方

Pythonで学習したモデルをエクスポート

試しに、入力に対して重みを乗算するだけのTestModelを作成した。重みは本来は学習させるが、今回はコンストラクタで与えることにする。これをC++で動作させることを目指す。

sample.py
from torch import nn, Tensor

class TestModel(nn.Module):

    def __init__(self, weight: Tensor) -> None:
        super().__init__()
        self.weight = weight

    def forward(self, x: Tensor) -> Tensor:
        return x * self.weight

torch.jit.trace関数にモデルと適当な入力データを与え、traced_modelを生成し、save関数でモデルと重みをファイルとして保存する。これで、重みが固定化されたモデルのエクスポートが完了する。

sample.py
import torch

def main() -> None:
    weight = Tensor([1, 2, 3])
    model = TestModel(weight)
    model.eval()

    x = Tensor([10, 20, 30])
    res = model(x)
    print(res)  # tensor([10., 40., 90.])

    traced_model = torch.jit.trace(model, x)
    traced_model.save("TestModel.pt")

if __name__ == "__main__":
    main()

C++で学習したモデルをインポート

次に、C++側のコードを書く。必要なヘッダはtorch/torch.htorch.script.hの2つのみ。torch::jit::load関数でエクスポートしたモデルを読み込んだら、forward関数で推論を実行する。

main.cpp
#include <iostream>
#include <string>
#include <torch/torch.h>
#include <torch/script.h>


int main(int argc, char** argv) {
    
    std::string model_path(argv[1]);
    auto model = torch::jit::load(model_path);

    auto input = torch::tensor(torch::ArrayRef<float>({100, 200, 300}));
    std::cout << "input: " << std::endl << input << std::endl;

    auto output = model.forward({input}).toTensor();
    std::cout << "output: " << std::endl << output << std::endl;

    return EXIT_SUCCESS;
}

CMakeLists.txtは以下のようにした。

CMakeLists.txt
cmake_minimum_required(VERSION 3.0.0)
project(torch_test CXX)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
message("TORCH_INCLUDE_DIRS: ${TORCH_INCLUDE_DIRS}")
message("TORCH_LIBRARIES: ${TORCH_LIBRARIES}")

add_executable(${PROJECT_NAME} main.cpp)

target_include_directories(
    ${PROJECT_NAME} PRIVATE
    ${TORCH_INCLUDE_DIRS}
)

target_link_directories(
    ${PROJECT_NAME} PRIVATE
)

target_link_libraries(
    ${PROJECT_NAME} PRIVATE
    ${TORCH_LIBRARIES}
)

target_compile_options(
    ${PROJECT_NAME} PRIVATE
    -O0
    -g
)

target_compile_features(
    ${PROJECT_NAME} PRIVATE
    cxx_std_17
)

ビルド&実行結果。

mkdir build && cd build
cmake -DTorch_DIR=/home/vboxuser/develop/pytorch/libtorch/share/cmake/Torch ..
make -j4
./torch_test ./TestModel.pt
# 以下実行結果
input: 
 100
 200
 300
[ CPUFloatType{3} ]
output: 
 100
 400
 900
[ CPUFloatType{3} ]

おまけ

OpenVINOでは明らかに変換できないカスタムレイヤでもC++で動作可能か確かめて見る。ここではforward内でC++の関数を呼ぶレイヤを作成してみた。

pow.cpp
#include <boost/python.hpp>

float mypow(const float x, const float y) {
    return std::pow(x, y);
}

BOOST_PYTHON_MODULE(libpow) {
    boost::python::def("pow", &mypow);
}

引数xyを受け取り、x^yを計算して返す関数powを定義した。

sample.py
import libpow

・・・()・・・

    def forward(self, x: Tensor) -> Tensor:
            p = libpow.pow(2, 2)
            return x * self.weight * p

これをforward内で呼ぶ。実質結果が4倍になる関数になり、こちらも上記と同じ手順で動作した。

なお、実装によっては以下のような警告が出る場合がある。

 TracerWarning: torch.from_numpy results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.

この場合、メッセージに表示される変数が定数として登録されてしまう。それでも構わないなら問題ないが、入力の値によって本来変化すべき変数でこの警告が表示された場合、意図しない結果になる可能性があるので、警告が出た場合は注意して確認すること。なお、関数に@torch.jit.scriptデコレータをつけることで対策できるケースがあるらしい。

参考: https://tech-blog.optim.co.jp/entry/2020/08/17/090000

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?