LoginSignup
6
5

More than 3 years have passed since last update.

LibTorch(PyTorch C++)APIメモ

Last updated at Posted at 2020-09-23

ほぼ自分用です。随時更新します。
LibTorch環境構築はこの記事を参照。

公式資料が有能なので、まずはそこをチェック。

float型vectorから1次元torch::Tensorへの変換

torch::Tensor tsr = torch::tensor(torch::ArrayRef<float>(vec));

torch::Tensor型vectorからtorch::Tensorへの変換

torch::Tensor tsr = torch::stack(tsr_vec);

1次元torch::Tensorからstd::vectorへの変換

torch::Tensor tsr = torch::randn({ 3 });
tsr= tsr.contiguous();
std::vector<float> v(tsr.data_ptr<float>(), tsr.data_ptr<float>() + tsr.numel());

1次元torch::Tensorのスライス

参考:https://pytorch.org/cppdocs/notes/tensor_indexing.html

namespace ti = torch::indexing;
torch::Tensor tsr2 = tsr.index({ ti::Slice(i, i + slice_size) });

1次元生配列からtorch::Tensorへの変換

// おそらく生配列でなくとも変換可能
// 先頭ポインタ、サイズ、型が分かっており、メモリ上に連続に配置されていればOK?
float raw_arr[] = {1.f, 2.f, 3.f};
torch::Tensor tsr = torch::from_blob(raw_arr/*先頭ポインタ*/, 3/*size*/, torch::TensorOptions().dtype(torch::kFloat32));

GPU使用可能かチェック

torch::DeviceType device_type;
if (torch::cuda::is_available()) {
    std::cout << "CUDA available! Running on GPU.\n";
    device_type = torch::kCUDA;
}
else {
    std::cout << "Running on CPU.\n";
    device_type = torch::kCPU;
}

CPU/GPUへ送る & 型変換

auto model = std::make_shared<Model>();
model->to(torch::kCUDA); // torch::Device型でなくても良い。
model->to(torch::kCPU);

tsr.to(torch::kFloat64);
tsr.to(torch::kCUDA);

tensorサイズ取得

size_t tsr_size = tsr.size(0 /*dim*/);
c10::IntArrayRef tsr_sizes = tsr.sizes();

torch::Tensorから組み込み型の値を取得

// tsrがスカラーでないとエラーを吐くので注意。
// 配列を変換したい場合、tsr.index({i})やtsr[i]で要素にアクセスしてから変換。
double d = tsr.item().toDouble();

その他

  • torch::Tensorは微分可能、at::Tensorは微分不可(らしい)
    • 両者に速度の違いはほとんどないため、基本的にtorch::Tensorを使う
  • torch_cuda.dllが見つからないエラー
    • 作成したexeと同じ場所に、cudaとtorch関連のDLLをコピーしてきたか?
  • カスタムデータローダーをインスタンス化できないエラー
6
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
5