6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

c++版PyTorchでカスタムデータセットを作る

6
Posted at

最近PyTorchのC++版をいじることが多いのですが、過去にはM1 MacでPytorchのc++版、libtorchをインストールして実行するなどの記事を書きました。

この記事では、英語でもなかなか出てこない情報である、PyTorch C++版でのデータセットとデータローダの作り方について忘備録的な意味でまとめておきます。

バージョン

  • PyTorch 2.0 C++ API

やりたいこと

インプットデータとアウトプットデータを自分で指定した形で取り出したいです。例えば、

  • インプット:std::vector<torch::Tensor>
  • アウトプット:std::vector<torch::Tensor>

のようなデータを取り出すことを考えます。インプットが複数、アウトプットが複数、という状況に対応できれば、多くのことに対応できるでしょう。

カスタムデータセット

例えば、

namespace Dataloader{

    using inputtensors = std::vector<torch::Tensor>;
    using outputtensors = std::vector<torch::Tensor>;
    using Example = torch::data::Example<inputtensors,outputtensors>;

     class SDataset : public torch::data::Dataset<SDataset,Example>
    {
        public:
            explicit SDataset(const size_t & numofdata){
                this -> numofdata = numofdata;
            }
            Example get(size_t index);
            torch::optional<size_t> size() const {
                return numofdata;
            }
        private:
            size_t numofdata;
    };

    Example SDataset::get(size_t index){
        std::vector<torch::Tensor> inputdata(2);
        std::vector<torch::Tensor> outputdata(2);

        inputdata.at(0) = at::linspace((double) index, (double)index+0.3, 4);
        inputdata.at(1) = at::linspace((double) index, (double)index+3.0, 2);
        outputdata.at(0) =  torch::randn({4,2});
        outputdata.at(1) =  torch::randn({3,2});
        return {inputdata,outputdata};
    }
}

のようにします。クラスSDatasettorch::data::Dataset<SDataset,Example>を継承していますが、これによってPyTorchのデータセットとして扱うことができるようになります。なお、このExampleの部分がデータの入力出力形式を指定するもので、この場合は

    using inputtensors = std::vector<torch::Tensor>;
    using outputtensors = std::vector<torch::Tensor>;
    using Example = torch::data::Example<inputtensors,outputtensors>;

としています。torch::data::Exampleがデータセットの形式になります。そして、その入力と出力は<inputtensors,outputtensors>で指定されています。

このデータセットのクラスは、

Example get(size_t index);
torch::optional<size_t> size() const {
    return numofdata;
}

のように、getsizeを最低限実装している必要があります。このサンプルでは適当な実装をしています。

このように定義したクラスを用いると、

auto sd = Dataloader::SDataset(100);
int batch_size = 10;
auto data_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
                            sd, batch_size); 
for (auto& batch : *data_loader ){
    auto c = batch.size();
    auto inputdata = batch.at(0).data;
    auto outputdata = batch.at(0).target;
}

のようにすると、データローダーも利用することができます。ここでは、バッチサイズが10となるランダムバッチを作って取り出す、ということをやっています。バッチとなっている場合にはstd::vectorでデータが束ねられており、Exampleというクラスはdataがインプット、targetがアウトプットとなっています。今回は、これらがそれぞれstd::vector<torch::Tensor>になっているわけです。

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?