グラフ向けの深層学習ライブラリDeep Graph Library(DGL)の基本的な使い方について紹介します。公式ドキュメントに事例やAPIの説明が詳細に載っていたりチュートリアルも豊富にありますが、DGLの一番基本的な動作(だと個人的に思っている)ノードの特徴量のmessageとreduceという2つの処理について、丁寧に説明している記事がなかったので説明してみます。
Deep Graph Library (DGL)とは?
New York UniversityとAWSが開発しているPytorch-basedの(?)グラフと対象としたDeep Learningのライブラリです。
画像や言語など従来よく研究されているデータ構造ではTensorFlow, Pytorch, Chainerなど有名なライブラリがあり、CNNやRNNなどが1つの関数(公式ではbuilding-blocksと言っている)になっていて、それを組み合わせることでモデルを作ることができます。しかし、グラフデータを対象としたライブラリではまだ決定版と言えるものがなく、DGLはその有力候補だと思います。
基本的な考え方
DGLの基本的な流れは以下の通りです。
- グラフの各ノードに特徴量を割り当てる
- あるノードから隣接ノードに1で割り当てた特徴量を送る(send)
- 2で隣接ノードから送られてきた特徴量を集約する(recv)
ここの2,3が公式ドキュメントやチュートリアルで丁寧な説明が少し足りないかなと思いました。唯一書いてあるのはPagerankのチュートリアルの真ん中くらいにある図くらいだと思います。以下、基本的な流れの詳細を説明します。説明ではDGL at a Glanceという一番最初のチュートリアルをベースに、さらに簡単にしたグラフ構造で説明します。コードもこのチュートリアルをベースにしています。
事前準備
公式チュートリアルはグラフ分野で有名なKarateデータでやっていますが、もっと簡単に5つのノードで構成されるグラフを作ります。
import dgl
def build_sample_graph():
g = dgl.DGLGraph()
g.add_nodes(5)
edge_list = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3)]
src, dst = tuple(zip(*edge_list))
# 無向グラフにするためsrcとdstを入れ替えて両方の向きにエッジを設定
g.add_edges(src, dst)
g.add_edges(dst, src)
return g
G = build_sample_graph()
可視化するとこんな感じです。
グラフの各ノードに特徴量を割り当てる
ここでは簡単に単位行列を各ノードの特徴量にします。ndataはノードの特徴量で、ここでは名前をhにしています。
G.ndata['h'] = torch.eye(5)
隣接ノードに特徴量を送る(send)
ソースとなるノードの特徴量('h')を隣接ノード(送り先では'msg'という名前で受信される)に送るコードは以下のように書きます。
def gcn_message(edges):
return {'msg' : edges.src['h']}
G.send(G.edges(), gcn_message)
この時点では、グラフ(特徴量)は以下のようになっています。図の黒字は元々の特徴量、赤字はノード0から隣接ノードに送った特徴量、同様に黄色はノード1が送ったもの、緑字はノード2が送ったもの、青字はノード3が送ったもの、紫字はノード4が送ったものです。
隣接ノードから送られてきた特徴量を集約する(recv)
最後に以下のようにして隣接ノードから送られてきた特徴量を集約します。送られてきた特徴量はnodeのmailboxというpropertyにあります。message処理で名前を'msg'としていたのでnodes.mailbox['msg']でデータにアクセスできます。
ここでは隣接ノードの特徴量の総和を新たにそのノードの特徴量('h')にセットしています。
def gcn_reduce(nodes):
return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}
G.recv(G.nodes(), gcn_reduce)
重みの学習
重みの学習はPytorchの他のモデルと同様に以下のようにやればできます。
optimizer.zero_grad()
loss.backward()
optimizer.step()
ここまでの全体のコードはこちらにあげました。
https://gist.github.com/k1ochiai/cd0279ca79dd74e91a2b5e1187928adb
今後の発展
Graph Convolutional NetworkやGraphSAGEなど最新のモデルを実装するには、この記事で説明したgcn_message, gcn_reduceという2つの関数を各手法になるように実装すればできます。
ちなみにDGLのリポジトリに結構最新のモデルも実装されているので使えそうです。例えばICLR 2019で発表された「How Powerful are Graph Neural Networks?」も実装があります(リンク)。