この論文を読んだ理由
- この論文の手法を取り入れた論文を数多く見かけたため.
読んだところ
- 全体
解いている課題
- GNNにおける集約において,近傍ノードがグラフ構造に強く依存してしまっている問題.具体的には,中心のノードは近傍が多く,端のノードは近傍が少ない.
- ノードごとに適当的に近傍の範囲を定めたい.
提案手法のアプローチ
各層の出力を集めてきた後,Concat/Max-pooling/LSTM-attention 演算を行う.
-
Concat...この場合は適応的ではない.Concatのあと線形層に通すことで全てのホップ数を考慮することが出来る.
-
Max-pooling...各層の出力のうち,一番情報が多い要素(値が大きい要素)を取ることで,適応的に集約するノードを選択できる.
-
LSTM-attention...双方向のLSTMで各層の重要度を計算し,重み付き和で集約を行う.重要なホップ数を定め,適応的に集約するノードを選択できる.
実装例
2層のGCNにJumping Knowledgeを取り入れたモデルです.
参考:https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gcn.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, JumpingKnowledge
class GCNWithJK(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.jump = JumpingKnowledge(mode)
if mode == 'cat':
# concatするので次元がhidden_channelsの2倍になる.
self.lin = Linear(hidden_channels * 2, out_channels)
else:
self.lin = Linear(hidden_channels, out_channels)
def reset_parameters(self):
self.conv1.reset_parameters()
self.conv2.reset_parameters()
self.jump.reset_parameters()
self.lin.reset_parameters()
def forward(self, x, edge_index):
xs = [] # 各層の出力を格納する
# 一層目
x = self.conv1(x, edge_index)
x = F.relu(x)
xs += [x]
# 二層目
x = self.conv2(x, edge_index)
x = F.relu(x)
xs += [x]
# 各層からの出力を集約
x = self.jump(xs)
# 最終層の次元へ変換
x = self.lin(x)
return F.log_softmax(x, dim=-1)