14
7

More than 5 years have passed since last update.

Fast Graph Representation Learning with PyTorch Geometric

Last updated at Posted at 2019-05-22

非公開(限定共有)にしていた輪講用メモを訳あって公開したものです.
現時点ではほとんどただの論文翻訳になっていますが,気が向いたら書き足します.

書誌情報他

概要・背景

グラフや点群,多様体といった不規則構造データ向けの深層学習ライブラリ PyTorch Geometric (PyG) の紹介.
専用の CUDA カーネルと効率的なミニバッチ管理によってスパース GPU 高速化を活用し,高スループットを実現.

GNNs ではかなりスパースで不規則でサイズの決まっていないデータを扱うため,今までは GPU での高スループットを出しづらく,実装が難しかった.

他の PyTorch 向けライブラリに Deep Graph Library があるが,記事投稿時点では PyG の方が注目されている模様(Star 数 2100 vs 3700).

特徴

グラフの扱い

ノード数を $N$,ノードの特徴次元数を $F$,エッジ数を $E$,エッジの特徴次元数を $D$ として

  • グラフ: $\mathcal{G} = (\boldsymbol{X}, (\boldsymbol{I}, \boldsymbol{E}))$
  • ノード特徴行列: $\boldsymbol{X} \in \mathbb{R}^{N\times F}$
  • エッジインデックス(COOrdinate フォーマット): $\boldsymbol{I} \in \mathbb{N}^{2\times E}$
  • エッジ特徴行列(オプション): $\boldsymbol{E}\in\mathbb{R}^{E\times D}$

のようにグラフを扱う.
後のコード例のコメントでも説明される.

Neighborhood Aggregation

Neighborhood aggregation or message passing の手法を,gather/scatter を用いて計算.
eq.png
gather_scatter.png

疎行列を用いた実装に対して,次数の低いグラフや non-coalesced 入力において利点があるらしい.
加えて,aggregation で central node や複数次エッジ情報を使えるようになるとのこと.

PyG が提供する MessagePassing インターフェースにより,ユーザは以下の3点に集中できる.

  • message メソッド($\phi$ に対応)の作成
  • update メソッド($\gamma$ に対応)の作成
  • aggregation 手法の選択

たとえばコードを次のように書ける(GitHub の README から拝借).

import torch
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, F_in, F_out):
        super(EdgeConv, self).__init__(aggr='max')  # "Max" aggregation.
        self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))

    def forward(self, x, edge_index):
        # x has shape [N, F_in]
        # edge_index has shape [2, E]
        return self.propagate(edge_index, x=x)  # shape [N, F_out]

    def message(self, x_i, x_j):
        # x_i has shape [E, F_in]
        # x_j has shape [E, F_in]
        edge_features = torch.cat([x_i, x_j - x_i], dim=1)  # shape [E, 2 * F_in]
        return self.mlp(edge_features)  # shape [E, F_out]

近年提案された neighborhood aggregation functions は,ユーザが作成しなくとも既に PyG に組み込み済み.

  • GCN,SGC,GraphSAGE,GAT,APPNP,などなど

他に,点群や多様体,複数次エッジ特徴をもつデータ向けに

  • PointCNN,MoNet,などなど

さらに high-level 実装として

  • autoencoding graphs,aggregating jumping knowledge,などなど

充実している.

Mini-batch Handling

複数のグラフインスタンス(サイズが異なってもよい)に対応.
これは自動的に

  • ブロック対角な隣接行列を作成
  • ノード次元方向に特徴行列を concat

することで実現.
接続のないグラフ間ではメッセージが交換されないので,neighborhood aggregation を修正することなしに複数インスタンスに適用可能.

実装の正しさの評価

ライブラリ実装がまともであるか確認するため,以下の3つの問題について各種手法の実験を行い,元論文の結果と比較している(省略).

  • 半教師ありノード分類
  • グラフ分類
  • 点群分類

実行時間の計測

種々のデータセット・モデルについて,シングル GPU での訓練時間(200 epoch)を計測.
Deep Graph Library(Degree Bucketing 利用)と比べると,最大40倍ほど速い.
Deep Graph Library(gather/scatter 利用)と比べるとあまり変わらない……
ただし,GAT についてはスパース softmax カーネルを最適化しているため,最大7倍ほど高速.
table4.png

14
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
7