はじめに
pytorchで学習したモデルをROSで使うとき、
- モデル学習時の環境がpython3だとpython2環境のROSでの実装が面倒
- pythonだとどうしても推論の速度が出ない
などで困っていたので、libtorchを使ってモデルをC++で動かしてみたときのメモです。
(ROS2だったらpython3の環境でも使えたけどやっぱり遅い...)
環境
Ubuntu18.04, Ubuntu16.04
ROS Melodic, ROS Kinetic
libtorchのビルド
libtorchはこちらに書いてあるようにダウンロードすることが可能ですが、2019/8/3現在、旧型のABIでビルドされています( https://discuss.pytorch.org/t/sos-libtorch-conflict-with-ros/48977/6 )。
これを用いると、ROSの新型ABIとの競合ができないのでソースからビルドする必要がありました。
ビルドは以下のようにしたらうまくいきました。(https://github.com/pytorch/pytorch/blob/master/docs/libtorch.rst )
$ git clone https://github.com/pytorch/pytorch
$ cd pytorch
$ git submodule update --init --recursive
$ python setup.py build
$ mkdir build_libtorch && cd build_libtorch
$ python ../tools/build_libtorch.py
学習済みモデルの変換
学習済みモデルの変換には、tracingとannotationによる方法がありますが、ここではtracingによる手法の実装例を紹介します。
annotationによる手法はこちらを参照してください。
# 実装例
import torch
model = Model() # nn.Module
model_path = "weight.pkl" # 学習済みの重み
model.load_state_dict(torch.load(model_path))
expample = torch.rand(1, 3, 64, 64) # 入力のshapeに合わせる
traced_net = torch.jit.trace(model, example)
traced_net.save("model.pt")
ROSで学習済みモデルでの推論
rosのパッケージのCMakeList.txtの書き方の例はこんな感じです。
find_package(Torch REQUIRED)
の前にcaffe2とtorchのDIRを設定したらmake通りました。
# CMakeLists.txt
cmake_minimum_required(VERSIONS 2.8.3)
project(hoge)
add_compile_options(-std=c++11)
find_package(catkin REQUIRED COMPONENTS
roscpp
geometry_msgs
)
set(Caffe2_DIR "$ENV{HOME}/pytorch/torch/share/cmake/Caffe2")
set(Torch_DIR "$ENV{HOME}/pytorch/torch/share/cmake/Torch")
find_package(Torch REQUIRED)
include_directories(
${catkin_INCLUDE_DIRS}
)
add_executable(hoge src/hoge.cpp)
target_link_libraries(hope ${catkin_LIBRARIES} ${TORCH_LIBRARIES))
プログラム自体は、単純に先ほど保存したモデルをloadして、forwardにinputを与えるだけで推論してくれます。
C++でのtorchの実装はこちらを参考にしました。
// hoge.cpp
#include "ros/ros.h"
#include <geometry_msgs/Twist.h>
#include <torch/script.h>
class Hoge{
public:
Hoge();
void process();
private:
ros::NodeHnadle nh;
ros::Publisher vel_pub;
torch::jit::script::Module module;
};
Hoge::Hoge(){
vel_pub = nh.advertise<geometry_msgs::Twist>("/cmd_vel", 1);
module = torch::jit::load("model.pt");
}
void Hoge::process(){
ros::Rate loop_rate(10);
while(ros::ok()){
torch::Tensor input = torch::ones({1, 3, 64, 64}));
auto output = module.foward({input}).toTensor();
vel.linear.x = output[0].item<float>();
vel.angular.z = output[1].item<float>();
vel_pub.publish(vel);
loop_rate.sleep();
ros::spinOnce();
}
}
int main(int args, char **argv){
ros::init(args, args, "hoge");
Hoge hoge;
hoge.process();
return 0;
}
参考URL
PYTORCH C++ API
LOADING A PYTORCH MODEL IN C++
PyTorchで学習したモデルをC++から使う @cashiwamochi
pythonで学習したDNNモデルをC++から利用する(PyTorch & libtorch版) @nmatsui