LoginSignup
7
7

More than 1 year has passed since last update.

PyG (PyTorch Geometric) で GAT (Graph Attention Networks)

Last updated at Posted at 2022-07-20

グラフ構造を深層学習する PyG (PyTorch Geometric) を Google Colaboratory 上で使ってみました。まずは GAT (Graph Attention Networks) を用いて、node property prediction (頂点のラベルを予測)です。

PyG (PyTorch Geometric) インストール

PyG (PyTorch Geometric) のレポジトリは https://github.com/pyg-team/pytorch_geometric にあります。

import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

import torch_geometric
1.12.0+cu113
[K     |████████████████████████████████| 7.9 MB 2.9 MB/s 
[K     |████████████████████████████████| 3.5 MB 2.7 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone

データセットのロード

ベンチマークなどで使えるデータセットは https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html に置いてあります。

from torch_geometric import datasets

KarateClub データセット

その中で最もデータサイズの小さい KarateClub データセットを試しに使ってみましょう。

dataset = datasets.KarateClub()
dataset
KarateClub()

次のようにして、データの内容を外観できます。

print("number of graphs:\t\t",len(dataset))
print("number of classes:\t\t",dataset.num_classes)
print("number of node features:\t",dataset.num_node_features)
print("number of edge features:\t",dataset.num_edge_features)
number of graphs:		 1
number of classes:		 4
number of node features:	 34
number of edge features:	 0

データの本体は .data にあります。次のようにして、中に入っているデータの形状(shape)を外観できます。

dataset.data
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

詳細は次の通りです。

print("edge_index:\t\t",dataset.data.edge_index.shape)
print(dataset.data.edge_index)
print("\n")
print("train_mask:\t\t",dataset.data.train_mask.shape)
print(dataset.data.train_mask)
print("\n")
print("x:\t\t",dataset.data.x.shape)
print(dataset.data.x)
print("\n")
print("y:\t\t",dataset.data.y.shape)
print(dataset.data.y)
edge_index:		 torch.Size([2, 156])
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,
          3, 33, 32, 33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
          1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33,  2, 23,
         24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32, 33,  0, 24, 25, 28, 32,
         33,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33,  8,  9, 13, 14, 15,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])


train_mask:		 torch.Size([34])
tensor([ True, False, False, False,  True, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False])


x:		 torch.Size([34, 34])
tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])


y:		 torch.Size([34])
tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0,
        2, 2, 0, 0, 2, 0, 0, 2, 0, 0])

NetworkX を使って可視化します。ノード(頂点)は、目的変数 y に応じて色分けしています。

import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils.convert import to_networkx, from_networkx

G = to_networkx(dataset.data, to_undirected=False)
color = ["r", "g", "m", "c", "y"]

plt.figure(figsize=(8, 8))
pos = nx.spring_layout(G, k=0.3)
node_color = [color[y] for y in dataset.data.y]
nx.draw_networkx_nodes(G, pos, alpha=0.5, node_color=node_color)
nx.draw_networkx_edges(G, pos, alpha=0.1)
nx.draw_networkx_labels(G, pos)
plt.axis("off")
plt.savefig("G.png")
plt.show()

Qiita_PyG_(PyTorch_Geometric)で_GAT(Graph_Attention_Networks)_15_0.png

PubMed データセット

次の例として、生命科学や生物医学に関する文献データベースである PubMed データセットを見てみます。

dataset = datasets.Planetoid(root=".", name="PubMed")
dataset
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!





PubMed()

次のようにして、データの内容を外観できます。

print("number of graphs:\t\t",len(dataset))
print("number of classes:\t\t",dataset.num_classes)
print("number of node features:\t",dataset.num_node_features)
print("number of edge features:\t",dataset.num_edge_features)
number of graphs:		 1
number of classes:		 3
number of node features:	 500
number of edge features:	 0

データの本体は .data にあります。次のようにして、中に入っているデータの形状(shape)を外観できます。

dataset.data
Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])

詳細は次の通りです。

print("edge_index:\t\t",dataset.data.edge_index.shape)
print(dataset.data.edge_index)
print("\n")
print("x:\t\t",dataset.data.x.shape)
print(dataset.data.x)
print("\n")
print("y:\t\t",dataset.data.y.shape)
print(dataset.data.y)
edge_index:		 torch.Size([2, 88648])
tensor([[    0,     0,     0,  ..., 19714, 19715, 19716],
        [ 1378,  1544,  6092,  ..., 12278,  4284, 16030]])


x:		 torch.Size([19717, 500])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.1046, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0194, 0.0080,  ..., 0.0000, 0.0000, 0.0000],
        [0.1078, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0266, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


y:		 torch.Size([19717])
tensor([1, 1, 0,  ..., 2, 0, 2])

NetworkX でネットワークを描画したいところですが、描画してもよく分からないので省略します。

train_val_test_split

ノードラベルを train, val, test に分割する train_val_test_split を次のように実装します。

def train_val_test_split(data, val_ratio: float = 0.15,
                             test_ratio: float = 0.15):
    rnd = torch.rand(len(data.x))
    train_mask = [False if (x > val_ratio + test_ratio) else True for x in rnd]
    val_mask = [False if (val_ratio + test_ratio >= x) and (x > test_ratio) else True for x in rnd]
    test_mask = [False if (test_ratio >= x) else True for x in rnd]
    return torch.tensor(train_mask), torch.tensor(val_mask), torch.tensor(test_mask)

GAT (Graph Attention Networks)

GATConv https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv を使った GAT (Graph Attention Networks) クラスを実装します。

import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.hid = 8
        self.in_head = 8
        self.out_head = 1
        
        
        self.conv1 = GATConv(dataset.num_features, self.hid, heads=self.in_head, dropout=0.6)
        self.conv2 = GATConv(self.hid*self.in_head, dataset.num_classes, concat=False,
                             heads=self.out_head, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
                
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        
        return F.log_softmax(x, dim=1)

cuda が使えれば cuda を、使えなければ cpu を使うようにします。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"

KarateClub データセットのノードラベル予測

データセットとモデルを次のようにします。

dataset = datasets.KarateClub()
data = dataset[0].to(device)
model = GAT().to(device)

train, val, test データセットを次のようにセットします。

train_mask, val_mask, test_mask = train_val_test_split(data)

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

次のようにして学習します。

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

losses = []
for epoch in range(2000):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    losses.append(loss.detach().numpy())

    if epoch % 200 == 0:
        print(loss)
    
    loss.backward()
    optimizer.step()
tensor(1.4991, grad_fn=<NllLossBackward0>)
tensor(0.7117, grad_fn=<NllLossBackward0>)
tensor(0.9274, grad_fn=<NllLossBackward0>)
tensor(0.4589, grad_fn=<NllLossBackward0>)
tensor(0.4229, grad_fn=<NllLossBackward0>)
tensor(0.4132, grad_fn=<NllLossBackward0>)
tensor(0.1550, grad_fn=<NllLossBackward0>)
tensor(0.7361, grad_fn=<NllLossBackward0>)
tensor(0.5670, grad_fn=<NllLossBackward0>)
tensor(0.3323, grad_fn=<NllLossBackward0>)

学習曲線は次のようになりました。

import matplotlib.pyplot

plt.plot(losses, alpha=0.8)

Qiita_PyG_(PyTorch_Geometric)で_GAT(Graph_Attention_Networks)_38_1.png

今回、目的変数 y は4種類のラベルになりますが、その予測結果は次のようになります。

model(data)
tensor([[-8.9359e+00, -1.1426e+00, -5.7290e+00, -3.8914e-01],
        [-6.4364e+00, -3.2092e-01, -1.4640e+00, -3.1798e+00],
        [-5.8728e+00, -3.6128e-01, -1.4878e+00, -2.5966e+00],
        [-7.9264e+00, -1.9967e-01, -1.9631e+00, -3.2134e+00],
        [-1.4734e+01, -5.6874e+00, -1.6281e+01, -3.3947e-03],
        [-7.4613e+00, -2.6100e+00, -5.5670e+00, -8.1134e-02],
        [-8.0405e+00, -2.9955e+00, -9.3368e+00, -5.1738e-02],
        [-9.7762e+00, -2.5239e-01, -1.5019e+00, -8.0933e+00],
        [-1.0885e+01, -8.3299e-02, -3.1077e+00, -3.3467e+00],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-8.8328e+00, -2.0148e+00, -8.0043e+00, -1.4367e-01],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-6.7221e+00, -1.1259e-01, -2.5291e+00, -3.6671e+00],
        [-6.8040e+00, -1.8940e-01, -1.9213e+00, -3.6878e+00],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-1.7997e+01, -7.4172e+00, -1.4998e+01, -6.0135e-04],
        [-1.4183e+00, -1.3070e+00, -1.4834e+00, -1.3456e+00],
        [-2.5639e-04, -1.7825e+01, -8.2728e+00, -1.3858e+01],
        [-2.5783e+00, -1.0899e+00, -6.1406e-01, -3.0638e+00],
        [-1.6390e-04, -1.8391e+01, -8.7232e+00, -1.3795e+01],
        [-2.4957e+00, -1.0830e+00, -6.0055e-01, -3.4909e+00],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-3.1565e-01, -2.9938e+00, -1.5305e+00, -5.4833e+00],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-1.4758e+00, -1.2810e+00, -1.5274e+00, -1.2853e+00],
        [-2.1341e-01, -6.6087e+00, -1.6809e+00, -5.3788e+00],
        [-4.2086e+00, -2.8116e+00, -7.8278e-02, -8.0508e+00],
        [-2.2204e+00, -1.5440e+00, -4.3808e-01, -3.4222e+00],
        [-6.4373e-06, -2.5067e+01, -1.1945e+01, -1.6719e+01],
        [-2.7569e+00, -1.2134e+00, -5.0864e-01, -3.2697e+00],
        [-1.1346e+01, -5.7865e-01, -8.5537e-01, -4.2540e+00],
        [-3.9678e-01, -4.7832e+00, -1.2040e+00, -3.9543e+00],
        [-1.4352e+00, -3.7600e+00, -3.2454e-01, -4.1481e+00]],
       grad_fn=<LogSoftmaxBackward0>)

学習済みのモデルの性能は次のように求まります。

model.eval()
_, pred = model(data).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
Accuracy: 0.8333

PubMed データセットのノードラベル予測

データセットとモデルを次のようにします。

dataset = datasets.Planetoid(root=".", name="PubMed")
data = dataset[0].to(device)
model = GAT().to(device)

train, val, test データセットを次のようにセットします。

train_mask, val_mask, test_mask = train_val_test_split(data)

data.train_mask=train_mask
data.val_mask=val_mask
data.test_mask=test_mask

次のようにして学習します。

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

losses = []
for epoch in range(1000):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    losses.append(loss.detach().numpy())

    if epoch % 200 == 0:
        print(loss)
    
    loss.backward()
    optimizer.step()
tensor(1.0939, grad_fn=<NllLossBackward0>)
tensor(0.6320, grad_fn=<NllLossBackward0>)
tensor(0.6077, grad_fn=<NllLossBackward0>)
tensor(0.6039, grad_fn=<NllLossBackward0>)
tensor(0.6311, grad_fn=<NllLossBackward0>)
tensor(0.6243, grad_fn=<NllLossBackward0>)
tensor(0.6043, grad_fn=<NllLossBackward0>)
tensor(0.6226, grad_fn=<NllLossBackward0>)
tensor(0.6192, grad_fn=<NllLossBackward0>)
tensor(0.6110, grad_fn=<NllLossBackward0>)

学習曲線は次のようになりました。

import matplotlib.pyplot

plt.plot(losses, alpha=0.8)

Qiita_PyG_(PyTorch_Geometric)で_GAT(Graph_Attention_Networks)_50_1.png

学習済みのモデルの性能は次のように求まります。

model.eval()
_, pred = model(data).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
Accuracy: 0.8562

おわり

また、そのうち PyG (PyTorch Geometric) についていろいろ試してみたいと思います。

7
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
7