5
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Graph Neural Network を学ぶ①

Last updated at Posted at 2024-06-18

はじめに

 こんにちは。今回は深層学習に関するお勉強の一環で「Graph Neural Network(GNN)」について自分なりに整理してみようと思います。
これまでCNNやTransformerなどのモデルについては触れたりしてたのですが、GNNはあまりその機会がなかったので勉強してみたいというモチベーションがありました。
(なので、記述・認識に誤りがあればご指摘いただけますと励みになります🙇)

1. Graph Neural Network(GNN)

1.1. グラフの構造を持つデータ

 実世界に存在するデータには、頻繁にグラフの構造を持つデータが存在しています(ex. SNSのユーザ間の関係、論文の引用関係、分子の持つ化学構造)。

最近論文調査を行う際にConnected Papersというサービスを使うことがあるのですが、こちらで確認できる論文同士の引用関係はまさにこの種類のデータの例と言えるわけです(下図)。

image.png

このようなグラフデータは、固定された順序や位置を持たない(非ユークリッド)データであり、深層学習で用いられる画像やテキストなどの固定の次元を持つ(ユークリッド)データと比較すると、データ構造の柔軟性が高いといえます。
複雑な関連性をモデルとして扱える一方で学習の難易度や計算コストなどの問題に対処するようなアルゴリズムの設計が必要になります。


 ここで、もう少し一般的なグラフの定義について、離散数学的(参考)にまとめると、

  • 1つのグラフ($G(V,E)$)は以下の2つの部分から構成されます。
  1. 頂点(vertex)、点(point)またはノード(node)と呼ばれるものの集合$V=(V(G)$
  2. 辺(edge)と呼ばれるものの集合$E=E(G))$(無機が付与されない(=順序のない)2つの頂点の対)

例えば、下図のグラフには4つのノード $V = \lbrace A, B, C, D \rbrace$ があります。
辺の集合 $E$ は以下のペアで構成されます。
$E = \lbrace (A, B)$,$(A, C)$,$(B, D)$,$(C, D) \rbrace$

1.2. Guraph Neural NetWork(GNN)とは

  グラフデータ上で動作するニューラルネットワークのことをGuraph Neural NetWork(GNN)といいます。冒頭にも記載したようなしたようなSNSのユーザ間の関係をはじめとする様々なケースで利用されています。

image.png

(画像引用元 Fig.6)

GNNでは、グラフで表現されるデータに対して、いろいろな種類の推論を行っていきます。詳細は後述しますが、ノードのクラス分類、クラスタリング、グラフのクラス分類、エッジの接続予測などがその例です。これらの推論では、グラフの構造の他に、最大2つの情報を取り出して、ネットワークに入力します(下記、先ほどのSNSのユーザ関係を例に記載)。

0)グラフの構造 (ex.上図(e))
1)ノードが保持する特徴量(ex.ユーザ(=ノード)ごとの年齢・性別・趣味等)
2)エッジに付帯する特徴量(ex.やりとりの頻度、いつからフォローしているか等)

ノードを$i$とした時に、1)ノードの保持する特徴量は$ \mathbf{x_i} \in \mathbb{R}^D $として表現されます。(D=特徴量の次元数)
今回参考にさせていただいた書籍では、2つの特徴量のうち、主に前者に焦点が当てられていました(後者はあまりメジャーではないとのこと)。
グラフ内の各ノードはそれぞれ特徴量ベクトルを持ち、これがニューラルネットワークのユニットとして機能します。

1.3. GNNの基本演算

 一概にGNNといっても、その内部で行われる基本的な演算やアーキテクチャは多岐にわたります。GNNの主要な目的は、グラフの構造情報やノード・エッジの特徴量を効果的に活用し、様々な推論タスクを遂行することです。そのために、多くの手法が提案されており、それぞれが異なる方法でノードの情報集約や更新を行います。

詳細

レビュー論文「A Comprehensive Survey on Graph Neural Networks」(Wu et al., 2021)に基づき、いくつかの主要なGNN手法を紹介します。

  1. リカレントGNN(Recurrent GNN, RecGNN)

    • ノード間のメッセージパッシングを繰り返し、ノードの状態が収束するまで情報を伝達します。
    • 例:GNN*(Scarselli et al., 2009)、GraphESN

  2. 畳み込みGNN(Convolutional GNN, ConvGNN)

    • グラフデータに対して畳み込み操作を一般化し、隣接ノードからの情報を集約します。
    • スペクトルベース:グラフのスペクトル特性を利用して畳み込みを定義します。例:Spectral CNN, ChebNet, GCN
    • 空間ベース:隣接ノードからの情報を直接集約します。例:GraphSage, GAT

  3. グラフオートエンコーダ(Graph Autoencoder, GAE)

    • ノードを低次元の潜在空間にエンコードし、そこからグラフ構造を再構築します。
    • 例:GAE, VGAE

  4. 空間-時間GNN(Spatial-Temporal GNN, STGNN)

    • 時間的変化を伴うグラフデータを扱うために、空間的および時間的な依存関係を同時にモデル化します。
    • 例:ST-GCN, DCRNN

しかし、基本的な仕組みはほぼ共通しているため、本説では基本となる演算についてまず触れていきたいと思います。


まず、各ノードは状態$$ \mathbf{z}_i \in \mathbb{R}^D$$を持ちます。各ノード$i$に特徴$\mathbf{x}_i$が与えられたとして、GNNは、最初に各ノードの状態を$\mathbf{z}_i = \mathbf{x}_i $ と初期化し、その後各ノードの状態$\mathbf{z}_i$を、次のように反復的に更新します(l = 1, 2, …)。

$$ \mathbf{z}_i^{(l+1)} = \mathbf{f}^{(l)}(\mathbf{z}
_i^{(l)}, \lbrace \mathbf{z}_j^{(l)} \mid j \in \mathcal{N}_i \rbrace \ ) \tag{1.3.1} $$

  • $\mathbf{z}_i^{(l)}$... $l$回目の更新後のノード$i$の状態。($\mathbf{z}_i^{(0)} = \mathbf{x}_i$)
  • $\mathcal{N}_i$ ... ノード$i$の隣接ノード

$\mathbf{z}_i$の更新式(1.3.1)は、ノード$i$に接続するノード$j(\in \mathcal{N}_i)$の状態$\mathbf{z}_j^{(l)}$を集約し、自ノードの状態$\mathbf{z}_i^{(l)}$と合わせて新しい$\mathbf{z}_i^{(l+1)}$を得る計算といえます。この計算は、グラフのすべてのノードについて独立に並行して行います。なお、$\mathbf{f}^{(l)}$には学習可能な重みが含まれ、モデルの種類によって異なる部分になります(下図)。

image.png

image.png

((画像引用元 図7.5))

更新の繰り返しは多層ネットワークで表現できます(上図)。各更新ステップ($l = 1, 2, \ldots$)はネットワークの1層に対応し、1回の更新ごとにノードの現状態が隣接ノードに伝播します。このようにして、ノードの情報はグラフ上の複数のエッジを介して他のノードに伝播されます。書籍では、より詳細に以下のように記載されていました

1回の更新のたびにノードの現状態が隣接ノードに伝播するので、$L$回の更新を行うと、各ノードの情報は、 グラフ上$L$個のエッジをたどって到達可能な範囲にあるノードに伝播します。

(出典:岡谷貴之. 深層学習 / 岡谷貴之著. 改訂第2版, 東京, 講談社, 2022.)

各ノードの最新の状態$\mathbf{z}_i^{(L)}$は、初期状態$\mathbf{x}_i$に、周辺ノードの状態を反映したものになります。この更新により得られるノードの状態$\mathbf{z}_i^{(L)}$を ノードの埋め込み(embedding) と呼び、さまざまな推論タスクで重要な役割を果たします。

 GNNはCNNなどと異なり、層間の伝播計算(式(1.3.1))において、次の層に接続されるユニットがグラフの構造を反映して異なる(多様化する) 点が特徴です。具体的には、ノード$i$に対応するユニットが受け取る入力は、要素数$\mid\mathcal{N}_i\mid$の集合の形をとります。この点において、GNNとTransformerには共通点があります。GNNでは「隣接ノード間の相互作用」を考慮し、Transformerでは「全要素間の相互作用」を考慮します。言い換えれば、Transformerは全ノードが互いに接続されたグラフ構造を扱っているとも言えます。


さて、GNNの全体を眺めてその特徴をなんとなく捉えられました。次は$\mathbf{z}_i$の更新式(1.3.1)が含む$\mathbf{f}^{(l)}$について考えます。改めて整理すると、更新のための処理は大きく1)隣接ノードの集約(aggregate)2) 自ノードの状態との結合(combine)に分けられます。

**参考**

後述するGCNの連続する2つの層に着目すると、1次元CNNのそれとの特徴的な違いが見えてきます。
前者は、後者のように前の層からの入力を「重み」を使って区別せず、「自ノードかそれ以外か」によってのみ区別されるわけですね。

image.png

2. 代表的なGNN

GNNにはどのような種類があるのかについては前章▶️詳細に記載しました。本章では、その中でいくつかの代表的なGNNを取り上げ、更新式の設計について眺めていきたいと思います。

2.1. Graph Convolutional Network(GCN)

グラフ畳み込みニューラルネットワークの更新式は以下の通りです(論文)。

\mathbf{z}_i^{(l+1)} = ReLU (\sum_{j \in \mathcal{N}_i \cup \lbrace i \rbrace} \frac{1}{\sqrt{deg(i)+deg(j)}} W^{(l+1)}z_j^{(l)} ) \tag{2.1.1}
  • $deg(i)$、$deg(j)$ ... ノード$i$、$j$の次数(=ノードに入るエッジの数)

これらは、スペクトラルグラフ理論から導出されており、これらの平方根で除算することで、ノードごとにエッジの数が異なることの影響をなくす効果があります。

詳細

正規化の必要性

  • グラフデータでは、各ノードの次数が異なる場合が多く、次数が大きいノードは多くの隣接ノードから情報を受け取り、次数が小さいノードは少ない情報しか受け取りません。
  • このため、単純に隣接ノードからの特徴量を合計するだけでは、次数が異なるノード間で情報量に偏りが生じます。次数が大きいノードの情報が過剰に強調され、小さいノードの情報が埋もれてしまうことがあります。

式(2.1.1)について
ノード $i$ の特徴量 $\mathbf{z}_i^{(l+1)}$ を更新するために、以下の手順が取られます:

  1. 隣接ノードの特徴量の重み付け:

    • 隣接ノード $j$ の特徴量 $\mathbf{z}_j^{(l)}$ に重み行列 $W^{(l+1)}$ を掛けます。

  2. 次数に基づくスケーリング:

    • 各隣接ノード $j$ の特徴量を、そのノード $i$ および $j$ の次数の平方根でスケーリングします。具体的には、$\frac{1}{\sqrt{deg(i) + deg(j)}}$ という因子を掛けます。(次数をそのまま使用すると、次数の大きいノードは過剰に抑制され、次数の小さいノードは過剰に強調される可能性があります。平方根をとることで、スケーリングの度合いが適度に調整され、極端な抑制や強調が緩和されます。)

スケーリングの効果

  • 平滑化:次数が大きいノードと小さいノードの情報が均等に扱われるようになります。これにより、各ノードが受け取る情報が平滑化され、過剰に強調されることが防がれます。
  • 対称性の確保:ノード $i$ とその隣接ノード $j$ の関係を対称的に扱うことで、情報の双方向性が保たれます。

スペクトルグラフ理論 ... グラフの性質をラプラシアン行列の固有値や固有ベクトルを通じて分析する理論です。この理論は、グラフ構造とその性質を深く理解するのに役立ちます。(こちらの記事の作者様は非常にわかりやすくGNNに使う基礎的な理論をまとめてくださっています。)

2.2. Graph Attention Network (GAT)

グラフ注意ネットワーク(GAT)は、ノード間の関連性を考慮して情報を集約するための手法です。GATの更新式は以下の通りです:

\mathbf{z}_i^{(l+1)} = \text{ReLU} \left( \sum_{j \in \mathcal{N}_i \cup \lbrace i \rbrace} a_{ij}h_j^{(l)} \right) \tag{2.2.1}
  • GATでは、各ノード $i$ が自分自身とその隣接ノード $j$ との関連性を計算します。この関連性は、各ノードの特徴量を用いて計算され、どの隣接ノードが重要かを示す注意の重み $a_{ij}$ として表現されます。
  • ノード $i$ は、計算された注意の重みを用いて、自身と隣接ノードの特徴量を重み付けし、新しい特徴量$h_j^{(l)}$を生成します。これにより、重要な隣接ノードの情報が強調され、重要でないノードの情報は抑制されます。
詳細
  1. 特徴量の重み付け

    • 各ノード $j$ の特徴量 $\mathbf{z}_j^{(l)}$ は、重み $W^{(l+1)}$ を用いて変換され、新たな特徴量 $h_j^{(l)}$ を生成します。
    • $h_j^{(l)} = W^{(l+1)}\mathbf{z}_j^{(l)}$

  2. 注意の重みの計算

    • ノード $i$ と隣接ノード $j$ との関連性を表す注意の重み $a_{ij}^{(l)}$ を計算します。これは、ノード $i$ とノード $j$ の変換された特徴量 $h_i^{(l)}$ と $h_j^{(l)}$ に基づいて計算されます。
    • $a_{ij}^{(l)} = \text{softmax}[a(h_i^{(l)}, h_j^{(l)})]$

  3. 特徴量の集約

    • ノード $i$ の新しい特徴量
    • $\mathbf{z}_i^{(l+1)}$は、隣接ノード$j$の注意の重みをかけて集約したものです。
    • これにより、ノード $i$ は自身とその隣接ノードから重要な情報を取り入れ、新しい特徴量ベクトルを生成します。

2.3. Graph Isomorphism Network (GIN)

 こちらの論文では、前述のGCNやGATなどのGNNが本質的にはグラフ上の平滑化計算にすぎないことやそれにより、更新回数(=層の数)が多くなると高周波成分が失われ、場合によっては性能が低下する傾向があるということが指摘されています(詳細は原論文をご覧ください)。

GCNにおける高周波成分の喪失の原因

平滑化効果:

  • GCNの基本的な操作は、隣接ノードの特徴を平均化することです。このプロセスは、各ノードの特徴がその隣接ノードの特徴と混ざり合い、平均化されることで、ノードごとの特徴の差異が次第に小さくなることを意味します。
  • この平滑化操作は、特に多層GCNで顕著であり、層を重ねるごとにノードの特徴量が均一化され、細かな変動(高周波成分)が減少します。

過度な情報拡散:

  • GCNでは、各層でノードの特徴がその隣接ノードに伝播します。これにより、ノード間の情報が広がり、特に層が深くなるほど、広範囲のノード間で情報が共有されるようになります。
  • しかし、これにより、ノードの個別の特徴が埋もれてしまい、全体の特徴が均一化されます。これが、高周波成分が失われる原因の一つです。

特定の周波数成分の減衰:

  • GCNの設計は、グラフ信号処理における低周波成分(ノード間のゆっくりとした変動)を強調しやすく、高周波成分(急激な変動)を減衰させる傾向があります。これは、GCNのフィルタがスペクトルドメインでの低周波成分を強調し、高周波成分を減衰させるためです。
  • このため、層が深くなるほど高周波成分が失われ、結果的にグラフの詳細な構造情報が失われることになります。
GATにおける高周波成分の喪失の原因

隣接ノードの情報の平均化効果

  • 重み付け平均:
    • GATでは、各隣接ノードの特徴が注意重み $a_{ij}$ によって重み付けされます。この重み付け平均は、注意スコアに基づいて情報を平均化する効果を持ちます。特に、注意重みが一様に近い場合、結果的に隣接ノードの特徴の平均値に近づきます。

    • 注意スコアが大きく異ならない場合、これは単純な平均と似た効果を持ち、ノード間の差異を平滑化します。


  • 注意スコアの制約:
    • 注意スコアは、ノード間の特徴の違いに基づいて計算されますが、実際には注意重みが完全に異なることは稀です。多くの場合、注意重みはある程度一様に分布するため、重み付け平均が実質的に平滑化として機能します。

  1. 平滑化効果:

    • GCNはノードの特徴を隣接ノードの特徴と平均化します。このため、各層でノード間の特徴の差が小さくなり、詳細な情報(高周波成分)が失われます。

  2. 情報の拡散:

    • 各層でノードの特徴が隣接ノードに広がるため、個々のノードの特徴が薄まり、全体的に均一化します。

  3. 低周波成分の強調:

    • GCNは低周波成分(ゆっくりとした変動)を強調しやすく、高周波成分(急激な変動)を減衰させます。これにより、詳細な構造情報が失われます。

これを改善するためのGNNとして、グラフ同型ネットワーク (GIN)が提案されました。更新式は以下で表されます。

z_i^{(l+1)} = MLP^{(l+1)} \left( \left(1 + \epsilon^{(l+1)}\right) z_i^{(l)} + \sum_{j \in N_i} z_j^{(l)} \right) \tag{2.3.1}

ここで MLP はいくつかの全結合層を重ねたもので、ここに学習対象となる重みがあります。また $\epsilon^{(l+1)}$も学習で決定します。

高周波成分の喪失を回避できる理由

  1. 非線形変換を使用:

    • GINは、全結合層(MLP)を用いて非線形変換を行います。この非線形変換は、ノード特徴の複雑なパターンを捉える能力があり、高周波成分の保持に貢献します。

  2. 加重平均ではなく総和を使用:

    • GINは、ノードの特徴を加重平均するのではなく、単純に総和します。これにより、各ノードの特徴が均一化されることなく保持され、高周波成分が失われにくくなります。
    • 平均化はノード特徴の平滑化を引き起こしやすいが、総和はその効果を軽減します。

  3. 学習可能なパラメータ $\epsilon$:

    • GINでは、学習可能なパラメータ $\epsilon$ を用いることで、自己ループの強度を調整し、ノード自身の特徴の重み付けを調整します。これにより、ノード間の特徴の差異をより効果的に保持します。
    • 数式中の $\left(1 + \epsilon^{(l+1)}\right)$ は、各ノードの自己ループの重要性を動的に調整する役割を果たします。

3. グラフを用いて行う推論

グラフを対象に行われる様々な形の推論があります。本章ではその種類を示し、代表的な問題を紹介します。ほとんどの場合において、1つのグラフとその各ノードの特徴 $\mathbf{x}_i$ が与えられたとき、GNNによって $\mathbf{x}_i$ を初期値にしてノードの状態 $\mathbf{z}_i^{(l)}$ を繰り返し更新し、ノードの埋め込み $\mathbf{z}_i^{(L)}$ を得ます。その後、$\mathbf{z}_i^{(L)}$ を使って問題ごとに異なる計算を行い、特定の形の出力を得ます。

3.1. ノードのクラス分類 (Node Classification)

 グラフの各ノードを、あらかじめ定められたクラスの1つに分類する問題です。多くの場合、半教師学習として定式化され、1つの大きなグラフが与えられ、そのノードの一部に正解クラスのラベルが付与されます。この条件のもとで、他のノードのクラスを予測します。

 ベンチマークタスクの一例として、論文の属性と論文間の引用関係を基に、研究分野などの論文の属性を予測する問題があります。具体的には、論文をノードとして、引用の有無をエッジとしてグラフを構成します。引用関係は有向である場合がありますが、向きを無視して無向グラフとすることもあります。推論では、GNNでノードの状態を更新した後の各ノードの状態 $\mathbf{z}_i^{(L)}$ を元に、全結合層やソフトマックス関数を適用して各クラスのスコアを得ます。学習過程では、交差エントロピーを最小化してGNN内のパラメータを最適化します。

3.2. ノードのクラスタリング (Node Clustering)

 グラフを構成するノードを、その埋め込み $\mathbf{z}_i^{(L)}$ を用いてクラスタリングする手法です。ノードの特徴や関係性に基づいてクラスタリングを行い、類似したノードをグループ化します。

3.3. グラフのクラス分類 (Graph Classification)

 与えられた1つのグラフをクラス分類します。最も単純な方法では、グラフの全ノードにわたってノードの埋め込み (z_i^{(L)}) を集約(プーリング)し、1つの特徴ベクトルを得て、これを元にグラフを分類します。集約には平均を使います:

z_G = \frac{1}{M} \sum_{i=1}^M z_i^{(L)}

ただし、こうして得られた特徴にはグラフの構造が(ノードの埋め込みに反映されているものを除き)反映されません。
 この推論を行う具体的な例の1つにタンパク質をグラフとして表現し、その機能を分類する問題があります。例えば、タンパク質の構造をグラフとして捉え、各ノードはアミノ酸を表し、エッジはアミノ酸間の結合を表します。これを基に、タンパク質が酵素であるか否か、またはどの種類の酵素であるかを分類します。

3.4. 接続予測 (Link Prediction)

 グラフの指定したノード間にエッジがあるかどうかを予測する問題です。不完全なグラフ、つまりノード間のエッジの情報が部分的に欠損している場合や、グラフそのものが時間とともに成長していくような場合を対象とします。用途は、商品、広告、SNSでのユーザ間のつながりなどのレコメンデーションや、知識ベースの構築などです。

3.5. グラフ自己符号化器 (Graph Autoencoder)

 グラフ自己符号化器は、グラフ構造を低次元空間にマッピング(データを圧縮してより少ない次元で表現)し、その低次元表現から元のグラフ構造を再構成するための手法です。デコーダは、得られたノードの埋め込み $\mathbf{z}_i^{(L)}$ を用いて、元のグラフ構造を再構成します。具体的には、2つのノード間の類似度を計算し、接続有無の確率を推定します:

p(A_{ij} = 1) = \sigma(\mathbf{z}_i^{(L)} \cdot \mathbf{z}_j^{(L)})

ここで、$\sigma$ はシグモイド関数を示し、ノード $i$ とノード $j$ 間のエッジの存在確率を計算します。これにより、ノード間の類似度に基づいてエッジの有無を予測します。

応用例

  • リンク予測:
    • ソーシャルネットワークや推薦システムにおいて、新たな友達関係や商品推薦を行うために用いられます。
  • ネットワーク再構成:
    • 不完全なネットワークデータの補完や、異常検知のためにネットワークを再構成します。

3.6. 2部グラフの接続予測 (Bipartite Graph Link Prediction)

 接続予測のよくある応用の1つにレコメンデーション、例えば商品を潜在的な購入者に推薦する問題があります。ユーザと商品をそれぞれ、個別のノードで表し、ユーザによる商品の購入有無をエッジとして表現します。得られるグラフは2部グラフです。

3.7. マルチグラフに対する推論 (Multigraph Inference)

 マルチグラフは、ノード間に複数のエッジが存在する(し得る)グラフです。これは、複数の種類の関係性を持つデータを表現するのに適しています。例えば、ソーシャルネットワークでは、ユーザ間の友人関係、フォロー関係、メッセージ送信関係などを一つのマルチグラフで表現できます。

 ノード間の複雑な関係性をモデル化し、それに基づいて推論を行う手法を 統計的関係学習(statistical relational learning) といいます。マルチグラフにおける統計的関係学習では、異なる種類のエッジを持つノード間の関係性を定量的に評価し、その関係性に基づいて推論を行います。例えば、SNSでは、ユーザー情報(所在、行動等)、フォロー関係、メッセージ送信関係などをノードとする一つのマルチグラフで表現できます。これにより友人関係(AさんとBさんのノード間にはエッジがある)や興味関心(旅行に関するメッセージを頻繁に送信している場合、同じジャンルの新しいコンテンツ(旅行に関するブログ記事や旅行パッケージ)に興味を持つ可能性が高い)を予推論を行うことができます。

Appendix

 実装に挑戦したい!という場合は、PyTorch Geometricなどのライブラリを使うととても便利です。下記のようなシンプルなコードでGNNを実際に動かすことができます。

GNNに関して知ることができる素晴らしい書籍が存在します。(サポートも充実している)
下記コードについてはその書籍のサポートページから引用したものです。

書籍:

コード:

gnn.py
import numpy as np
import random
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
n = dataset[0].num_nodes


# グラフ畳み込みネットワークの定義
class GCN(torch.nn.Module):
   def __init__(self, in_d, mid_d, out_d):
       super().__init__()
       self.conv1 = GCNConv(in_d, mid_d)
       self.conv2 = GCNConv(mid_d, out_d)

   def forward(self, data):
       x, edge_index = data.x, data.edge_index

       x = self.conv1(x, edge_index)
       emb = x.detach()
       x = F.relu(x)
       x = self.conv2(x, edge_index)

       return F.log_softmax(x, dim=1), emb

       
model = GCN(dataset.num_node_features, 16, dataset.num_classes)
data = dataset[0]
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4)

def train(epoch):
   model.train()
   for epoch in range(epoch):
       optimizer.zero_grad()
       out = model(data)[0]
       loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
       loss.backward()
       optimizer.step()

train(500)

(GCN embedding 可視化)

ノード分類

image.png

details

例: Cora データセット
Coraデータセットは、学術論文の引用関係を表すグラフデータセットです。このデータセットは、機械学習コミュニティで広く使用されており、ノード分類タスクの標準ベンチマークとして利用されます。

頂点 (ノード): 論文(テキスト)
各頂点(ノード)は個々の論文を表します。これらの論文は、テキストデータとして表現され、特定の研究分野に属しています。

頂点特徴量: 論文要旨の bag of words
各ノードには、その論文の要旨を表す特徴ベクトルが付与されています。この特徴ベクトルは、bag of words モデルを用いて生成され、論文内で使用されている単語の出現頻度を表します。

辺 (エッジ): 論文 A から B への辺は引用を表す
エッジは、論文間の引用関係を表します。具体的には、論文 A が論文 B を引用している場合、ノード A とノード B の間にエッジが存在します。このエッジは、論文間の関係性を示します。

頂点ラベル: 論文のカテゴリ
各ノードには、論文の研究分野を示すクラスラベルが付与されています。これらのラベルは、ノード分類タスクのターゲットとなるものであり、例えば、論文が「機械学習」、「情報検索」、「データマイニング」などのカテゴリに属しているかを示します。

タスク: 論文のカテゴリを予測する(テキスト分類)
Coraデータセットにおける主なタスクは、各論文(ノード)がどの研究分野(カテゴリ)に属するかを予測することです。このタスクはテキスト分類問題に似ていますが、論文間の引用関係(グラフ構造)を活用することで、単なるテキスト分類以上の精度を達成することができます。

5
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?