LoginSignup
6
7

More than 3 years have passed since last update.

pytorchで学習済みのモデルをROS C++(libtorch)で使う

Last updated at Posted at 2019-08-25

はじめに

pytorchで学習したモデルをROSで使うとき、

  • モデル学習時の環境がpython3だとpython2環境のROSでの実装が面倒
  • pythonだとどうしても推論の速度が出ない

などで困っていたので、libtorchを使ってモデルをC++で動かしてみたときのメモです。
(ROS2だったらpython3の環境でも使えたけどやっぱり遅い...)

環境

Ubuntu18.04, Ubuntu16.04
ROS Melodic, ROS Kinetic

libtorchのビルド

libtorchはこちらに書いてあるようにダウンロードすることが可能ですが、2019/8/3現在、旧型のABIでビルドされています( https://discuss.pytorch.org/t/sos-libtorch-conflict-with-ros/48977/6 )。

これを用いると、ROSの新型ABIとの競合ができないのでソースからビルドする必要がありました。

ビルドは以下のようにしたらうまくいきました。(https://github.com/pytorch/pytorch/blob/master/docs/libtorch.rst )

$ git clone https://github.com/pytorch/pytorch
$ cd pytorch
$ git submodule update --init --recursive
$ python setup.py build
$ mkdir build_libtorch && cd build_libtorch
$ python ../tools/build_libtorch.py

学習済みモデルの変換

学習済みモデルの変換には、tracingとannotationによる方法がありますが、ここではtracingによる手法の実装例を紹介します。
annotationによる手法はこちらを参照してください。

# 実装例
import torch

model = Model() # nn.Module
model_path = "weight.pkl" # 学習済みの重み
model.load_state_dict(torch.load(model_path))
expample = torch.rand(1, 3, 64, 64) # 入力のshapeに合わせる
traced_net = torch.jit.trace(model, example)
traced_net.save("model.pt")

ROSで学習済みモデルでの推論

rosのパッケージのCMakeList.txtの書き方の例はこんな感じです。
find_package(Torch REQUIRED) の前にcaffe2とtorchのDIRを設定したらmake通りました。

# CMakeLists.txt
cmake_minimum_required(VERSIONS 2.8.3)
project(hoge)

add_compile_options(-std=c++11)

find_package(catkin REQUIRED COMPONENTS
  roscpp
  geometry_msgs
)

set(Caffe2_DIR "$ENV{HOME}/pytorch/torch/share/cmake/Caffe2")
set(Torch_DIR "$ENV{HOME}/pytorch/torch/share/cmake/Torch")
find_package(Torch REQUIRED)

include_directories(
  ${catkin_INCLUDE_DIRS}
)
add_executable(hoge src/hoge.cpp)
target_link_libraries(hope ${catkin_LIBRARIES} ${TORCH_LIBRARIES))

プログラム自体は、単純に先ほど保存したモデルをloadして、forwardにinputを与えるだけで推論してくれます。
C++でのtorchの実装はこちらを参考にしました。

// hoge.cpp
#include "ros/ros.h"
#include <geometry_msgs/Twist.h>
#include <torch/script.h>

class Hoge{
public:
  Hoge();
  void process();
private:
  ros::NodeHnadle nh;
  ros::Publisher vel_pub;
  torch::jit::script::Module module;
};

Hoge::Hoge(){
  vel_pub = nh.advertise<geometry_msgs::Twist>("/cmd_vel", 1);
  module = torch::jit::load("model.pt");
}

void Hoge::process(){
  ros::Rate loop_rate(10);
  while(ros::ok()){
    torch::Tensor input = torch::ones({1, 3, 64, 64}));
    auto output = module.foward({input}).toTensor();
    vel.linear.x = output[0].item<float>();
    vel.angular.z = output[1].item<float>();
    vel_pub.publish(vel);
    loop_rate.sleep();
    ros::spinOnce();
  }
}

int main(int args, char **argv){
  ros::init(args, args, "hoge");
  Hoge hoge;
  hoge.process();
  return 0;
}

参考URL

PYTORCH C++ API
LOADING A PYTORCH MODEL IN C++
PyTorchで学習したモデルをC++から使う @cashiwamochi
pythonで学習したDNNモデルをC++から利用する(PyTorch & libtorch版) @nmatsui

6
7
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
7