はじめに
GraphNets というライブラリが DeepMind からリリースされています。
このライブラリは この論文 の実装として公開されたもので、
端的に言うと 「ノード、エッジ、Globalに任意の属性(≒任意のtensor)を持ったグラフを入力にし、 構造は同じだが属性が更新されたグラフ を出力する」 というネットワークブロックを提供するものです。
何が面白そうかというと、グラフという知識表現として非常に表現力の高いデータ構造を扱うことができるということです。
以前から、グラフ畳み込みニューラルネットワークというものが(やや)注目されていましたが、それの汎用的なデータフォーマットを定めて、扱いやすくしてくれたような位置づけだと思います。
とはいうものの、現状ほとんどドキュメントも無く、デモのソースコードを見ても中々使い方がわかりません。
試行錯誤の末、なんとか GraphNetを学習させる道筋が見えてきたので、書いておこうと思います。
今回のお題は、「とにかくあるノードを5と出力するGraphNetを作る」です。
環境
- Python 3.6
- graph-nets: 1.0.0
- 何かよくわからないけど、
pip install -U dm-sonnet==1.23
を入れておかないと動かなかった
- 何かよくわからないけど、
graph_nets の使い方概要
いきなり見ると呪文のようですが、ライブラリ付属のデモコードや今回の私のコードを読む時に役に立ちそうなことを書いておきます。
GraphsTuple
大事な型として、 GraphsTuple
というのがあって、 これはGraph構造を表現している Tensorの集まりです。
GraphsTuple は、 gn.utils_tf.data_dicts_to_graphs_tuple()
のような関数で生成できます。
この data_dicts は 論文 の Box 3: Our definition of “graph”
を意味しているようです。
ざっくり言うと、以下のようなものです。
- n_node, n_edge: ノードとエッジの数
- nodes: ノードの属性(attritbues)のリスト(や np.array)
- edges: エッジの属性のリスト(やnp.array)
- globals: Globalの属性のリスト(やnp.array)
- receivers, senders: エッジがどこからどこに向かっているかを ノードのindexで表す
GraphNetたち (modules.*)
gn.modules.*
の GraphNetwork
や GraphIndependent
がGraphNetを表していて、 GraphsTuple
を受け取って、GraphsTuple
を出力します。
ここに学習用の重みとかが詰まっている(正確には重みは引数で渡す)ので、 入力 → <GraphNet> → Output -> Loss
という流れを作って、 Lossを Minimize することで学習が進行します。
出力される GraphsTuple
もTensorなので、session.run()
などで計算して値を取り出すことができます。
GraphNetwork
や GraphIndependent
は、ノードやエッジやGlobalの属性をどういう感じで計算に用いるかで色々バリエーションがあるようで、
GraphNetwork
は フルに畳み込む感じで、 GraphIndependent
は ノードはノード属性のみ、エッジはエッジ属性のみ、というように独立して使うようです。
とにかくあるノードを5と出力するGraphNetを作る
今回は、 Hello Worldの気分で、図のようなGraphを入力として、出力は n2の属性=5 となるように学習させることを考えてみます。
コード
ソースコードをまるっと書くと以下のようになります。
import sonnet as snt
import graph_nets as gn
import tensorflow as tf
def get_graphs():
# graphの構造と属性。
data_dict = dict(
nodes=[[2.], [3.], [0.]], # node attrs
# edges=[[0.], [0.]], # edge attrs
# globals=[0.],
receivers=[2, 2],
senders=[0, 1],
n_node=3,
n_edge=2,
)
return gn.utils_tf.data_dicts_to_graphs_tuple([data_dict]) # returns GraphTuple
# Create the graph network. 今回は GraphIndependent を使う。
graph_net_module = gn.modules.GraphIndependent(
# edge_model_fn=lambda: snt.nets.MLP([8, 1]),
node_model_fn=lambda: snt.nets.MLP([8, 1]), # ここのMLPの出力次元が、Outputの次元になる
# global_model_fn=lambda: snt.nets.MLP([8, 1]),
)
# 入力用のGraphを作る
input_graphs = get_graphs()
input_graph_tr = gn.utils_tf.make_runnable_in_session(input_graphs) # egdes: None -> edges: tf.no_op() みたいな変換をしている
# GraphNet に InputのGraphを通して、 OutputのGraphを得る
output_graphs = graph_net_module(input_graph_tr)
# Loss関数。 node[2] が 5 になるようにLossを作る。要するに他のノードの値は一切見ていない!
loss_op_tr = tf.reduce_sum((output_graphs.nodes[2] - 5) ** 2) # reduce_sum 要らないかも
# Optimizer. ここまでくると、いつもの Tensorflow。
learning_rate = 1e-2
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)
# Init Session
try:
sess.close()
except NameError:
pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Run Training
loss_history = []
for _ in range(1000):
train_values = sess.run({
"loss": loss_op_tr,
"step": step_op,
"outputs": output_graphs
})
loss_history.append(train_values['loss'])
from pprint import pprint
pprint(train_values)
print(train_values['outputs'].nodes[2])
import matplotlib.pyplot as plt
plt.plot(loss_history)
実行した出力
上記を実行すると、こんな感じの出力になります。
{'loss': 0.018520564,
'outputs': GraphsTuple(nodes=array([[6.091499 ],
[6.7052937],
[4.8639097]], dtype=float32), edges=None, receivers=array([2, 2], dtype=int32), senders=array([0, 1], dtype=int32), globals=None, n_node=array([3], dtype=int32), n_edge=array([2], dtype=int32)),
'step': None}
[4.8639097]
4.86とまだ5になってないですね。ちなみにあと1000回ループするとほとんど5になります。
さいごに
とても高度な技術で何の役にも立たない学習をさせましたが、何がどうなっているのかさっぱりだったので、これを作るのに3時間くらいかかりました...
しかしようやく仕組みがわかってきたので、もう少し違うことを試してみようと思います。