11
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorch Geometricのdocsを真面目に読む

Last updated at Posted at 2023-05-28

経緯と内容

研究でPyTorch Geometricを真面目にやることになりそうなので、Introduction by Exampleやその周辺のドキュメントをちゃんと読むことにした。

Introduction by Example

とりあえず読めというやつ。

Data Handling of Graphs

  • graphはオブジェクト(node)とそのつながり方(edge)によって規定される構造である
  • PyGでは、一つのグラフはtorch_geometric.data.Dataのインスタンスとして表現される:
    • data.x: nodeの特徴行列 [num_nodes, num_node_features]
    • data.edge_index: グラフのつながり方をCOO fromat(後述)というshapeが[2, num_edge]、データ型がtorch.longであるようなフォーマットでノード間のつながりを定義する。
    • data.edge_attr:edgeの特徴行列 [num_edges, num_edge_features]
    • data.y: node-levelの教師ラベル[num_nodes, *]もしくはgraph-levelの教師ラベル[1, *]。※node classificationとgraph classificationを同時にやりたいとか、複数タスクを同時に学習させる場合はどう渡すんだろうね※
    • data.pos: nodeの位置行列 [num_nodes, num_dimensions]
  • このattributeのどれも必須ではないし、何なら3Dメッシュを定義するdata.face、みたいな感じで自由に追加できる。torch_geometric.data.Dataクラスを継承して作るんかね?
  • 例えば、unweightedかつundirectedで、1次元のnode featureを持つようなグラフは以下のように定義される:
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

# [0,1,1,2]が始点のnode indexのlistで、[1,0,2,1]が終点のnode indexのlist
# undirectedなので、双方向に渡しているっぽい
# この場合、[0]-[1]-[2]という風につながっている
# 全結合をよく使うから、全結合を簡潔に表せる方法があると嬉しい
# 普通(?)にadjency matrixとして渡すことはできんのか?

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])
  • (source_index, target_index)のリストとしてもedge_indexを渡すことができ、その場合は転置してcontiguousメソッドをcallすればいいらしい
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

image.png

  • print(data)で簡単なinformationを表示できる

    • 今回の例の場合はData(edge_index=[2, 4], x=[3, 1])と出る
  • edge_indexの各要素は{0, 1, ..., num_nodes-1}の範囲にいる必要があり、これを検証するためにはdata.validate(raise_on_error=True)を実行すればよい。

    • 試してみたところ、正常の場合Trueが返る。失敗した場合、以下のようなエラーを吐く
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 1
----> 1 data.validate(raise_on_error=True)

File ~/miniconda3/envs/torch2-pyg/lib/python3.10/site-packages/torch_geometric/data/data.py:565, in Data.validate(self, raise_on_error)
    563     if num_nodes is not None and self.edge_index.max() >= num_nodes:
    564         status = False
--> 565         warn_or_raise(
    566             f"'edge_index' contains larger indices than the number "
    567             f"of nodes ({num_nodes}) in '{cls_name}' "
    568             f"(found {int(self.edge_index.max())})", raise_on_error)
    570 return status

File ~/miniconda3/envs/torch2-pyg/lib/python3.10/site-packages/torch_geometric/data/data.py:990, in warn_or_raise(msg, raise_on_error)
    988 def warn_or_raise(msg: str, raise_on_error: bool = True):
    989     if raise_on_error:
--> 990         raise ValueError(msg)
    991     else:
    992         warnings.warn(msg)

ValueError: 'edge_index' contains larger indices than the number of nodes (3) in 'Data' (found 100)
  • ほかにもいろいろ便利なattributeやmethodが用意されているよ。詳しくはドキュメントを読んでね
print(data.keys)
>>> ['x', 'edge_index']

print(data['x'])
>>> tensor([[-1.0],
            [0.0],
            [1.0]])

for key, item in data:
    print(f'{key} found in data')
>>> x found in data
>>> edge_index found in data

'edge_attr' in data
>>> False

data.num_nodes
>>> 3

data.num_edges
>>> 4

data.num_node_features
>>> 1

data.has_isolated_nodes()
>>> False

data.has_self_loops()
>>> False

data.is_directed()
>>> False

# Transfer data object to GPU.
device = torch.device('cuda')
data = data.to(device)

COO formatについて

Common Benchmark Datasets

  • scikit-learnやpytorchと同様、torch_geometric.datasetsにはいろいろなベンチマーク用のデータセットが実装されている

Mini-batches

  • 複数グラフをbatch単位でハンドリングするために、mini-batchという概念が実装されている
  • data.edge_indexdata.xdata.yをbatchごとにまとめる
    • 例えば、$n$個のグラフをまとめたい場合、以下のようにconcatする。※グラフごとにnum_nodesが異なるのでstackはできない※
\begin{eqnarray}
A = \left[
\begin{array}{ccc}
A_{1} & & \\
& \ddots & \\
& & A_n
\end{array}
\right]
, \quad
X = \left[
\begin{array}{c}
X_1 \\
\vdots \\
X_n
\end{array}
\right]
, \quad
Y = \left[
\begin{array}{c}
Y_1 \\
\vdots \\
Y_n
\end{array}
\right]
\end{eqnarray}
  • ※例えば$X$は、(batch中のnum_nodesの総数, num_features)のような形になる※
  • どのnode indexがそれぞれどのグラフに対応しているかを、data.batchとして与えてあげる
\mathrm{batch} = [0 \ \cdots \ 0 \ 1 \ \cdots \ n-2 \ n-1 \ \cdots \ n-1 ]^{\top}
  • torch.utils.scatter関数を使うことにより、グラフごとの要約統計量を計算したりすることができる
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import scatter

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

i = 0
print("="*100)
for batch in loader:
    print(batch)
    print(batch.num_graphs)
    x = scatter(batch.x, batch.batch, dim=0, reduce="mean")
    print(x.size())
    print("="*100)
    i += 1
    if i == 5:
        break
====================================================================================================
DataBatch(edge_index=[2, 3602], x=[908, 21], y=[32], batch=[908], ptr=[33])
32
torch.Size([32, 21])
====================================================================================================
DataBatch(edge_index=[2, 4084], x=[1042, 21], y=[32], batch=[1042], ptr=[33])
32
torch.Size([32, 21])
====================================================================================================
DataBatch(edge_index=[2, 3904], x=[1037, 21], y=[32], batch=[1037], ptr=[33])
32
torch.Size([32, 21])
====================================================================================================
DataBatch(edge_index=[2, 3852], x=[1135, 21], y=[32], batch=[1135], ptr=[33])
32
torch.Size([32, 21])
====================================================================================================
DataBatch(edge_index=[2, 3996], x=[1047, 21], y=[32], batch=[1047], ptr=[33])
32
torch.Size([32, 21])
====================================================================================================

Data Transforms

  • torchvisionで画像をtransformするのと同様、torch_geometric.transformsにはいくつかの便利なtransformが実装されている
    • torch_geometric.transforms.KNNGraph(k:int)を使うことで、data.posに基づき各nodeと各nodeのk-th neighbor nodesをつなぐ
    • torch_geometric.transforms.RandomJitter(translate:float)を使うことで、data.posをランダムにずらすことができる
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomJitter(0.01))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
  • pre_transformはディスクに保存する前に実行され、transformはディスクにいったんdatasetが保存されてから実行される

Learning Methods on Graphs

Cora datasetをベンチマークとしたGCNの学習の例は以下の通り。

# load dataset
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()

# model definition
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

# training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# evaluation
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
>>> Accuracy: 0.8150

Creating Message Passing Networks

Creating Your Own Datasets

Loading Graphs from CSV

Advanced Mini-Batching

Other Informative Links

11
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?