2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorch Geometricのデータセットの自作するための簡単なまとめ

Last updated at Posted at 2023-11-25

PyGはグラフの機械学習に有用だが、データセットの自作が最初は難しかったので最低限必要な知識を記しておく。

InMemoryDataset

CPUメモリに収まるデータセットを構成できる便利なクラス。

概観

最初にインスタンスを生成すると、processメソッドが呼び出されてデータセットが処理される。
torch.save()でデータがdata.ptに保存され、次回以降高速に読み込める。

最低限知っておけばよいメソッド

  • raw_file_names
    処理前の生のデータのファイルがあればここに書いておく。

  • processed_file_names
    処理されたデータが、__init__で渡したディレクトリの下にこのファイル名で保存される。

  • process
    グラフのノード, エッジ, 重み, 特徴量の設定などを行う。基本的にtorch_geometric.data.Dataクラスでグラフを表現し、torch.save()でファイルに保存される。

import torch
from torch_geometric.data import Data, InMemoryDataset


class MyDataset(InMemoryDataset):
    def __init__(self, root="データセットを保存するパス") -> None:
        super(MyDataset, self).__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> list:
        return []

    @property
    def processed_file_names(self) -> list:
        return ["data.pt"]

    def process(self) -> None:
        data = Data(...)
        self.data, self.slices = self.collate([data])

        torch.save((self.data, self.slices), self.processed_paths[0])

Dataクラスの使い方

グラフの表現にはtorch_geometric.data.Dataクラスを使用する(素のPyTorchにもDataクラスがあるので一応区別して書く)。

最低限知っておけばよいアトリビュート

スクリーンショット 2023-11-26 9.17.34.png

  • x: ノードの特徴量、例えば座標など。
    インデックスがノード番号に対応する。
x = torch.tensor(
    [
        [0, 1],
        [1, 2],
        [1, 0],
        [2, 2],
        [2, 0]
    ]
)
# shape: (ノード数, ノードの特徴量の次元)
  • edge_index: エッジのインデックス(COO format)。エッジの始点を行ベクトルで並べ、同じ順で終点を並べる。
edge_index = torch.tensor(
    [
        [0, 0, 1, 2, 4],
        [1, 2, 4, 4, 3]
    ]
)
# shape: (エッジ数, 2)
  • edge_attr: エッジの特徴量、例えば重みなど。
    edge_indexと同じ順で対応させる。
edge_attr = torch.tensor(
    [
        [1, 3, 2, 6, 1]
    ]
)
# shape: (エッジ数, エッジの特徴量の次元)
  • y: ノードのラベル、例えば属するコミュニティなど。予測タスクのための学習などに用いる。xと同じ順で、ノード番号に対応させる。
y = torch.tensor(
    [
        [0, 0, 0, 1, 1]
    ]
)
# shape: (ノード数, 任意)
2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?