最近PyTorchのC++版をいじることが多いのですが、過去にはM1 MacでPytorchのc++版、libtorchをインストールして実行するなどの記事を書きました。
この記事では、英語でもなかなか出てこない情報である、PyTorch C++版でのデータセットとデータローダの作り方について忘備録的な意味でまとめておきます。
バージョン
- PyTorch 2.0 C++ API
やりたいこと
インプットデータとアウトプットデータを自分で指定した形で取り出したいです。例えば、
- インプット:
std::vector<torch::Tensor> - アウトプット:
std::vector<torch::Tensor>
のようなデータを取り出すことを考えます。インプットが複数、アウトプットが複数、という状況に対応できれば、多くのことに対応できるでしょう。
カスタムデータセット
例えば、
namespace Dataloader{
using inputtensors = std::vector<torch::Tensor>;
using outputtensors = std::vector<torch::Tensor>;
using Example = torch::data::Example<inputtensors,outputtensors>;
class SDataset : public torch::data::Dataset<SDataset,Example>
{
public:
explicit SDataset(const size_t & numofdata){
this -> numofdata = numofdata;
}
Example get(size_t index);
torch::optional<size_t> size() const {
return numofdata;
}
private:
size_t numofdata;
};
Example SDataset::get(size_t index){
std::vector<torch::Tensor> inputdata(2);
std::vector<torch::Tensor> outputdata(2);
inputdata.at(0) = at::linspace((double) index, (double)index+0.3, 4);
inputdata.at(1) = at::linspace((double) index, (double)index+3.0, 2);
outputdata.at(0) = torch::randn({4,2});
outputdata.at(1) = torch::randn({3,2});
return {inputdata,outputdata};
}
}
のようにします。クラスSDatasetはtorch::data::Dataset<SDataset,Example>を継承していますが、これによってPyTorchのデータセットとして扱うことができるようになります。なお、このExampleの部分がデータの入力出力形式を指定するもので、この場合は
using inputtensors = std::vector<torch::Tensor>;
using outputtensors = std::vector<torch::Tensor>;
using Example = torch::data::Example<inputtensors,outputtensors>;
としています。torch::data::Exampleがデータセットの形式になります。そして、その入力と出力は<inputtensors,outputtensors>で指定されています。
このデータセットのクラスは、
Example get(size_t index);
torch::optional<size_t> size() const {
return numofdata;
}
のように、getとsizeを最低限実装している必要があります。このサンプルでは適当な実装をしています。
このように定義したクラスを用いると、
auto sd = Dataloader::SDataset(100);
int batch_size = 10;
auto data_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
sd, batch_size);
for (auto& batch : *data_loader ){
auto c = batch.size();
auto inputdata = batch.at(0).data;
auto outputdata = batch.at(0).target;
}
のようにすると、データローダーも利用することができます。ここでは、バッチサイズが10となるランダムバッチを作って取り出す、ということをやっています。バッチとなっている場合にはstd::vectorでデータが束ねられており、Exampleというクラスはdataがインプット、targetがアウトプットとなっています。今回は、これらがそれぞれstd::vector<torch::Tensor>になっているわけです。