はじめに
pytorch_geometric のデータ形式がなんとも見にくいのでうまく可視化する方法を調査した備忘録です.
環境
- python 3.12
- torch 2.4.1
- torch_geometric 2.6.1
- networkx 3.2.1
pytorch_geometric のデータ形式
pytorch_geometric のデータ形式は エッジの始点及び終点で表される配列edge_index
とノード上の信号を表す配列 x
からなります.
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
このグラフを可視化すると下図のような感じになります.
networkx によるグラフの可視化
pytorch_geometric の Data
オブジェクトを networkx の Graph
オブジェクトに変換する関数があるので,それを使います.
from torch_geometric.utils import to_networkx
g = to_networkx(data)
# 可視化
nx.draw(G=g, with_labels=True)
nx.draw()
により,先ほど示したグラフの図が表示されます(再掲).
from_networkx()
により,pytorch_geometric の Data
オブジェクトに戻すことができますが,グラフ信号は失われます.
from torch_geometric.utils import from_networkx
data = from_networkx(g)
print(data.edge_index)
# >>> tensor([[0, 1, 1, 2],
# [1, 0, 2, 1]])
print(data.x)
# >>> None
他にも networkit, trimesh, cugraph, dgl, rdmol といったライブラリとの相互変換もできるようです.
詳しくは公式ドキュメントをご参照ください.
torch_geometric.utils
の下の方にいろいろ載ってます.