はじめに
ナレッジグラフにおけるリンク予測モデルの1つであるGraILについてまとめる.
本記事は,Inductive Relation Prediction by Subgraph Reasoning に掲載されている内容に基づく.
背景
従来の埋め込みベースの手法は,知識グラフの基礎となる構成的な論理ルールを明示的に捕捉しておらず,また,学習時にエンティティの完全な集合を必要とする.
GraILは、局所的なサブグラフ構造を推論し,実体に依存しない表現を学習する.埋め込みベースのモデルとは異なり,GraILはinductiveであり,学習後に未知のエンティティやグラフに汎化することができる.
モデル
モデルの構築には,サブグラフの抽出,ノードラベリング,GNNによるスコア計算の3つのステップをとる.
サブグラフの抽出
2つのノードu, vからk-hop離れたノードの集合をN(u),N(v)とする.
このとき,以下の条件を満たすノードのサブグラフを考える.
- $N(u) \cap N(v)$
- ノードu, vの両方からk-hop以内
ノードラベリング
サブグラフ上のノードiに対して,$(d(i, u), d(i, v))$を考える.
ここで,$d(i, u)$はノードi, u間の最短経路である.
ノードu, v は,それぞれ $(0, 1)$, $(1, 0)$とそれぞれ表現する.
それぞれのベクトルに対して,$[OneHot(d(i, u)) \oplus OneHot(d(i, v))]$のようにone-hot表現にし,結合することによって各ノードをサブグラフ上のトポロジカルな位置として表現する.
GNN によるスコア計算
Message-Passing Scheme [Xu 2019]
隣接するノードの表現を繰り返し集約し更新していくことで,ノードを表現する.
$N(t)$をノードtの隣接ノードの集合としたとき,ノードtの隣接ノードから集約された表現 ($a_t^k$)と,ノードtにおけるk層目の表現 ($h_t^k$) は以下のように与える.
このとき,$h_i^0$はノードラベリングによって得た特徴量で初期化する.
$$
a_t^k = AGGREGATE^k(\set{h_s^{k-1}: s \in N (t) }, h_t^{k-1} )
$$
$$
h_t^k = COMBINE^k(h_t^{k-1}, a_t^k)
$$
Attention [Schlichtkrull 2017]
アテンションを追加し,AGGREGATE関数を次のように定義する.
$$
a_t^k = \sum_{r=1}^{R}\sum_{s \in N_{r}(t)} \alpha_{rr_tst}^k W_r^k h_s^{k-1}
$$
$R$: ナレッジグラフにおけるrelationの総数
$N_r(t)$: ノードtがrelation r で隣接するノードの集合
$W_r^t$: relation r における k層目のメッセージを伝達する変換行列
$\alpha_{rr_tst}^k$: ノードs, t間のrelation r のk層目のアテンションの重み
また,この重みは次のように与える.
$$
\alpha_{rr_tst}^k = \sigma(A_2^ks + b_2^k)
$$
$$
s = ReLU(A_1^k[h_s^{k-1} \oplus h_t^{k-1} \oplus e_r^a \oplus e_{r_t}^a ] + b_1^k)
$$
$e_r^a, e_{r_t}^a$: それぞれのrelationにおけるアテンションの埋め込み
Basis Sharing Mechanism [Schlichtkrull 2017]
COMBINE関数は,次にように与える.
また,Edge Dropoutを行うことで,集約しつつエッジをランダムに削除する.
$$
h_t^k = ReLU(W_{self}^k h_t^{k-1} + a_t^k)
$$
$W_{self}^k$: 各層の変換行列
サブグラフの表現
L層のMessage Passingの後,サブグラフ上のノードの集合を $\nu$ として,サブグラフの表現は次のように与える.
$$
h_{G_{(u, v, r_t)}}^L = \frac{1}{|\nu|} \sum_{i \in \nu} h_i^t
$$
リンク予測におけるスコア
サブグラフの表現($h_{G_{(u, v, r_t)}}^L$),対象ノードの表現 ($h_u^L$, $h_v^L$),対象のrelationの埋め込み ($e_{r_t}$) を結合することで,次のようにスコアを定義する.
$$
score(u, r_t, v) = W^T[ h_{G_{(u, v, r_t)}}^L \oplus h_u^L \oplus h_v^L \oplus e_{r_t}]
$$
また,JK-Connection Mechanism (Xu 2018) にように,断片的な層の表現を用いることで,パフォーマンスを向上させる.
学習
トリプルの先頭または末尾をランダムに別のエンティティと入れ替えることで,negative tripletをサンプルする [Bordes2013].
ロス関数を次のように与え,確率的勾配降下法によって学習を行う.
$$
L = \sum_{i=1}^{|\epsilon|} max(0, score(n_i) - score(p_i)+\gamma)
$$
$\epsilon$: 学習グラフにおけるエッジ/トリプレットの全ての集合
$p_i, n_i$: positive tripletと,negative triplet
$\gamma$: margin hyperparameter
終わりに
GraILと呼ばれるリンク予測モデルについてまとめた.
一方で,説明が不十分な点が残っており今後も追記していく.