3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

GCNを用いずにRNNでグラフを生成する手法:GraphRNN

Last updated at Posted at 2020-10-31

目的

強化学習を用いた構造最適化を行いたいと考えている.
その際,グラフを用いて構造を表現することが出来そうな為,グラフを生成するGraphRNNを読んでみることにした.これはその論文を読んだ備忘録である.
数学的観点に関して飛ばして,ざっとアルゴリズムを知りたい方は,提案手法のAbstractとGraphRNNのフレームワーク,グラフのシーケンス学習,モデルの構成,学習時の条件のみを見れば良いと思う.

論文紹介・画像引用・GIF引用

GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Models

提案手法

Abstract

  • 無向グラフを対象にしている.

  • このアプローチの重要なアイデアは,異なるノード順序の下でのグラフをシーケンスとして扱っており,グラフ畳み込みニューラルネットワークを使わずにRNN的手法を用いてグラフ生成を可能にしている.

  • グラフ生成手順を以下の図のように,グラフの情報を保持するGraphレベルのRNNと,新規に追加するノードに関して,一つ一つの既存ノードが隣接するかどうかを一つ一つ出力していくEdgeレベルのRNNから構成されている.

  • これらはGraphVAEと比べ,複雑なエッジの情報を持つ多様なサイズのモデルを生成することが出来る.

  • 学習時,学習の複雑さを大幅に削減する為,シーケンス生成方法はBFSノード順序付けのスキームを導入した.

数学的定義

  • グラフは$G=(V,E)$で定義する.ここでの$V,E$はノード,エッジを示している.
    $\pi$は置換関数を表す.つまり,(1,2,3)を(2,1,3)とか(3,1,2)にするものである.

  • $\Pi$は$n!$通りのノードの置換の全てのパターンの置換関数$\pi$を示す.

グラフをシーケンスとしてモデリング(スキップ可能)

グラフから列にマッピングする関数$f_S$を定義する.ここでのグラフ$G \sim p(G)$は$\pi$によって並べられた$n$ノードから出来上がっており,列(シーケンス)は次の通り表される.
$$S^{\pi}=f_S(G,\pi)=(S_1^{\pi},...,S_n^{\pi})$$
ここで,$S^{\pi}$のそれぞれの要素は,$S_i^{\pi}\in \lbrace0,1 \rbrace ^{i-1},i\in \lbrace1,...,n \rbrace$という風に表される.これは,隣接ベクトルとして扱い,新規ノード$\pi(v_i)$と既に存在するノード$\pi(v_j), j \in \lbrace1,...,i-1 \rbrace$とで隣接しているかどうかを0,1で示している.
よって,$S^{\pi}$の要素$S_i^{\pi}$はまとめると次の式で表される.
$$S_i^{\pi}=(A_{1,i}^{\pi},...,A_{i-1,i}^{\pi})^T,\forall i \in \lbrace2,...,n \rbrace$$

また,$S^{\pi}$から固有のグラフ$G$を求めるための関数を$f_G$とし,$f_G(S^{\pi})=G$となる.

この定義式の基,グラフの分布$p(G)$は次のように表現することができる.
$$p(G)=\sum_{S^{\pi}}p(S^{\pi}) \mathbb{1}[f_G(S^{\pi})=G]$$
$$p(S^{\pi})=\Pi_{i=1}^{n+1} p(S_i^{\pi}|S_1^{\pi},...,S_{i-1}^{\pi})$$
なお,ここでは$n+1$回目をシーケンスの終了EOSとした.また,今後は,$p(S_i^{\pi}|S_1^{\pi},...,S_{i-1}^{\pi})$を簡単のため,$p(S_i^{\pi}|S_{<i}^{\pi})$と表現する.

GraphRNNのフレームワーク

先の定義を用いることで,$p(G)$を$p(S^{\pi})$の形で表現することができた.
ここで,本手法では$p(S_i^{\pi}|S_{<i}^{\pi})$においてニューラルネットワークを用い,複雑な分布を表現できるようにする.
ここで,RNNを用いた.このRNNは,次の状態遷移関数と出力関数の二つから構成される.
$$h_i=f_{trans}(h_{i-1},S_{i-1}^{\pi})$$
$$\theta_i=f_{out}(h_i)$$

$h_i \in \mathbb{R}^d$は,グラフの状態をエンコードした特徴ベクトルである.$S_{i-1}^{\pi}$は最近作られたノード$i-1$がどのノードと隣接しているかを表すベクトル.$\theta_i$は次のノードの隣接ベクトルの分布を決定するためのベクトル($S_i^{\pi}\sim P_{\theta_i}$).$P_{\theta_i}$は2値ベクトル上の分布を示している.
一般的に,ここでの$f_{trans},f_{out}$は任意のNNを利用してよい.また,これらを用いたグラフ生成アルゴリズムの概略図は次に示される.

上で述べたのは全般的に使えるフレームワークである.次に,本手法で利用したアルゴリズムの詳細に関しては次の節で触れる.

GraphRNNの種類(スキップ可能)

GRNNの種類として,本研究では二つの手法を提案している.
遷移関数$f_{trans}$(グラフレベルRNN)は同じGated Recurrent Unit(GRU
))(LSTMをシンプルにしたモデル)を利用している.
対して$f_{out}$(エッジレベルモデル)においては異なる二つのものを提案する.異なるポイントは,$f_{out}$においてRNNを使用しているかしていないかの違い.

これらは**それぞれSGDと$S^{\pi}$に関する最尤度損失を用いて学習させる.**詳しくいうと,観測されたすべてのグラフシーケンスにおける$\Pi p_{model}(S^{\pi})$を最適化することでNNのパラメータを学習させる.

Multivariate Bernoulli

こっちはシンプルなGRNNのアプローチであり,GraphRNN-Sと呼ぶ.
このアプローチでは$p(S_i^{\pi}|S_{<i}^{\pi})$をシンプルに多変量ベルヌーイ分布(ベルヌーイ分布とは,結果が0か1かのものの確率分布を示している.多変量は,0,1となる変数が複数あるものを指している.)としてモデルする.
そしてこの多変量ベルヌーイ分布$p(S_i^{\pi}|S_{<i}^{\pi})$は,$f_{out}$によって出力された$\theta_i \in \mathbb{R}^{i-1}$によって決定される.
ここで,我々は$f_{out}$を一層のMLPとSigmoid関数より作成した.
具体的にいうと,ベクトル$\theta_i$の要素$\theta_i[j]$はエッジ$(i,j)$の存在確率として解釈することができる.そしてこのエッジの存在確率$\theta_i$を基にエッジをサンプルし,新しい$S_i^{\pi}$を作成する.

Dependent Bernoulli sequence

このアルゴリズムでは網羅的にエッジの情報を処理するため,$p(S_i^{\pi}|S_{<i}^{\pi})$を条件付き確率分布として次のように扱う.
$$p(S_i^{\pi}|S_{<i}^{\pi})= \Pi_{j=1}^{i-1} p(S_{i,j}^{\pi}|S_{i,<j}^{\pi},S_{<i}^{\pi})$$

ここでの$S_{i,j}^{\pi}$はバイナリスカラーとして,1のときはノード$\pi (v_{i+1})$がノード$\pi (v_{j})$と隣接していることを示している.

そしてこの確率の下,このモデルでは,先と違い,$f_{out}$においてもRNNを利用することによって,全体的なアルゴリズムとして以下の図のような階層的RNNモデルとなる.
また,$f_{out}$(エッジレベルのRNN)は$f_{trans}$(グラフレベルRNN)と同じくGRU(Gated Reccurent Unit)を利用しており,MLPを用いてエッジの存在確率をスカラーとして順々に出力している.

グラフのシーケンス学習

この研究での重大な発見は,ランダムなノード順列の下で学習するのではなく,BFS(幅優先探索)によるノード順列を利用しても,一般性を失うことなくグラフ生成の学習を行えたことである.
この下,隣接情報を示す$S^{\pi}$を生成する関数$f_S(G,\pi)$を次のように置き換えることができる.
$$S^{\pi}=f_S(G,BFS(G,\pi))$$

ここでの$BFS()$はBFS関数を表しており,グラフ$G$とランダム順序を示す$\pi$を入力とし,そして$\pi(v_1)$を最初としてBFSに基づくシーケンスを出力とする.

グラフの生成時,BFSを用いたノードの順序を利用することには二つのメリットがある.

一つ目は学習データが少なくてよくなることである.
BFSにおける$\pi$は,複数の$\pi$における$BFS(G,\pi)$が同じシーケンスを出力する.よって,$BFS(G,\pi)$を学習するときは$\pi$を基に学習するときよりデータ数が少なく済む.

二つ目は,エッジレベルRNNにおいて推定すべきエッジの数を減らすことができることである.具体的に言うと,ある程度ノードが追加されたとき,もし新しいノードをBFSの順序の下で追加するときは,新しいノードはBFSシーケンスの中でも最後の方のいくつかのノードとしかつながらない.よってグラフ作成時,新しいノードからエッジを伸ばすかどうかを考慮するノード群が少なくて済み,生成コストが少なく,学習が安定する.この研究では,この考慮するノードの個数はハイパーパラメータとして固定している.

モデルの構成

GraphRNNのモデル構成は,グラフレベルのRNNは128次元の特徴ベクトルを保持する4層のGRUで,エッジレベルのRNNも16次元の特徴ベクトルを保持する4層のGRUで構成されている.エッジレベルRNNにおいては,隣接ベクトルを出力する為,GRUで求まった16次元の特徴ベクトルはMLPとReLUを用いて8次元ベクトルに変換した後,さらに別のMLP,Sigmoidを用いて確率を表すスカラー値に変換している.

エッジレベルRNNはグラフレベルRNNの出力を入力として毎回隣接ベクトルを生成するとき(下図における$h_i$から下に伸びる青のベクトル)に初期化される.この時,128次元を16次元に落とし込むため,全結合層を利用している.
また,論文では見当たらなかったが,エッジレベルRNNの情報を同じくグラフレベルRNNの入力に加えるため,全結合層を利用して16次元を128次元にしていると考えられる.

学習時の条件

グラフを学習させるとき,既存のグラフに対し,ランダムな$\pi$を基にBFS順のシーケンスを作成する.
そして,これらを学習データとする.なお,学習時のGrapRNNは,逐一エッジの隣接は推定させるものの,その推測に基づくエッジは実際に連結させず,学習データに基づいてエッジを連結させる.
テスト時は,推測させたエッジの隣接情報を基にグラフを生成させる.
それぞれSGDと生成された$S^{\pi}$に関する最尤度損失を用いて学習させる.

## 実験 以下のグラフ生成方法と比較し,高品質なグラフの生成力に関する実験検証を行った.

Traditional baselines

  • E-R model
  • B-A model
  • Kronecker graph model: 学習用パラメータを含む
  • mixed-membership stochastic block model: 学習用パラメータを含む

Deeplearning baselines

  • GraphVAE
  • DeepGMG

結果

本手法では,Maximum Mean Discrepancy(MMD)測定に即した評価手法を提案する(詳しいことは省く).
このMMDスコアは二つのデータ分布における近さを示しており,スコアが低ければ低いほど近いということを示す.

結果,他の手法と比べて良い結果を出した.

また,GraphRNNが生成したグラフとtrainデータ,そして他の手法が生成したグラフを以下の図に示す.結果,他の手法と比べてもGraphRNNがデータセットの構造を掴むことができていることを示している.

まとめ

この研究は,GCNを用いず,グラフを列としてとらえることにより,RNNを利用するだけでグラフ生成を行っている.
そしてそれを基に,GraphVAEよりも表現性能が上回った点が興味深かった.

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?