0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

書籍「グラフ深層学習」を参考にGNNのグラフ埋め込みをやってみた_part2

Last updated at Posted at 2024-06-02

本記事の概要

  • 前回の記事に引き続き、GNNのグラフ埋め込みをpythonでの実装も含めてやってみたよ
    • 今回は グラフの大域情報 を反映する手法を扱ったよ
    • GNN のライブラリは使わずにやったよ
  • 書籍「グラフ深層学習」の4章を参考にしているよ
  • 簡単な理論とコードを載せているよ
  • 僕と同じくGNNビギナーの方の参考になればうれしいよ

モチベーション

前回、グラフ埋め込みの基本的な手法である DeepWalk について、簡単な理論とpythonコードによる解説を行いました。

参考書籍は引き続き「グラフ深層学習(2023, ヤオ マー &ジリアン タン)」です。
https://www.amazon.co.jp/dp/4910612122

DeepWalkは、グラフの近傍情報(接続)をもとにノードの特徴を取得し、埋め込み空間へと移す手法です。

一方で、接続していなくてもグラフ全体で見ると類似した(立ち位置が似ている)ノードというものも存在します。
埋め込みにはそのような情報も反映されている方が好ましい場合もありそうです。

今回は近傍情報だけでなくグラフの大域的な情報も考慮した埋め込み手法について、書籍「グラフ深層学習」とその参考文献をもとに解説します。

読んでいただきたい方

  • GNN に興味があるけど勉強したことのない方
  • GNN を勉強し始めた方
  • 前回の記事を読んだよ、という方
  • GNN に詳しく「こいつちゃんとわかってるかどうかを確かめてやろう」という猛者

お断り

「グラフ深層学習」から(直接・間接)引用する場合には、
「グラフ深層学習」p.20
のようにして出典とさせていただきます。簡易的であることをお許しください。
(数式部分に関しては、都度出典を明記してはいません。この点もお許しください。)
また、本記事でのグラフは「無向グラフ」であることを前提にします。

グラフ埋め込みの簡単なお気持ち(復習)

前回の記事で解説した通り、グラフ埋め込みは
  グラフの各ノードを低次元ベクトル表現に写像すること
「グラフ深層学習」p.93
を目指します。
グラフには様々な特徴があり、後のタスクにおいて有用な特徴を抽出した埋め込み表現を得ることが目標です。

Deep Walk アルゴリズム

グラフの特徴のひとつに、ノード同士の位置関係(近さ)があります。近くにあるノードは埋め込み空間でも近くにいてほしい、と考えるのは自然なことです。
よって、ノードの位置関係を抽出して埋め込み表現を得る手法のうち、ベーシックなものが Deep Walkアルゴリズム でした。

前回の記事 でも解説しましたがDeep Walk アルゴリズムの概要は以下の通りです。

  1. 各ノードをスタートとするウォーク(ノードを遷移する道のようなもの)をランダムにたくさん生成する
  2. 各ウォークにおいて、適当な幅内に存在するノードの組(共起しているノード組)を抽出する
  3. 抽出頻度をできるだけ再現できるように埋め込み表現を学習する

前回の記事でも提示した、グラフ[1] とその埋め込みベクトルの可視化の具体例を示します。

グラフ01.png

グラフ03.png

グラフ上で近いノードが、埋め込み空間上でも近くなるように表現されていることがわかります。

pythonの実装コードも前回の記事に示しているのでご参考ください。

Deep Walk アルゴリズム の限界

近いノードが埋め込み空間で近くに位置しているのは良いことなのですが、グラフ全体におけるノードの立ち位置を意識すると、また少し見方が変わってきます。

上のグラフでは、2 と 4 がやや中心的なノードになっていて、1 と 3 は 2 に接続、5 は 4 に接続しています。1,3,5 いずれも、それ以外のノードとの接続はありません。
この場合、1,3,5 は 中心的なノードにのみ接続しているという点で共通した立ち位置にある ように感じられます。
一方、埋め込み空間においては (1,3) と 5 は大きく離れていて、上記のような立ち位置が反映されているとはいえません。

ノードの特徴として、各ノードの接続や近さだけでなく、グラフの大域的な情報も有用だと考えられます。
よって、このような情報も反映させた埋め込み表現の学習も行われています。今回の記事は、そのような学習方法に関する解説です。

グラフの大域情報

まずは、グラフの大域情報を定式化します。
この点について「グラフ深層学習」では、第2章で触れられています。そのうち中心性に関する情報のいくつかをご紹介します。

中心性:各グラフの重要度を表す指標
「グラフ深層学習」p.28

以下、この中心性を表す値を「グラフ深層学習」にならって中心性スコアと呼びます。

次数中心性

各ノードの 接続しているノードの数次数 といいます。
たとえばグラフ[1]における各ノードの次数は次のようになります。

ノード番号 1 2 3 4 5 6 7
次数 1 3 1 4 1 2 2

次数が近いノードは、それぞれのノードの周辺での立ち位置がよく似ている可能性があります。よって、埋め込み表現を学習する際に次数を考慮することで立ち位置を反映した表現になる効果がありそうです。

固有ベクトル中心性

次数中心性において、グラフ[1]の 1, 3, 5 はすべて次数 1 なので同等です。
一方で、(1, 3) が接続している 2 と、5 が接続している 4 を比べるとどうでしょうか。

次のように感じる方もいると思います。(いてほしい、、、)

人物コメント01.png

…まあとにかく、(1, 3) と 5 は次数が同じでも若干の違いがありそうです(ということにします)。

このように、単純な次数ではなく、周囲のノードの中心性も考慮して中心性を測る方法として 固有ベクトル中心性 があります。

SNSでいうと、ユーザーを評価する際に、友達の人数だけでなく友達の中に巨大なインフルエンサーがいるかどうかを評価の対象にする、というイメージです(あまり気分の良くない表現を承知で書いてます笑)。

固有ベクトル中心性の(簡単な)定式化

グラフの隣接行列 $A \in \lbrace{0,1}\rbrace ^{N \times N}$ の固有ベクトル $\boldsymbol{c}$ は、次の方程式を満たすベクトル $\boldsymbol{c}$ として得られます。

\displaylines{
A\boldsymbol{c} = \lambda\boldsymbol{c}
}
  • $A$ は対称行列なので、重複を許して固有値 $\lambda_{1}, \cdots, \lambda_{N}$ はすべて実数です。

このとき、次のことが言えるそうです。

固有値のうち絶対値が最大のものに対応する固有ベクトルは、すべての成分が正となるように取れる

このような最大固有値に対応する固有ベクトルの成分は、中心性スコアとして望ましい性質を持っていると考えられます。

  1. 最大固有値に対応しているので、多くの情報量をもっている
  2. すべての値が正なので、そのまま重要度の尺度として扱いやすい

よって、このように
 最大固有値に対応する固有ベクトルの成分
を中心性スコアに用いるが、固有ベクトル中心性です。

上記の最大固有値に関する性質は、ペロン・フロベニウスの定理から分かるそうです。
数学的な詳細は、こちらの記事が詳しいようです。
https://zenn.dev/nagayu71/articles/247cdd62eece92

固有ベクトル中心性の具体例

グラフ[1]の隣接行列において、最大固有値の固有ベクトル成分は次のようになります(ライブラリによってはマイナスがついた値が出力される場合がありますが、その場合はすべてに$-1$をかければよいです)。

ノード番号 固有ベクトル成分
1 0.15
2 0.38
3 0.15
4 0.61
5 0.25
6 0.43
7 0.43

4 と 2 のスコアが大きいことに加えて、1,3 のスコアと 5 のスコアに差があります。これは先ほど述べたノードの重要度の直観に沿ったものです。

その他の中心性

「グラフ深層学習」では、Katz中心性媒介中心性といった中心性に触れられています。
興味のある方は、一度書籍をご覧いただければと思います。

大域情報を保存する埋め込み表現の学習アルゴリズム(LOG)

「グラフ深層学習」では、大域情報を保存するためのアルゴリズムとして以下の論文による提案手法が紹介されています。

Preserving Local and Global Information for Network Embedding(Ma et al., 2017)

ここからは、この論文をもとに提案手法を簡単に解説します。

Ma et al. は、大域情報と局所情報(前回記事の共起情報)の両方を保存する
 LOGフレームワーク (Local and Global Information Preserved Embedding)
を提案しています。

LOG は、Deep Walk などで抽出される共起情報に加えて、中心性スコアをもとにしたノードの相対順位の保存を目指します。
「グラフ深層学習」p.113

前提

前回の記事で紹介した通り、Deep Walk アルゴリズムによる学習は以下の用意に行われていました。

  • ランダムウォークから得られた共起リストをもとに、すべての組 $(v_{con} | v_{cen})$ の共起情報を抽出(頻度を求めて確率として扱う)
  • 埋め込みドメインの行列 $W_{cen}$ および $W_{con}$ をもとに、モデル$p(v_{con} | v_{cen})$ を構成
  • 共起情報と $p(v_{con} | v_{cen})$ から目的関数をもとに $W_{cen}$ および $W_{con}$ を学習

$W_{cen}$ および $W_{con}$ の学習過程に、相対順位を保存するアルゴリズムを追加していきます。

相対順位のモデル化

全ノードに大域的な重要度スコア(たとえば固有ベクトル中心性スコア)が割り振られているとき、この値をもとにした相対順位を考えます。
この相対順位に関する確率は、次のようにモデル化することができます。

\displaylines{
p_{\mathrm{global}} = \prod_{1 \leq i < j \leq N}p(v_{i}, v_{j})
}

$p(v_{i}, v_{j})$は、ノード $v_{i}$ が $v_{j}$ よりも上位に順位付けされる確率です。

埋め込み空間の表現を用いて、この確率 $p(v_{i}, v_{j})$ は次のようにモデル化することができます。

\displaylines{
p(v_{i}, v_{j}) = \sigma(f(\boldsymbol{u}_{i|cen}, \boldsymbol{u}_{i|con})-f(\boldsymbol{u}_{j|cen}, \boldsymbol{u}_{j|con}))
}

$\boldsymbol{u_{*|cen}}$ と $\boldsymbol{u_{*|con}}$ はノード$v_{*}$ の $W_{cen}$、$W_{con}$ における埋め込みベクトルです。これを何かしらの関数 $f(x)$ で変換します。
また $\sigma(x)$ はシグモイド関数です。$\boldsymbol{w}$は$\boldsymbol{u_{i}}, \boldsymbol{u_{j}}$と一緒に学習されるパラメータです。確率のモデル化ですから、シグモイド関数を使うのは自然な流れですね。

関数 $f$ による変換について、Ma et al. では例として次のような線形変換を挙げています。

\displaylines{
\boldsymbol{w}^{T}\boldsymbol{u_{i|con}} + \boldsymbol{w'}^{T}\boldsymbol{u_{i|cen}}
}

$\boldsymbol{w}$、$\boldsymbol{w'}$は学習対象のパラメータです。後述の実装では、この例に従って実装しています。

目的関数

相対順位の確率モデルをもとに、次の目的関数を最小化することで、大域情報の保存を目指すことができます。

\displaylines{
L_{\mathrm{global}} = -\log{p_{\mathrm{global}}}
}

また、DeepWalkアルゴリズムの 目的関数は以下のように定式化されていました。

\displaylines{
L_{\mathrm{local}} = -\sum_{v_{con} \in V}{\sum_{v_{cen} \in V}{\frac{\mathrm{exp}(f_{con}(v_{con})^{T}f_{cen}(v_{cen}))}{\sum_{v\in V}\mathrm{exp}(f_{con}(v)^{T}f_{cen}(v_{cen}))}}}
}

LOGの学習アルゴリズムにおける目的関数は、これら2つの目的関数を用いて次のように定式化されます。

\displaylines{
L = L_{\mathrm{local}} + \lambda L_{\mathrm{global}}
}

$\lambda$ は $L_{\mathrm{global}}$ をどれくらい重要視するかを決める、ハイパーパラメータです。

LOGアルゴリズムのまとめ

いったん、LOGアルゴリズムを整理します。

  1. 共起情報を学習するアルゴリズムに、大域情報を保存するアルゴリズムを追加したものである
  2. 大域情報はノードの相対順位で定式化される
  3. 相対順位は(シグモイド関数を用いた)確率でモデル化される
  4. 目的関数では、大域情報をどれくらい重要視するかをハイパーパラーメータによって調整する

目的とアルゴリズムの対応が、結構シンプルで明確だと思われます。
次項では、このLOGアルゴリズムを実装し、埋め込みベクトルがどのように学習されるかを確認します。

LOGアルゴリズムの実装

ここからは LOGアルゴリズムの実装を通して、理解を深めていけたらと思います。

前回記事と同様に、以下の方針で実装します。

  • pythonを使用
  • GNN 用のライブラリは使用しない
    • ただし、学習プロセスにおいて pytorch を使用

扱うグラフ

簡単な例として、次の グラフ [2] を埋めこみます。
グラフ07.png

15個のノードを、3つのグループとそれをつなぐノードに分類しています。

実装コード

まずは、学習完了までのコード全体を示します。

論文で提案されているアルゴリズムから、最適化を行うタイミングなどを若干変えています。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

# シードの設定
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# ハイパーパラメータの設定
n = 15  # ノード数
conb_list = [
    (1, 2), (2, 3), (2, 4),  # グループ1
    (9, 5), (9, 12), (5, 6), (5, 8),  # グループ2
    (12, 13), (12, 14), (12, 15),  # グループ3
    (2, 10), (10, 11), (11, 9), (7, 5), (11, 5)  # グループ間のエッジ
]

walk_iter = 30  # ランダムウォークの反復回数
walk_length = 10  # ランダムウォークの長さ
window_size = 3  # 共起性抽出のウィンドウサイズ
embedding_dim = 2  # 埋め込み次元
learning_rate = 0.03  # 学習率
lanmda = 0.0  # 大域情報を重要視する程度
epochs = 500  # エポック数

# 関数定義

def create_adjacency_matrix(n, conb_list):
    """
    隣接行列を作成する関数
    """
    adjacency_matrix = np.zeros((n, n))
    for conb in conb_list:
        adjacency_matrix[conb[0]-1, conb[1]-1] = 1
        adjacency_matrix[conb[1]-1, conb[0]-1] = 1
    return adjacency_matrix

def random_walk(adjacency_matrix, start_node:int, walk_length:int) -> list:
    """
    指定したノードからランダムウォークを実行する関数
    """
    walk = [start_node]
    for i in range(walk_length):
        next_node = np.random.choice(np.where(adjacency_matrix[walk[-1]-1] == 1)[0]) + 1
        walk.append(next_node)
    return walk

def random_walk_all_node(adjacency_matrix, iteration:int, walk_length:int) -> list:
    """
    全ノードからランダムウォークを実行する関数
    """
    R = []
    for node in range(1, n+1):
        walks = []
        for i in range(1, iteration+1):
            walk = random_walk(adjacency_matrix, node, walk_length)
            walks.append(walk)
        R = R + walks
    return R

def extract_cooccurrence(R, window_size:int) -> list:
    """
    ランダムウォークの結果から共起性を抽出する関数
    """
    cooccurrence = []
    for walk in R:
        for i in range(len(walk)):
            for j in range(i+1, min(i+window_size+1, len(walk))):
                cooccurrence.append([walk[i], walk[j]])
            for j in range(i-1, max(i-window_size-1, -1), -1):
                cooccurrence.append([walk[i], walk[j]])
    return cooccurrence

def create_cooccurrence_matrix(cooccurrence, n:int) -> np.ndarray:
    """
    共起性データから共起行列を作成する関数
    """
    cooccurrence_matrix = np.zeros((n, n))
    for c in cooccurrence:
        cooccurrence_matrix[c[0]-1, c[1]-1] += 1
        cooccurrence_matrix[c[1]-1, c[0]-1] += 1
    return cooccurrence_matrix

# 隣接行列の作成
adjacency_matrix = create_adjacency_matrix(n, conb_list)

# 全ノードからランダムウォークを実行
R = random_walk_all_node(adjacency_matrix, walk_iter, walk_length)

# 共起性の抽出
cooccurrence = extract_cooccurrence(R, window_size)

# 共起行列の作成
co_occurrence_matrix = create_cooccurrence_matrix(cooccurrence, n)

# 共起行列の行ごとに確率分布に変換
co_occurrence_probs = co_occurrence_matrix / co_occurrence_matrix.sum(axis=1, keepdims=True)
co_occurrence_probs_tensor = torch.tensor(co_occurrence_probs, dtype=torch.float32)


# 大域スコアの取得

def create_score_by_eigen(adjacency_matrix):

    # 隣接行列の固有値ベクトル
    eigenvalues, eigenvectors = np.linalg.eig(adjacency_matrix)

    # 最大固有値に対応する固有ベクトルの取得
    max_eigenvalue_index = np.argmax(eigenvalues)
    score = eigenvectors[:, max_eigenvalue_index]

    # 固有ベクトルの符号を正に揃える
    if np.any(score < 0):
        score = -score

    score = np.array(score)
    score = torch.tensor(score, dtype=torch.float32)

    # scoresの行列を作成
    score_matrix = score.unsqueeze(0).repeat(adjacency_matrix.shape[0], 1)

    # scoresの差分行列を作成
    score_diff = score_matrix.t() - score_matrix

    # 条件を満たすインデックスを取得
    condition_indices = (score_diff >= 0).nonzero(as_tuple=True)

    return condition_indices


# 相対順位に関するインデックスを作成
condition_indices = create_score_by_eigen(adjacency_matrix)

# 埋め込み行列の定義
W_cen = nn.Parameter(torch.randn(co_occurrence_matrix.shape[0], embedding_dim))
W_con = nn.Parameter(torch.randn(co_occurrence_matrix.shape[0], embedding_dim))

# パラメータの定義
w_1 = nn.Parameter(torch.randn(embedding_dim))
w_2 = nn.Parameter(torch.randn(embedding_dim))

# 最適化手法の定義
optimizer = optim.Adam([W_cen, W_con, w_1, w_2], lr=learning_rate)

# クロスエントロピー損失関数の定義
loss_function = nn.CrossEntropyLoss()


# 学習ループ
previous_loss = float('inf')
patience = 10  # 何エポック連続で損失が改善しない場合に学習を停止するか
trigger_times = 0

for epoch in range(epochs):
    optimizer.zero_grad()

    # 共起行列の再構築
    logits = torch.matmul(W_cen, W_con.t())

    # ソフトマックスを適用して確率に変換
    log_probs = nn.functional.log_softmax(logits, dim=1)

    # Fの差分行列を計算
    F = torch.mv(W_cen, w_1) + torch.mv(W_con, w_2)
    F_diff = F.unsqueeze(1) - F.unsqueeze(0)

    # 条件を満たすペアに対してシグモイドを計算
    log_sigmoid = nn.functional.logsigmoid(F_diff[condition_indices])

    # 損失の計算
    loss = loss_function(log_probs, co_occurrence_probs_tensor) - lanmda * torch.sum(log_sigmoid)

    # 勾配の計算とパラメータの更新
    loss.backward()
    optimizer.step()

    # 進捗表示
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}')

    # 損失の変化を確認
    if loss.item() < previous_loss:
        previous_loss = loss.item()
        trigger_times = 0
    else:
        trigger_times += 1

    # 一定エポック連続で改善がない場合、学習を停止
    if trigger_times >= patience:
        print(f'Early stopping at epoch {epoch + 1}')
        break

Deep Walk アルゴリズムに追加した点

Ma et al. を参考に、Deep Walk アルゴリズムに追加した点を示します。

  • パラメータ$\lambda$の設定
lanmda = 0.0  # 大域情報を重要視する程度
  • 相対順位を抽出するためのindexを取得する関数
# 大域スコアから順序情報の取得

def create_score_by_eigen(adjacency_matrix):

    # 隣接行列の固有値ベクトル
    eigenvalues, eigenvectors = np.linalg.eig(adjacency_matrix)

    # 最大固有値に対応する固有ベクトルの取得
    max_eigenvalue_index = np.argmax(eigenvalues)
    score = eigenvectors[:, max_eigenvalue_index]

    # 固有ベクトルの符号を正に揃える
    if np.any(score < 0):
        score = -score

    score = np.array(score)
    score = torch.tensor(score, dtype=torch.float32)
    print(score)

    # scoresの行列を作成
    score_matrix = score.unsqueeze(0).repeat(adjacency_matrix.shape[0], 1)

    # scoresの差分行列を作成
    score_diff = score_matrix.t() - score_matrix

    # 条件を満たすインデックスを取得
    condition_indices = (score_diff >= 0).nonzero(as_tuple=True)

    return condition_indices

# 相対順位に関するインデックスを作成
condition_indices = create_score_by_eigen(adjacency_matrix)
  • 相対順位のindexをもとにlog-シグモイド関数の値を計算
    # Fの差分行列を計算
    F = torch.mv(W_cen, w_1) + torch.mv(W_con, w_2)
    F_diff = F.unsqueeze(1) - F.unsqueeze(0)

    # 条件を満たすペアに対してシグモイドを計算し、和を取る
    log_sigmoid = nn.functional.logsigmoid(F_diff[condition_indices])
  • 損失関数の計算
    # 損失の計算
    loss = loss_function(log_probs, co_occurrence_probs_tensor) - lanmda * torch.sum(log_sigmoid)

結果の可視化

$\lambda$ の値を変化させて、埋め込み結果を可視化します。

λ=0 の場合

$\lambda=0$ の場合は大域情報を全く考慮しない、すなわち Deep Walk アルゴリズムと同じです。
グラフ08_加工.png

色が完全に分かれていて、グループごとにノードが固まっています。Deep Walk の特徴ですね。

λ=0.1 の場合

$\lambda=0.1$ に設定し、大域情報を少し保存するように学習させると、次のようになりました。

グラフ09_加工.png

グループごとに分かれてはいますが、$\lambda=0$ と比べて以下の点が変化しています。

  • グループごとの次数 $1$ のノード同士が近づいている
  • 全体的に木構造のようになっている
    • 順位を保存しようとしていることがわかる

λ=0.7 の場合

大域情報を多めに保存するように学習させた場合です。

グラフ10_加工.png

$\lambda=0.1$ の場合に挙げた点が、より顕著になったように感じます。

固有ベクトル中心性のスコアの確認

今回のグラフで、固有ベクトル中心性のスコアは以下のようになっていました。

ノード番号 固有ベクトル成分
1 0.046
2 0.125
3 0.046
4 0.046
5 0.547
6 0.199
7 0.199
8 0.199
9 0.462
10 0.207
11 0.442
12 0.279
13 0.102
14 0.102
15 0.102

ノード $5$ のスコアが一番高いため、$5$ を頂点とするようなイメージで埋め込まれたのだと考えられます。

発展事項

当然ながら、中心性の計算方法によって、埋め込み結果はいろいろと変わり得るのだと思います。
また、コミュニティ構造の保存など、別の観点での情報を保存するような手法もあるようです。

まとめ

以上、GNNビギナーによるグラフ埋め込みpart2の解説でした。
目的を意識しながら、アルゴリズムの改善部分を読み解くと理解しやすいように思います。

感想・ご指摘などいただけると大変うれしいです。
そして、ぜひ「グラフ深層学習」をまだ購入していない方は購入し、みんなで読み進めていきましょう!

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?