非公開(限定共有)にしていた輪講用メモを訳あって公開したものです.
現時点ではほとんどただの論文翻訳になっていますが,気が向いたら書き足します.
書誌情報他
概要・背景
グラフや点群,多様体といった不規則構造データ向けの深層学習ライブラリ PyTorch Geometric (PyG) の紹介.
専用の CUDA カーネルと効率的なミニバッチ管理によってスパース GPU 高速化を活用し,高スループットを実現.
GNNs ではかなりスパースで不規則でサイズの決まっていないデータを扱うため,今までは GPU での高スループットを出しづらく,実装が難しかった.
他の PyTorch 向けライブラリに Deep Graph Library があるが,記事投稿時点では PyG の方が注目されている模様(Star 数 2100 vs 3700).
特徴
グラフの扱い
ノード数を $N$,ノードの特徴次元数を $F$,エッジ数を $E$,エッジの特徴次元数を $D$ として
- グラフ: $\mathcal{G} = (\boldsymbol{X}, (\boldsymbol{I}, \boldsymbol{E}))$
- ノード特徴行列: $\boldsymbol{X} \in \mathbb{R}^{N\times F}$
- エッジインデックス(COOrdinate フォーマット): $\boldsymbol{I} \in \mathbb{N}^{2\times E}$
- エッジ特徴行列(オプション): $\boldsymbol{E}\in\mathbb{R}^{E\times D}$
のようにグラフを扱う.
後のコード例のコメントでも説明される.
Neighborhood Aggregation
Neighborhood aggregation or message passing の手法を,gather/scatter を用いて計算.
疎行列を用いた実装に対して,次数の低いグラフや non-coalesced 入力において利点があるらしい.
加えて,aggregation で central node や複数次エッジ情報を使えるようになるとのこと.
PyG が提供する MessagePassing
インターフェースにより,ユーザは以下の3点に集中できる.
-
message
メソッド($\phi$ に対応)の作成 -
update
メソッド($\gamma$ に対応)の作成 - aggregation 手法の選択
たとえばコードを次のように書ける(GitHub の README から拝借).
import torch
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, F_in, F_out):
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))
def forward(self, x, edge_index):
# x has shape [N, F_in]
# edge_index has shape [2, E]
return self.propagate(edge_index, x=x) # shape [N, F_out]
def message(self, x_i, x_j):
# x_i has shape [E, F_in]
# x_j has shape [E, F_in]
edge_features = torch.cat([x_i, x_j - x_i], dim=1) # shape [E, 2 * F_in]
return self.mlp(edge_features) # shape [E, F_out]
近年提案された neighborhood aggregation functions は,ユーザが作成しなくとも既に PyG に組み込み済み.
- GCN,SGC,GraphSAGE,GAT,APPNP,などなど
他に,点群や多様体,複数次エッジ特徴をもつデータ向けに
- PointCNN,MoNet,などなど
さらに high-level 実装として
- autoencoding graphs,aggregating jumping knowledge,などなど
充実している.
Mini-batch Handling
複数のグラフインスタンス(サイズが異なってもよい)に対応.
これは自動的に
- ブロック対角な隣接行列を作成
- ノード次元方向に特徴行列を concat
することで実現.
接続のないグラフ間ではメッセージが交換されないので,neighborhood aggregation を修正することなしに複数インスタンスに適用可能.
実装の正しさの評価
ライブラリ実装がまともであるか確認するため,以下の3つの問題について各種手法の実験を行い,元論文の結果と比較している(省略).
- 半教師ありノード分類
- グラフ分類
- 点群分類
実行時間の計測
種々のデータセット・モデルについて,シングル GPU での訓練時間(200 epoch)を計測.
Deep Graph Library(Degree Bucketing 利用)と比べると,最大40倍ほど速い.
Deep Graph Library(gather/scatter 利用)と比べるとあまり変わらない……
ただし,GAT についてはスパース softmax カーネルを最適化しているため,最大7倍ほど高速.