LoginSignup
7
5

More than 3 years have passed since last update.

Graph Convolutional Networks の実装メモ

Last updated at Posted at 2019-07-14

内容

  • PyTorch で Graph Convolutional Networks を実装する。
  • 簡単な具体例で動作確認をする。

参照論文

SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS

ライブラリー

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import math

import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

データ

Zachary’s Karate Club

G = nx.karate_club_graph()

pos = nx.spring_layout(G)
color = []
for node in G.nodes:
    if G.node[node]['club'] == 'Mr. Hi':
        color.append('r')
    elif G.node[node]['club'] == 'Officer':
        color.append('b')

plt.figure(figsize=(5, 5))
#nx.draw(G, pos=pos, node_size=200, node_color=color, with_labels=True)
nx.draw_networkx(G, pos=pos, node_size=200, node_color=color, with_labels=True)

image.png

前処理

A = nx.adjacency_matrix(G).todense()
L = nx.laplacian_matrix(G).todense()
D = L + A

A_tilde = np.array(A).astype(np.float) + np.identity(G.number_of_nodes(), dtype=np.float)
D_temp = np.array(D).astype(np.float)   # <-- change to D_tilde = D + I

for i in range(G.number_of_nodes()):
  D_temp[i][i] = 1.0 / math.sqrt(D_temp[i][i])

A_hat =  np.matmul(np.matmul(D_temp, A_tilde), D_temp)

X = np.identity(G.number_of_nodes(), dtype=np.float)   

# (*)

y = []
for node in G.nodes:
    if G.node[node]['club'] == 'Mr. Hi':
        y.append(0)
    elif G.node[node]['club'] == 'Officer':
        y.append(1)

(*) 定義の確認と他の定式化との比較をすること。

image.png

学習

参照論文より引用。

image.png

image.png

class GCN(nn.Module):
  def __init__(self, A_hat, num_feat, num_hidden, num_class):
    super(GCN, self).__init__()
    self.num_feat = num_feat
    self.num_hidden = num_hidden
    self.num_class = num_class
    self.A_hat = A_hat
    self.W_0 = nn.Parameter(torch.Tensor(num_feat, num_hidden))
    self.W_1 = nn.Parameter(torch.Tensor(num_hidden, num_class))
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.W_0.size(1))
    self.W_0.data.uniform_(-stdv, stdv)
    stdv = 1. / math.sqrt(self.W_1.size(1))
    self.W_1.data.uniform_(-stdv, stdv)

  def forward(self, X, A_hat):
    H = torch.mm(torch.mm(A_hat, X), self.W_0)
    H = F.relu(H)
    H = torch.mm(torch.mm(A_hat, H),self.W_1)
    return F.log_softmax(H, dim=1)

num_feat = len(G.nodes())
num_hidden = 10
num_class = 2

model = GCN(A_hat, num_feat, num_hidden, num_class).to(device)

#for p in model.parameters():
#  print(p.shape)   # 形の確認
#  print(p)   # 初期値の確認

loss_fun = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

#print(optimizer)   # デフォルト値の確認

A_hat_tensor = torch.Tensor(A_hat).to(device)
X_tensor = torch.Tensor(X).to(device)
y_tensor = torch.LongTensor(y).to(device)    

loss_hist = []
acc_hist = []

for epoch in range(500):
  model.train()
  model.zero_grad()

  output = model(X_tensor, A_hat_tensor)
  loss = loss_fun(output, y_tensor)
  loss_hist.append(loss.item())

  preds = torch.argmax(output, dim=1)
  acc = torch.mean(torch.eq(preds, y_tensor).type(torch.DoubleTensor)).numpy()
  acc_hist.append(acc)

  if (epoch+1)%100 == 0:
    print('{}  Loss: {:.4f}  Acc: {:.4f}'.format(epoch+1, loss.item(), acc))

  loss.backward()
  optimizer.step()

#for p in model.parameters():
#  print(p)   # 学習後の値の確認

image.png

結果の可視化

損失の推移

plt.figure(figsize=(5, 3))
plt.plot(loss_hist)
plt.grid(True)
plt.show()

image.png

正解率の推移

plt.figure(figsize=(5, 3))
plt.plot(acc_hist)
plt.grid(True)
plt.show()

image.png

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5