6
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyG (PyTorch Geometric) で Graph Autoencoders (GAE) と Variational Graph Autoencoders (VGAE)

Posted at

グラフ構造を深層学習する PyG (PyTorch Geometric) を Google Colaboratory 上で使ってみました。今回は、Graph Autoencoders (GAE) と Variational Graph Autoencoders (VGAE) を使うことがテーマです。

PyG (PyTorch Geometric) インストール

PyG (PyTorch Geometric) のレポジトリは https://github.com/pyg-team/pytorch_geometric にあります。また、コードはチュートリアルドキュメント https://pytorch-geometric.readthedocs.io/en/latest/index.html を参考にしています。

import os
import torch

torch.manual_seed(53)
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

import torch_geometric
1.12.0+cu113
[K     |████████████████████████████████| 7.9 MB 6.8 MB/s 
[K     |████████████████████████████████| 3.5 MB 6.5 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"

データセットの自作

今回もデータセットを自作します。自作する理由は、生成したネットワーク(グラフ構造)が「本物っぽいかどうか」自分の目で判断しやすくするためです。

GridDataset

格子状のネットワークのデータセットを作ってみました。

import numpy as np
from scipy.spatial import distance
from torch_geometric.data import Data, InMemoryDataset

class GridDataset(InMemoryDataset):
    def __init__(self, transform = None):
        super().__init__('.', transform)

        f = lambda x: np.linalg.norm(x) - np.arctan2(x[0], x[1])
        embeddings = []
        ys = []
        for x in range(-10, 11, 2):
            for y in range(-10, 11, 2):
                embeddings.append([x, y])
                ys.append(f([x, y]))
        embeddings = torch.tensor(embeddings, dtype=torch.float)
        y2 = []
        for y in ys:
            if y > np.array(ys).mean():
                y2.append(1)
            else:
                y2.append(0)
        ys = torch.tensor(y2, dtype=torch.float)

        dist_matrix = distance.cdist(embeddings, embeddings, metric='euclidean')
        edges = []
        edge_attr = []
        for i in range(len(dist_matrix)):
            for j in range(len(dist_matrix)):
                if i < j:
                    if dist_matrix[i][j] == 2:
                        edges.append([i, j])
                        edge_attr.append(abs(f(embeddings[i]) - f(embeddings[j])))
                    elif dist_matrix[i][j] < 3 and (
                        embeddings[i][0] == embeddings[j][1] or
                        embeddings[i][1] == embeddings[j][0]
                    ):
                        edges.append([i, j])
                        edge_attr.append(abs(f(embeddings[i]) - f(embeddings[j])))

        edges = torch.tensor(edges, dtype=torch.long).T
        edge_attr = torch.tensor(edge_attr, dtype=torch.long)
        data = Data(x=embeddings, edge_index=edges, y=ys, edge_attr=edge_attr)
        self.data, self.slices = self.collate([data])
        self.data.num_nodes = len(embeddings)

    def layout(self):
        return {i:x.detach().numpy() for i, x in enumerate(self.data.x)}

    def node_color(self):
        c = {0:"red", 1:"blue"}
        return [c[int(x.detach().numpy())] for (i, x) in enumerate(self.data.y)]

NetworkX で可視化すると、このようなネットワークになります。

import networkx as nx
import matplotlib.pyplot as plt

dataset = GridDataset()
G = torch_geometric.utils.convert.to_networkx(dataset.data)
plt.figure(figsize=(12,12))
nx.draw_networkx(G, pos=dataset.layout(), with_labels=False, alpha=0.5, node_color=dataset.node_color())

GAEとVGAE_7_0.png

ColonyDataset

上のような構造のネットワークは、現実世界ではあまり登場しませんので、今度はもうちょっと次数に偏りのあるネットワークを作ってみました。ColonyDataset と命名しましたが、なんとなく大腸菌コロニーを連想したので。いや、似てないって?すみません。

import numpy as np
from scipy.spatial import distance
from torch_geometric.data import Data, InMemoryDataset

class ColonyDataset(InMemoryDataset):
    def __init__(self, transform = None):
        super().__init__('.', transform)

        f = lambda x: np.linalg.norm(x) - np.arctan2(x[0], x[1])
        embeddings = []
        ys = []

        for x in range(-10, 11, 5):
            for y in range(-10, 11, 5):
                embeddings.append([x, y])
                ys.append(f([x, y]))
                for theta in range(max(0, 15 - abs(x) - abs(y))):
                    x2 = x + np.sin(theta + np.random.rand()) * abs(17 - theta) * 0.1
                    y2 = y + np.cos(theta + np.random.rand()) * abs(17 - theta) * 0.1
                    embeddings.append([x2, y2])
                    ys.append(f([x2, y2]))
                
        embeddings = torch.tensor(embeddings, dtype=torch.float)
        y2 = []
        for y in ys:
            if y > np.array(ys).mean():
                y2.append(1)
            else:
                y2.append(0)
        ys = torch.tensor(y2, dtype=torch.float)

        dist_matrix = distance.cdist(embeddings, embeddings, metric='euclidean')
        edges = []
        edge_attr = []
        for i in range(len(dist_matrix)):
            for j in range(len(dist_matrix)):
                if i < j:
                    if dist_matrix[i][j] == 5 or dist_matrix[i][j] < 2:
                        edges.append([i, j])
                        edge_attr.append(abs(f(embeddings[i]) - f(embeddings[j])))

        edges = torch.tensor(edges).T
        edge_attr = torch.tensor(edge_attr)
        data = Data(x=embeddings, edge_index=edges, y=ys, edge_attr=edge_attr)
        self.data, self.slices = self.collate([data])
        self.data.num_nodes = len(embeddings)

    def layout(self):
        return {i:x.detach().numpy() for i, x in enumerate(self.data.x)}
    
    def node_color(self):
        c = {0:"red", 1:"blue"}
        return [c[int(x.detach().numpy())] for (i, x) in enumerate(self.data.y)]

NetworkX で可視化すると、このようなネットワークになります。

import networkx as nx
import matplotlib.pyplot as plt

dataset = ColonyDataset()
G = torch_geometric.utils.convert.to_networkx(dataset.data)
plt.figure(figsize=(12,12))
nx.draw_networkx(G, pos=dataset.layout(), with_labels=False, alpha=0.5, node_color=dataset.node_color())

GAEとVGAE_11_0.png

データセットの選択

次のようにして、どちらのデータセットを使うか選択します。

use_dataset = GridDataset
use_dataset = ColonyDataset

今回は、辺(エッジ)の有無を予測するため、それを train test に分離します。

dataset = use_dataset()
data = dataset.data
data = torch_geometric.utils.train_test_split_edges(data)
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'train_test_split_edges' is deprecated, use 'transforms.RandomLinkSplit' instead
  warnings.warn(out)

Graph Autoencoders (GAE)

まずは Graph Autoencoder (GAE) です。

GCNConv を重ねて、エンコーダー部分を作ります。

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = torch_geometric.nn.GCNConv(
            in_channels, 4 * out_channels, cached=True
            ) 
        self.conv1b = torch_geometric.nn.GCNConv(
            4 * out_channels, 16 * out_channels, cached=True
            ) 
        self.conv1c = torch_geometric.nn.GCNConv(
            16 * out_channels, 32 * out_channels, cached=True
            ) 
        self.conv1d = torch_geometric.nn.GCNConv(
            32 * out_channels, 16 * out_channels, cached=True
            ) 
        self.conv1e = torch_geometric.nn.GCNConv(
            16 * out_channels, 4 * out_channels, cached=True
            ) 
        self.conv2 = torch_geometric.nn.GCNConv(
            4 * out_channels, out_channels, cached=True
            ) 

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv1b(x, edge_index).relu()
        x = self.conv1c(x, edge_index).relu()
        x = self.conv1d(x, edge_index).relu()
        x = self.conv1e(x, edge_index).relu()
        return self.conv2(x, edge_index)

トレーニングとテストのコードです。

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss.backward()
    optimizer.step()
    return float(loss)

def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return model.test(z, pos_edge_index, neg_edge_index)

GAE のエンコーダー部分に GCNEncoder をセットします。そのほかの諸設定を行います。

epochs = 500
out_channels = 16

num_features = dataset.num_features
model = torch_geometric.nn.GAE(GCNEncoder(num_features, out_channels))
model = model.to(device)

x = data.x.to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

次のようにして学習します。ベストモデルが best_model として保存されます。

import copy

loss_hist = []
auc_hist = []
ap_hist = []
best_score = None
for epoch in range(1, epochs + 1):
    loss = train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    loss_hist.append(loss)
    auc_hist.append(auc)
    ap_hist.append(ap)
    if best_score is None or best_score < ap:
        best_score = ap
        best_model = copy.deepcopy(model)
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}, Loss: {}'.format(epoch, auc, ap, loss))
Epoch: 001, AUC: 0.3816, AP: 0.4461, Loss: 1.3883142471313477
Epoch: 002, AUC: 0.4190, AP: 0.5396, Loss: 28.98647689819336
Epoch: 003, AUC: 0.4468, AP: 0.5558, Loss: 2.6154823303222656
Epoch: 007, AUC: 0.5818, AP: 0.6184, Loss: 1.3535693883895874
Epoch: 008, AUC: 0.6011, AP: 0.6215, Loss: 1.3591862916946411
Epoch: 009, AUC: 0.6011, AP: 0.6249, Loss: 1.356323003768921
Epoch: 019, AUC: 0.6273, AP: 0.6257, Loss: 1.1663818359375
Epoch: 020, AUC: 0.6875, AP: 0.6798, Loss: 1.1575605869293213
Epoch: 021, AUC: 0.7307, AP: 0.7208, Loss: 1.151369571685791
Epoch: 022, AUC: 0.7616, AP: 0.7486, Loss: 1.1134968996047974
Epoch: 024, AUC: 0.7400, AP: 0.7662, Loss: 1.162852168083191
Epoch: 025, AUC: 0.7739, AP: 0.8188, Loss: 1.1479345560073853
Epoch: 026, AUC: 0.7971, AP: 0.8392, Loss: 1.088531494140625
Epoch: 028, AUC: 0.8187, AP: 0.8483, Loss: 1.163226842880249
Epoch: 029, AUC: 0.8241, AP: 0.8636, Loss: 1.1547914743423462
Epoch: 031, AUC: 0.8403, AP: 0.8695, Loss: 1.1620863676071167
Epoch: 047, AUC: 0.8557, AP: 0.8755, Loss: 1.0460742712020874
Epoch: 048, AUC: 0.8642, AP: 0.8834, Loss: 1.0296823978424072
Epoch: 054, AUC: 0.8819, AP: 0.8876, Loss: 0.9689010381698608
Epoch: 055, AUC: 0.8927, AP: 0.9002, Loss: 0.9880993366241455
Epoch: 056, AUC: 0.8966, AP: 0.9042, Loss: 0.9827735424041748
Epoch: 078, AUC: 0.8943, AP: 0.9057, Loss: 0.9343335628509521
Epoch: 079, AUC: 0.9005, AP: 0.9136, Loss: 0.9135234951972961
Epoch: 121, AUC: 0.8997, AP: 0.9142, Loss: 0.9891144037246704
Epoch: 128, AUC: 0.9043, AP: 0.9154, Loss: 0.9204936027526855
Epoch: 132, AUC: 0.9144, AP: 0.9266, Loss: 0.9302182197570801
Epoch: 147, AUC: 0.9151, AP: 0.9295, Loss: 0.8523568511009216
Epoch: 154, AUC: 0.9182, AP: 0.9295, Loss: 0.8805981278419495
Epoch: 157, AUC: 0.9136, AP: 0.9296, Loss: 0.8795582056045532
Epoch: 163, AUC: 0.9244, AP: 0.9356, Loss: 0.8179492354393005
Epoch: 164, AUC: 0.9252, AP: 0.9393, Loss: 0.8485832810401917
Epoch: 198, AUC: 0.9267, AP: 0.9437, Loss: 0.8306972980499268
Epoch: 237, AUC: 0.9190, AP: 0.9469, Loss: 0.7869285941123962
Epoch: 241, AUC: 0.9213, AP: 0.9482, Loss: 0.7720121145248413
Epoch: 243, AUC: 0.9236, AP: 0.9491, Loss: 0.7884184122085571
Epoch: 244, AUC: 0.9236, AP: 0.9499, Loss: 0.7504099607467651
Epoch: 247, AUC: 0.9298, AP: 0.9521, Loss: 0.7511246800422668
Epoch: 248, AUC: 0.9375, AP: 0.9549, Loss: 0.7275109887123108
Epoch: 250, AUC: 0.9414, AP: 0.9558, Loss: 0.7692357897758484
Epoch: 458, AUC: 0.9282, AP: 0.9560, Loss: 0.6997581720352173
Epoch: 459, AUC: 0.9298, AP: 0.9564, Loss: 0.7822049260139465
Epoch: 460, AUC: 0.9321, AP: 0.9574, Loss: 0.7706275582313538
Epoch: 461, AUC: 0.9321, AP: 0.9580, Loss: 0.7076061367988586

学習の履歴は次のようになりました。

import matplotlib.pyplot as plt

plt.title("GAE")
plt.plot(loss_hist, label="Loss")
plt.grid()
plt.legend()
plt.yscale('log')
plt.show()
plt.title("GAE")
plt.plot(auc_hist, label="AUC")
plt.plot(ap_hist, label="AP")
plt.grid()
plt.legend()
plt.show()

GAEとVGAE_25_0.png

GAEとVGAE_25_1.png

学習はできたので、次は(辺の)予測です。エンコーダーで算出した z を利用して、予測された隣接行列(っぽいもの)を算出します。

z = best_model.encode(x, train_pos_edge_index)
prob_adj = z @ z.T
prob_adj = prob_adj - torch.diagonal(prob_adj)
prob_adj
tensor([[ 0.0000, -0.0693, -2.5038,  ..., -3.5280, -2.4522, -1.7510],
        [ 0.0406,  0.0000, -2.2706,  ..., -3.5723, -2.5015, -1.7879],
        [ 0.4745,  0.5979,  0.0000,  ..., -3.6429, -2.6418, -1.8696],
        ...,
        [-0.9027, -1.0568, -3.9960,  ...,  0.0000,  0.4723,  0.7083],
        [-0.9423, -1.1015, -4.1102,  ..., -0.6431,  0.0000,  0.2936],
        [-0.8921, -1.0387, -3.9890,  ..., -1.0580, -0.3573,  0.0000]],
       grad_fn=<SubBackward0>)

この値が閾値以上の場合に「隣接する」と考えたいと思います。閾値 threshold は 0 でも良いかもしれませんが、今回はこのように算出してみました。

prob_adj_values = prob_adj.detach().cpu().numpy().flatten()
prob_adj_values.sort()
dataset = use_dataset()
threshold = prob_adj_values[-len(dataset.data.edge_attr)]
dataset.data.edge_index = (prob_adj >= threshold).nonzero(as_tuple=False).t()

以上のように予測された辺を用いて、ネットワークを生成します。

import networkx as nx
import matplotlib.pyplot as plt

G = torch_geometric.utils.convert.to_networkx(dataset.data)
plt.figure(figsize=(12,12))
nx.draw_networkx(G, pos=dataset.layout(), with_labels=False, alpha=0.5, node_color=dataset.node_color())

GAEとVGAE_31_0.png

...うーん、元のネットワークの特徴を残してる気がしないでもないですが...まだ改良の余地ありそうですね。

Variational Graph Autoencoders (VGAE)

次は Variational Graph Autoencoders (VGAE) です。先ほどと同様に、エンコーダー部分に入れる VariationalGCNEncoder を作ります。

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = torch_geometric.nn.GCNConv(
            in_channels, 4 * out_channels, cached=True
            ) 
        self.conv1b = torch_geometric.nn.GCNConv(
            4 * out_channels, 16 * out_channels, cached=True
            ) 
        self.conv1c = torch_geometric.nn.GCNConv(
            16 * out_channels, 32 * out_channels, cached=True
            ) 
        self.conv1d = torch_geometric.nn.GCNConv(
            32 * out_channels, 16 * out_channels, cached=True
            ) 
        self.conv1e = torch_geometric.nn.GCNConv(
            16 * out_channels, 4 * out_channels, cached=True
            ) 
        self.conv_mu = torch_geometric.nn.GCNConv(
            4 * out_channels, out_channels, cached=True
            )
        self.conv_logstd = torch_geometric.nn.GCNConv(
            4 * out_channels, out_channels, cached=True
            )

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv1b(x, edge_index).relu()
        x = self.conv1c(x, edge_index).relu()
        x = self.conv1d(x, edge_index).relu()
        x = self.conv1e(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

トレーニングとテストのコードです。

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)

def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return model.test(z, pos_edge_index, neg_edge_index)

VGAE のエンコーダー部分に VariationalGCNEncoder をセットします。そのほかの諸設定を行います。

epochs = 500
out_channels = 16

num_features = dataset.num_features
model = torch_geometric.nn.VGAE(
    VariationalGCNEncoder(num_features, out_channels)
    )
model = model.to(device)

x = data.x.to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

次のようにして学習します。ベストモデルが best_model として保存されます。

import copy

loss_hist = []
auc_hist = []
ap_hist = []
best_score = None
for epoch in range(1, epochs + 1):
    loss = train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    loss_hist.append(loss)
    auc_hist.append(auc)
    ap_hist.append(ap)
    if best_score is None or best_score < ap:
        best_score = ap
        best_model = copy.deepcopy(model)
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}, Loss: {}'.format(epoch, auc, ap, loss))
Epoch: 001, AUC: 0.3731, AP: 0.4432, Loss: 3.864459276199341
Epoch: 002, AUC: 0.4028, AP: 0.5502, Loss: 18.978899002075195
Epoch: 003, AUC: 0.4282, AP: 0.5533, Loss: 5.359278202056885
Epoch: 004, AUC: 0.4151, AP: 0.5543, Loss: 2.7068936824798584
Epoch: 005, AUC: 0.4082, AP: 0.5561, Loss: 3.064584493637085
Epoch: 015, AUC: 0.4846, AP: 0.5571, Loss: 1.5492334365844727
Epoch: 016, AUC: 0.4715, AP: 0.5571, Loss: 1.5692222118377686
Epoch: 017, AUC: 0.4946, AP: 0.5910, Loss: 1.55221426486969
Epoch: 018, AUC: 0.5679, AP: 0.6442, Loss: 1.5478882789611816
Epoch: 019, AUC: 0.6620, AP: 0.7120, Loss: 1.5624383687973022
Epoch: 020, AUC: 0.7137, AP: 0.7338, Loss: 1.5453048944473267
Epoch: 021, AUC: 0.7353, AP: 0.7453, Loss: 1.506240725517273
Epoch: 022, AUC: 0.7623, AP: 0.7691, Loss: 1.4864426851272583
Epoch: 023, AUC: 0.7600, AP: 0.7808, Loss: 1.468309760093689
Epoch: 024, AUC: 0.7600, AP: 0.7853, Loss: 1.4616118669509888
Epoch: 025, AUC: 0.7454, AP: 0.7881, Loss: 1.4413743019104004
Epoch: 029, AUC: 0.7585, AP: 0.8036, Loss: 1.4201958179473877
Epoch: 030, AUC: 0.7778, AP: 0.8204, Loss: 1.2653322219848633
Epoch: 031, AUC: 0.7878, AP: 0.8252, Loss: 1.299864649772644
Epoch: 033, AUC: 0.7994, AP: 0.8289, Loss: 1.2688895463943481
Epoch: 051, AUC: 0.8040, AP: 0.8343, Loss: 1.2111725807189941
Epoch: 052, AUC: 0.8056, AP: 0.8437, Loss: 1.2464674711227417
Epoch: 053, AUC: 0.8102, AP: 0.8484, Loss: 1.22722327709198
Epoch: 054, AUC: 0.8318, AP: 0.8598, Loss: 1.2544875144958496
Epoch: 055, AUC: 0.8511, AP: 0.8734, Loss: 1.2202423810958862
Epoch: 056, AUC: 0.8627, AP: 0.8793, Loss: 1.2152200937271118
Epoch: 057, AUC: 0.8742, AP: 0.8829, Loss: 1.2272169589996338
Epoch: 062, AUC: 0.8711, AP: 0.8852, Loss: 1.1996705532073975
Epoch: 077, AUC: 0.8881, AP: 0.8916, Loss: 1.1212635040283203
Epoch: 078, AUC: 0.8943, AP: 0.8941, Loss: 1.137722373008728
Epoch: 085, AUC: 0.8989, AP: 0.9014, Loss: 1.1602625846862793
Epoch: 094, AUC: 0.8935, AP: 0.9029, Loss: 1.0816481113433838
Epoch: 099, AUC: 0.9020, AP: 0.9099, Loss: 1.131610631942749
Epoch: 100, AUC: 0.9028, AP: 0.9101, Loss: 1.047181487083435
Epoch: 107, AUC: 0.9020, AP: 0.9104, Loss: 1.1225495338439941
Epoch: 112, AUC: 0.9005, AP: 0.9166, Loss: 1.0279103517532349
Epoch: 116, AUC: 0.9020, AP: 0.9196, Loss: 1.0380473136901855
Epoch: 120, AUC: 0.9059, AP: 0.9242, Loss: 1.1030912399291992
Epoch: 125, AUC: 0.9051, AP: 0.9244, Loss: 0.9978398084640503
Epoch: 126, AUC: 0.9090, AP: 0.9296, Loss: 1.039618730545044
Epoch: 131, AUC: 0.9174, AP: 0.9303, Loss: 1.0598678588867188
Epoch: 137, AUC: 0.9205, AP: 0.9328, Loss: 1.0251210927963257
Epoch: 141, AUC: 0.9221, AP: 0.9354, Loss: 1.0087817907333374
Epoch: 145, AUC: 0.9244, AP: 0.9402, Loss: 1.054215908050537
Epoch: 147, AUC: 0.9290, AP: 0.9415, Loss: 1.0088623762130737
Epoch: 148, AUC: 0.9336, AP: 0.9455, Loss: 1.0282702445983887
Epoch: 198, AUC: 0.9321, AP: 0.9457, Loss: 1.033915400505066
Epoch: 199, AUC: 0.9344, AP: 0.9482, Loss: 0.9904683828353882
Epoch: 283, AUC: 0.9329, AP: 0.9490, Loss: 0.9304066896438599
Epoch: 311, AUC: 0.9360, AP: 0.9509, Loss: 0.9486202597618103
Epoch: 312, AUC: 0.9406, AP: 0.9549, Loss: 0.9481338262557983
Epoch: 358, AUC: 0.9406, AP: 0.9560, Loss: 1.0044310092926025

学習の履歴は次のようになりました。

import matplotlib.pyplot as plt

plt.title("VGAE")
plt.plot(loss_hist, label="Loss")
plt.grid()
plt.legend()
plt.yscale('log')
plt.show()
plt.title("VGAE")
plt.plot(auc_hist, label="AUC")
plt.plot(ap_hist, label="AP")
plt.grid()
plt.legend()
plt.show()

GAEとVGAE_42_0.png

GAEとVGAE_42_1.png

学習はできたので、次は(辺の)予測です。エンコーダーで算出した z を利用して、予測された隣接行列(っぽいもの)を算出します。

z = best_model.encode(x, train_pos_edge_index)
prob_adj = z @ z.T
prob_adj = prob_adj - torch.diagonal(prob_adj)
prob_adj
tensor([[ 0.0000, -0.2043, -1.2594,  ..., -2.0548, -1.1037, -0.7604],
        [ 0.1644,  0.0000, -1.0084,  ..., -2.1810, -1.2070, -0.8390],
        [ 0.2415,  0.1238,  0.0000,  ..., -2.0992, -1.3040, -0.9518],
        ...,
        [-1.2879, -1.7827, -2.8331,  ...,  0.0000, -0.1195, -0.0661],
        [-1.2505, -1.7225, -2.9517,  ..., -1.0332,  0.0000,  0.0888],
        [-1.1217, -1.5690, -2.8140,  ..., -1.1944, -0.1257,  0.0000]],
       grad_fn=<SubBackward0>)

この値が閾値以上の場合に「隣接する」と考えたいと思います。閾値 threshold は 0 でも良いかもしれませんが、今回はこのように算出してみました。

prob_adj_values = prob_adj.detach().cpu().numpy().flatten()
prob_adj_values.sort()
dataset = use_dataset()
threshold = prob_adj_values[-len(dataset.data.edge_attr)]
dataset.data.edge_index = (prob_adj >= threshold).nonzero(as_tuple=False).t()

以上のように予測された辺を用いて、ネットワークを生成します。

import networkx as nx
import matplotlib.pyplot as plt

G = torch_geometric.utils.convert.to_networkx(dataset.data)
plt.figure(figsize=(12,12))
nx.draw_networkx(G, pos=dataset.layout(), with_labels=False, alpha=0.5, node_color=dataset.node_color())

GAEとVGAE_48_0.png

うーん...GAEとVGAE、どっちが良いか微妙ですね。

データセットを変える

GridDatasetでも同様に計算してみました。

GAE

GAEとVGAE_のコピー_25_0.png

GAEとVGAE_のコピー_25_1.png

GAEとVGAE_のコピー_31_0.png

VGAE

GAEとVGAE_のコピー_42_0.png

GAEとVGAE_のコピー_42_1.png

GAEとVGAE_のコピー_48_0.png

GridDataset、意外と難しいんですね。そして、データセットを変えると、当然ながら生成されるネットワークの特徴も変わることが改めて実感できたというか。

6
9
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?