LoginSignup
1
2

書籍「グラフ深層学習」を参考にGNNのグラフ埋め込みをやってみた

Posted at

本記事の概要

  • GNNのグラフ埋め込みをpythonでの実装も含めてやってみたよ
    • GNN のライブラリは使わずにやったよ
  • 書籍「グラフ深層学習」の4章を参考にしているよ
  • 簡単な理論とコードを載せているよ
  • 僕と同じくGNNビギナーの方の参考になればうれしいよ

モチベーション

グラフニューラルネットワーク(GNN)について耳にする機会が増えたこと、また今年に入って書籍が続けて発行されたことから、GNN勉強してみたい!という人が増えているのではないでしょうか。

わたしもその一人で、独学ですが「グラフ深層学習(2023, ヤオ マー &ジリアン タン)」を読み進めています。

「グラフ深層学習」はグラフ理論の紹介から始まり、4章でグラフの特徴を抽出する グラフ埋め込み を取り上げています。

LLMなどでも使われる「埋め込み(embedding)」ですが、グラフにおいても特徴抽出の手法(教師なし学習)として用いられるようです。
「グラフ深層学習」でも自然言語の埋め込みと同じような考えをもとに書かれており、初心者のわたしでもわかりやすく説明されていたので、自らの理解の定着もかねて記事を書くことにしました。

訳者の方による書籍の内容の記事がQiitaにありました。
こちら
こちらを読んで理解できるかたは、本記事で得られることはあまりない気がします笑

読んでいただきたい方

  • GNN に興味があるけど勉強したことのない方
  • GNN を勉強し始めた方
  • GNN に詳しく「こいつちゃんとわかってるかどうかを確かめてやろう」という猛者

お断り

「グラフ深層学習」から(直接・間接)引用する場合には、
「グラフ深層学習」p.20
のようにして出典とさせていただきます。簡易的であることをお許しください。
(数式部分に関しては、都度出典を明記してはいません。この点もお許しください。)
また、本記事でのグラフは「無向グラフ」であることを前提にします。

グラフ理論のさわり

まずはグラフ理論の基本を簡単に確認します。

グラフの定義

グラフ理論のグラフは、ノード(頂点)とエッジ(辺)からなります。
都市同士を道で結んだグラフであれば、都市がノード、道がエッジです。
SNSの友人関係をグラフで表す場合であれば、人がノード、友人であるという関係がエッジになります。

これを数学的に抽象化して、次のように定義されます。

定義

グラフは $G=\lbrace V, E \rbrace$ と表わされる。
$V=\lbrace v_{1}, v_{2}, …,v_{N} \rbrace$ は $N=|V|$個のノードからなる集合で、$E=\lbrace e_{1}, e_{2}, …,e_{M} \rbrace$ は $M=|E|$個のエッジからなる集合である。
「グラフ深層学習」p.20

グラフの例

たとえば、次のような グラフ[1] を考えます。
グラフ01.png

このグラフは丸がノード、線がエッジを表していて、$V$ と $E$ はそれぞれ次のようにかくことができます。

\displaylines{
V=\lbrace{1, 2, 3, 4, 5, 6, 7}\rbrace \\
E=\lbrace{(1, 2), (2, 3), (2, 4), (4, 5), (4, 6), (4, 7), (6, 7)}\rbrace
}

ここで、エッジに向きはないので $(1, 2)$ と $(2, 1)$ は区別されません。

このように、グラフをノードとエッジの集合を用いて表す(定義する)のがグラフ理論でのよくある表現のようです。

隣接行列

集合で定義されたグラフですが、行列と対応させることもできます。

隣接行列の定義

あるグラフ $G$ に対応する隣接行列を $A\in\lbrace{0, 1\rbrace}^{N \times N}$ と表す。
$A$ の $i$ 行 $j$ 列の要素 $A_{i, j}$ は、ノード $i$ がノード $j$ と接続しているときに $1$、接続していないときに $0$ をとります。
「グラフ深層学習」p.21

グラフ[1] の隣接行列は次のようになります(一度確認してみてください)。

\begin{pmatrix}
0 & 1 & 0 & 0 & 0 & 0 & 0\\
1 & 0 & 1 & 1 & 0 & 0 & 0\\
0 & 1 & 0 & 0 & 0 & 0 & 0\\
0 & 1 & 0 & 0 & 1 & 1 & 1\\
0 & 0 & 0 & 1 & 0 & 0 & 0\\
0 & 0 & 0 & 1 & 0 & 0 & 1\\
0 & 0 & 0 & 1 & 0 & 1 & 0
\end{pmatrix}

なお、ノード 1 がノード 2 と接続しているとき、当然ノード 2 もノード 1 と接続しています。よって、隣接行列 $A$ は対称行列になります。

グラフの性質

グラフ[1] を見ると、いくつかの特徴がみられます。

  • ノード 1 や 3 は 2 とだけつながっている。ノード 5 も同様。
  • ノード 2 や 4 は複数のノードとつながっている。

特に、ノード 4 は「中心的な役割を果たしているノード」のように見えます。

ここでは詳細は触れませんが、ノードの特徴を表す性質として以下のようなものがあります。

  • 次数:そのノードが隣接する他のノードの個数
  • 近傍:そのノードが隣接する全ノードの集合
    「グラフ深層学習」p.22, 23

また、ノードとノードを結ぶ道のようなものを考察対象にすることもあり、次のように定義されます。

  • ウォーク:ノードとエッジが交互で現れる列
  • トレイル:同じエッジを通らないウォーク
  • パス:同じノードを通らないウォーク
    「グラフ深層学習」p.24

この辺りは本記事の中では触れませんが、おそらく最適化とかの文脈で出てくるのかなと想像しています。

まとめ

グラフ理論のさわりを簡単に紹介しました。
「グラフ深層学習」では、他にもラプラシアン行列や固有値ベクトルを用いたスペクトルグラフ理論などが解説されています。GNN について学習する際には確実に押さえておくほうがよさそうです。

グラフ埋め込み

ここからは、グラフ埋め込みの目標、一般的な学習フレームワーク、具体的な学習フレームワーク例についてかきます。

グラフ埋め込みの目標

グラフ埋め込みは、
 グラフの各ノードを低次元ベクトル表現に写像すること
「グラフ深層学習」p.93
を目指しています。
GNN では、グラフを何かしらのベクトル表現によって表し、それを深層学習のフレームワークで使います。よって、タスクに応じた有用なベクトル表現を得ることが重要です(まだ深層学習の段階までは読み進められていない、、、)。
都市、SNNの例を挙げましたが、グラフを用いた現実のタスクではノードの数はかなり大きくなるのが通常です。それによって、たとえば隣接行列はサイズが大きい疎な行列になると考えられます。
そのような行列を低次元ベクトル表現に写像することで、その後の深層学習のフレームワークで扱いやすくすることができます。(このあたりは、自然言語処理の文脈と同じですね!)

ベクトル表現を得るうえで重要なのは元のグラフの情報をできるだけ保存することです。単に低次元に落とすのではなく、その後のタスクで使用したい情報をなるべく保存できる写像が良い写像といえます。

グラフ埋め込みのアルゴリズム

グラフ埋め込みのアルゴリズムを考える際に、次に 2 つのドメインを意識すると理解がしやすくなります。

  1. グラフドメイン:エッジの繋がり(グラフ構造)によってノードを表現
  2. 埋め込みドメイン:連続値を要素にもつベクトルによってノードを表現
    「グラフ深層学習」p.93

グラフ埋め込みは 1 → 2 への写像です。目的に合わせて 1 の情報をなるべく保存して写像することを目指して、写像を学習します。

一般的な学習フレームワークとしては、以下のようになるようです。
グラフ02.PNG
「グラフ深層学習」p.94 をもとに作成

  • グラフドメインから埋め込みドメインへ写像
    • この写像は マッピング関数 とよばれます。
  • 元のグラフからは保存したい情報 $I$ を抽出
  • 埋め込み表現からは保存したい情報を再構築した $I'$ を取得
  • $I$ と $I'$ をもとに目的関数を作成
  • 目的関数を最適化することで写像を学習

考え方は、よくある教師なし学習と同じですね。
$I$ は埋め込み後のタスクに合わせて保存したい情報を適切に選択するのが重要だと思われます。

具体例:DeepWalk アルゴリズム

ノード同士の関係性をとらえる視点の一つが「ノードの共起性」です。
グラフ上のウォークをランダムに生成したとき、近くにあるノードは共起する可能性が高いと考えられます。
このような共起性をもとにした埋め込み表現であれば、元のグラフの情報のうちノード同士の関係を保存した表現を得ることができていそうです。

ノードの共起性を保存する、代表的なグラフ埋め込みアルゴリズムの 1 つが DeepWalk です。
「グラフ深層学習」p.95
ここからは、DeepWalkを題材に、グラフ埋め込みの具体的なフレームワークをみていきましょう。

マッピング関数

マッピング関数は次のようになります。
$$f(v_{i})=\boldsymbol{u_{i}}=\boldsymbol{e_{i}}^{T}W$$
$e_{i}$ は $i$ 番目の要素のみが $1$ のワンホットベクトルです。
また $W\in\mathbb{R}^{N\times d}$ は、埋め込み次元を $d$ とする、学習対象の行列(パラメータ)です。

ランダムウォーク

共起性に関する情報を抽出するもっとも一般的な方法は、ランダムウォークです。
「グラフ深層学習」p.96

ランダムウォークは、次のようにしてウォークを生成します。

  1. ウォークの長さ $T$ を決める
  2. 始点ノード $v_{0}$ を決める
  3. $v_{0}$ からスタートしてグラフ上をエッジに沿ってランダムに移動し、$T$ 個のノードを訪れるまで繰り返す

このようにして得られたノードの列が「長さ $T$ のランダムウォーク」です。($T$ はハイパーパラメータです)

(具体例)

グラフ[1] で、始点をノード $1$ とした長さ $5$ のランダムウォークの例です。

 $1 → 2 → 4 → 5 → 4 $

言語モデルでいうところの 単語列 (文?)を生成したのと同じイメージだと理解しています。

ランダムウォークからの共起性の抽出

生成したランダムウォークから、言語モデルにおけるSkip-gramのアルゴリズムで共起性を抽出します。

  1. 中心ノード $v_{cen}$ を決める
  2. $v_{cen}$ の前後 $w$ 個のノードを 文脈ノード $v_{con}$ とする($w$ はハイパーパラメータ)
  3. ノードの組 $(v_{con}, v_{cen})$ を抽出し、共起リスト $I$ に追加していく

(具体例)

ランダムウォーク $1 → 2 → 4 → 5 → 4 $ から、前後 $w$ 個のノードの組を抽出したリスト

$$[(2, 1), (4, 1), (4, 2), (5, 2), …]$$

このようにしてつくられた $I$ には、ノードの共起性に関する情報が保存されていると考えられます。

共起性の再構築

ここからは埋め込みドメインの話です。
共起の組 $I$ は中心ノードと文脈ノードの組として抽出されています。よってマッピング関数としては、各ノードの、中心ノードとしての役割と文脈ノードとしての役割をもとにモデル化できるとよさそうです。

よって、2 つのマッピング関数を次のように定めます。

\displaylines{
f_{cen}(v_{i})=\boldsymbol{u_{i}}=\boldsymbol{e_{i}}^{T}W_{cen} \\
f_{con}(v_{i})=\boldsymbol{u_{i}}=\boldsymbol{e_{i}}^{T}W_{con}
}

$(v_{con}, v_{cen})$ の出現(頻度)に関するモデルは、ソフトマックス関数を用いた次のような条件付き確率で説明することができます。

\displaylines{
p(v_{con} | v_{cen})=\frac{\mathrm{exp}(f_{con}(v_{con})^{T}f_{cen}(v_{cen}))}{\sum_{v\in V}\mathrm{exp}(f_{con}(v)^{T}f_{cen}(v_{cen}))}
}

$p(v_{con} | v_{cen})$ が、グラフドメインからの情報 $I$ となるべく整合性が取れるように 2 つのマッピング関数を学習すればよさそうです。

つまり、以下のようなイメージで理解しています。

  • ランダムウォークから得られた共起リストをもとに、すべての組 $(v_{con} | v_{cen})$ の共起情報を抽出(頻度を求めて確率として扱う)
  • 埋め込みドメインの行列 $W_{cen}$ および $W_{con}$ をもとに、モデル$p(v_{con} | v_{cen})$ を構成
  • 共起情報と $p(v_{con} | v_{cen})$ から目的関数をもとに $W_{cen}$ および $W_{con}$ を学習

Deep Walk アルゴリズムのまとめ

Deep Walk アルゴリズムについて簡単に解説しました。
グラフ埋め込みのフレームワークを意識しながら追うことで、腑に落ちやすいのではないかと思います。

次のセクションでは、Deep Walk アルゴリズム の実装を行います。

Deep Walk アルゴリズムの実装

ここからは Deep Walk アルゴリズムの実装を通して、理解を深めていけたらと思います。

以下の方針で実装します。

  • pythonを使用
  • GNN 用のライブラリは使用しない
    • ただし、残念ながら完全スクラッチで実装するようなスキルはないため笑、学習プロセスにおいて pytorch を使用しています。

扱うグラフ

簡単な例として、グラフ [1] を埋めこみます。
グラフ01.png

ノードの関係性の特徴から、埋め込み後のベクトルが例えば次のようになっていることが期待されます。

  • (1,2,3) と (4,5,6,7) がそれぞれ近くにいる
  • 1 と 3 が近い
  • 6 と 7 が近い
    • 5 も近いが、6 と 7 とはやや離れている

実装コード

まずは、学習完了までのコード全体を示します。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

# シードの設定
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# グラフの設定
n = 7  # ノード数
conb_list = [[1, 2], [2, 3], [2, 4], [4, 5], [4, 6], [4, 7], [6, 7]]  # エッジの指定

# ハイパーパラメータの設定
walk_iter = 30  # ランダムウォークの反復回数
walk_length = 10  # ランダムウォークの長さ
window_size = 3  # 共起性抽出のウィンドウサイズ
embedding_dim = 2  # 埋め込み次元
learning_rate = 0.03  # 学習率
epochs = 200  # エポック数

# 関数定義

def create_adjacency_matrix(n, conb_list):
    """
    隣接行列を作成する関数
    """
    adjacency_matrix = np.zeros((n, n))
    for conb in conb_list:
        adjacency_matrix[conb[0]-1, conb[1]-1] = 1
        adjacency_matrix[conb[1]-1, conb[0]-1] = 1
    return adjacency_matrix

def random_walk(adjacency_matrix, start_node:int, walk_length:int) -> list:
    """
    指定したノードからランダムウォークを実行する関数
    """
    walk = [start_node]
    for i in range(walk_length):
        next_node = np.random.choice(np.where(adjacency_matrix[walk[-1]-1] == 1)[0]) + 1
        walk.append(next_node)
    return walk

def random_walk_all_node(adjacency_matrix, iteration:int, walk_length:int) -> list:
    """
    全ノードからランダムウォークを実行する関数
    """
    R = []
    for node in range(1, n+1):
        walks = []
        for i in range(1, iteration+1):
            walk = random_walk(adjacency_matrix, node, walk_length)
            walks.append(walk)
        R = R + walks
    return R

def extract_cooccurrence(R, window_size:int) -> list:
    """
    ランダムウォークの結果から共起性を抽出する関数
    """
    cooccurrence = []
    for walk in R:
        for i in range(len(walk)):
            for j in range(i+1, min(i+window_size+1, len(walk))):
                cooccurrence.append([walk[i], walk[j]])
            for j in range(i-1, max(i-window_size-1, -1), -1):
                cooccurrence.append([walk[i], walk[j]])
    return cooccurrence

def create_cooccurrence_matrix(cooccurrence, n:int) -> np.ndarray:
    """
    共起性データから共起行列を作成する関数
    """
    cooccurrence_matrix = np.zeros((n, n))
    for c in cooccurrence:
        cooccurrence_matrix[c[0]-1, c[1]-1] += 1
        cooccurrence_matrix[c[1]-1, c[0]-1] += 1
    return cooccurrence_matrix


# 隣接行列の作成
adjacency_matrix = create_adjacency_matrix(n, conb_list)

# 全ノードからランダムウォークを実行
R = random_walk_all_node(adjacency_matrix, walk_iter, walk_length)

# 共起性の抽出
cooccurrence = extract_cooccurrence(R, window_size)

# 共起行列の作成
co_occurrence_matrix = create_cooccurrence_matrix(cooccurrence, n)

# 共起行列の行ごとに確率分布に変換
co_occurrence_probs = co_occurrence_matrix / co_occurrence_matrix.sum(axis=1, keepdims=True)
co_occurrence_probs_tensor = torch.tensor(co_occurrence_probs, dtype=torch.float32)

# 埋め込み行列の定義
W_cen = nn.Parameter(torch.randn(co_occurrence_matrix.shape[0], embedding_dim))
W_con = nn.Parameter(torch.randn(co_occurrence_matrix.shape[0], embedding_dim))

# 最適化手法の定義
optimizer = optim.Adam([W_cen, W_con], lr=learning_rate)

# クロスエントロピー損失関数の定義
loss_function = nn.CrossEntropyLoss()



# 学習ループ
for epoch in range(epochs):
    optimizer.zero_grad()
    
    # 共起行列の再構築
    logits = torch.matmul(W_cen, W_con.t())
    
    # ソフトマックスを適用して確率に変換
    log_probs = nn.functional.log_softmax(logits, dim=1)
    
    # 損失の計算
    loss = loss_function(log_probs, co_occurrence_probs_tensor)
    
    # 勾配の計算とパラメータの更新
    loss.backward()
    optimizer.step()
    
    # 進捗表示
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')

プロセス全体は、これまで解説してきた内容に沿っています。何点か個別に取り上げます。

ランダムウォークの設定

  • walk_iter:1 つのノードにつき、ランダムウォークの取得個数(今回は30)
  • walk_length:ランダムウォークの長さ(今回は10)
  • window_size:共起性を抽出するウィンドウ幅(今回は前後3つ)

共起情報 I と再構築情報 I'

  • co_occurrence_probs
    • ランダムウォークから作成した共起性の情報 $I$
  • log_probs
    • 埋め込み W_cenW_con から再構築した情報 $I'$
    • logits = torch.matmul(W_cen, W_con.t()) は、$f_{con}(v_{con})^{T}f_{cen}(v_{cen})$ をまとめて行列計算している

結果の可視化

作成された埋め込みベクトルを可視化します。
2 つのマッピング関数のベクトルが求められていますので、今回は平均をとったベクトルを可視化します。
次のようになりました。

グラフ03.png

(1,2,3) と (4,5,6,7) の関係、1 と 3 の近さ、6 と 7 の近さなど、元のグラフの特徴がわりと抽出できていそうです。

なお、今回は学習エポック数を200にしていましたが、loss関数の減少を見ると50エポックくらいでほぼ収束していました。

別のグラフ

もう少し複雑なグラフでも実験してみます。次のようなグラフを使用します。(GPT に作ってもらいました。)

グラフ04.png

ノードの色は仮想のグループを想定しています。

コードの変更点

グラフの設定部分を、次のように変更すれば実験できます。

n = 21  # ノード数
conb_list = [
    (1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7),  # グループ1
    (8, 9), (8, 10), (9, 11), (9, 12), (10, 13), (10, 14),  # グループ2
    (15, 16), (15, 17), (16, 18), (16, 19), (17, 20), (17, 21),  # グループ3
    (4, 8), (5, 9), (6, 10), (7, 11),  # グループ間のエッジ
    (12, 16), (13, 17), (14, 18), (15, 19), (16, 20)  # グループ間のエッジ
]  # エッジの指定

可視化

埋め込みベクトルを可視化すると次のようになりました。

グラフ05.png

位置関係はおおむね保てているような気がします。
今回は、100エポックくらいで収束が確認できました。

発展事項

今回の解説、実験は非常に単純な例でした。発展事項としては以下のようなものが考えられるようです。

  • 学習の高速化
  • 複雑グラフへのあてはめ
  • 構造的な情報の抽出

このあたりはもう少し「グラフ深層学習」を読み進めて理解を深めていきたいと思っています。

まとめ

以上、GNNビギナーによるグラフ埋め込みの解説でした。
グラフ埋め込みは、言語モデルを意識しながら読み進めると理解がしやすいと思います。

感想・ご指摘などいただけると大変うれしいです。
そして、ぜひ「グラフ深層学習」をまだ購入していない方は購入し、みんなで読み進めていきましょう!

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2