LoginSignup
10
3

More than 5 years have passed since last update.

DeepMindのLibraryの graph_nets を使ってとにかく「5」と出力するGraphNetを作る

Posted at

はじめに

GraphNets というライブラリが DeepMind からリリースされています。
このライブラリは この論文 の実装として公開されたもので、
端的に言うと 「ノード、エッジ、Globalに任意の属性(≒任意のtensor)を持ったグラフを入力にし、 構造は同じだが属性が更新されたグラフ を出力する」 というネットワークブロックを提供するものです。
何が面白そうかというと、グラフという知識表現として非常に表現力の高いデータ構造を扱うことができるということです。
以前から、グラフ畳み込みニューラルネットワークというものが(やや)注目されていましたが、それの汎用的なデータフォーマットを定めて、扱いやすくしてくれたような位置づけだと思います。

とはいうものの、現状ほとんどドキュメントも無く、デモのソースコードを見ても中々使い方がわかりません。
試行錯誤の末、なんとか GraphNetを学習させる道筋が見えてきたので、書いておこうと思います。

今回のお題は、「とにかくあるノードを5と出力するGraphNetを作る」です。

環境

  • Python 3.6
  • graph-nets: 1.0.0
    • 何かよくわからないけど、 pip install -U dm-sonnet==1.23 を入れておかないと動かなかった

graph_nets の使い方概要

いきなり見ると呪文のようですが、ライブラリ付属のデモコードや今回の私のコードを読む時に役に立ちそうなことを書いておきます。

GraphsTuple

大事な型として、 GraphsTuple というのがあって、 これはGraph構造を表現している Tensorの集まりです。

GraphsTuple は、 gn.utils_tf.data_dicts_to_graphs_tuple() のような関数で生成できます。
この data_dicts は 論文Box 3: Our definition of “graph” を意味しているようです。
ざっくり言うと、以下のようなものです。

  • n_node, n_edge: ノードとエッジの数
  • nodes: ノードの属性(attritbues)のリスト(や np.array)
  • edges: エッジの属性のリスト(やnp.array)
  • globals: Globalの属性のリスト(やnp.array)
  • receivers, senders: エッジがどこからどこに向かっているかを ノードのindexで表す

GraphNetたち (modules.*)

gn.modules.*GraphNetworkGraphIndependent がGraphNetを表していて、 GraphsTuple を受け取って、GraphsTuple を出力します。
ここに学習用の重みとかが詰まっている(正確には重みは引数で渡す)ので、 入力 → <GraphNet> → Output -> Loss という流れを作って、 Lossを Minimize することで学習が進行します。
出力される GraphsTuple もTensorなので、session.run() などで計算して値を取り出すことができます。

GraphNetworkGraphIndependent は、ノードやエッジやGlobalの属性をどういう感じで計算に用いるかで色々バリエーションがあるようで、
GraphNetwork は フルに畳み込む感じで、 GraphIndependent は ノードはノード属性のみ、エッジはエッジ属性のみ、というように独立して使うようです。

とにかくあるノードを5と出力するGraphNetを作る

今回は、 Hello Worldの気分で、図のようなGraphを入力として、出力は n2の属性=5 となるように学習させることを考えてみます。

image.png

コード

ソースコードをまるっと書くと以下のようになります。

import sonnet as snt
import graph_nets as gn
import tensorflow as tf

def get_graphs():
    # graphの構造と属性。 
    data_dict = dict(
        nodes=[[2.], [3.], [0.]], # node attrs
#        edges=[[0.], [0.]],    # edge attrs
#        globals=[0.],
        receivers=[2, 2],
        senders=[0, 1],
        n_node=3,
        n_edge=2,
    )
    return gn.utils_tf.data_dicts_to_graphs_tuple([data_dict])  # returns GraphTuple

# Create the graph network. 今回は GraphIndependent を使う。
graph_net_module = gn.modules.GraphIndependent(
#    edge_model_fn=lambda: snt.nets.MLP([8, 1]),
    node_model_fn=lambda: snt.nets.MLP([8, 1]),  # ここのMLPの出力次元が、Outputの次元になる
#    global_model_fn=lambda: snt.nets.MLP([8, 1]),
)

# 入力用のGraphを作る
input_graphs = get_graphs()
input_graph_tr = gn.utils_tf.make_runnable_in_session(input_graphs)  # egdes: None -> edges: tf.no_op() みたいな変換をしている

# GraphNet に InputのGraphを通して、 OutputのGraphを得る
output_graphs = graph_net_module(input_graph_tr)  

# Loss関数。 node[2] が 5 になるようにLossを作る。要するに他のノードの値は一切見ていない!
loss_op_tr = tf.reduce_sum((output_graphs.nodes[2] - 5) ** 2)  # reduce_sum 要らないかも

# Optimizer. ここまでくると、いつもの Tensorflow。
learning_rate = 1e-2
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

# Init Session
try:
  sess.close()
except NameError:
  pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Run Training

loss_history = []
for _ in range(1000):
    train_values = sess.run({
      "loss": loss_op_tr,
      "step": step_op,
      "outputs": output_graphs 
    })
    loss_history.append(train_values['loss'])

from pprint import pprint
pprint(train_values)
print(train_values['outputs'].nodes[2])

import matplotlib.pyplot as plt
plt.plot(loss_history)

実行した出力

上記を実行すると、こんな感じの出力になります。

{'loss': 0.018520564,
 'outputs': GraphsTuple(nodes=array([[6.091499 ],
       [6.7052937],
       [4.8639097]], dtype=float32), edges=None, receivers=array([2, 2], dtype=int32), senders=array([0, 1], dtype=int32), globals=None, n_node=array([3], dtype=int32), n_edge=array([2], dtype=int32)),
 'step': None}
[4.8639097]

image.png

4.86とまだ5になってないですね。ちなみにあと1000回ループするとほとんど5になります。

さいごに

とても高度な技術で何の役にも立たない学習をさせましたが、何がどうなっているのかさっぱりだったので、これを作るのに3時間くらいかかりました...
しかしようやく仕組みがわかってきたので、もう少し違うことを試してみようと思います。

10
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
3