PyGはグラフの機械学習に有用だが、データセットの自作が最初は難しかったので最低限必要な知識を記しておく。
InMemoryDataset
CPUメモリに収まるデータセットを構成できる便利なクラス。
概観
最初にインスタンスを生成すると、process
メソッドが呼び出されてデータセットが処理される。
torch.save()
でデータがdata.pt
に保存され、次回以降高速に読み込める。
最低限知っておけばよいメソッド
-
raw_file_names
処理前の生のデータのファイルがあればここに書いておく。 -
processed_file_names
処理されたデータが、__init__
で渡したディレクトリの下にこのファイル名で保存される。 -
process
グラフのノード, エッジ, 重み, 特徴量の設定などを行う。基本的にtorch_geometric.data.Data
クラスでグラフを表現し、torch.save()
でファイルに保存される。
import torch
from torch_geometric.data import Data, InMemoryDataset
class MyDataset(InMemoryDataset):
def __init__(self, root="データセットを保存するパス") -> None:
super(MyDataset, self).__init__(root)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self) -> list:
return []
@property
def processed_file_names(self) -> list:
return ["data.pt"]
def process(self) -> None:
data = Data(...)
self.data, self.slices = self.collate([data])
torch.save((self.data, self.slices), self.processed_paths[0])
Dataクラスの使い方
グラフの表現にはtorch_geometric.data.Data
クラスを使用する(素のPyTorchにもData
クラスがあるので一応区別して書く)。
最低限知っておけばよいアトリビュート
-
x
: ノードの特徴量、例えば座標など。
インデックスがノード番号に対応する。
x = torch.tensor(
[
[0, 1],
[1, 2],
[1, 0],
[2, 2],
[2, 0]
]
)
# shape: (ノード数, ノードの特徴量の次元)
-
edge_index
: エッジのインデックス(COO format)。エッジの始点を行ベクトルで並べ、同じ順で終点を並べる。
edge_index = torch.tensor(
[
[0, 0, 1, 2, 4],
[1, 2, 4, 4, 3]
]
)
# shape: (エッジ数, 2)
-
edge_attr
: エッジの特徴量、例えば重みなど。
edge_index
と同じ順で対応させる。
edge_attr = torch.tensor(
[
[1, 3, 2, 6, 1]
]
)
# shape: (エッジ数, エッジの特徴量の次元)
-
y
: ノードのラベル、例えば属するコミュニティなど。予測タスクのための学習などに用いる。x
と同じ順で、ノード番号に対応させる。
y = torch.tensor(
[
[0, 0, 0, 1, 1]
]
)
# shape: (ノード数, 任意)