LoginSignup
0
3

はじめに

GCN,ちゃんと理解していますか?

グラフ深層学習屋さんなら誰でも知っている叩き台ことGCNですが,有名すぎてネットのそこら中に実装が転がっていて,コピペすれば動いてしまいます.

なんとなく何をしているかは知っていても,具体的な実装を見たことはない……そんな状況を解消するため,GCNの実装を細かいところまで見ます.

GCNの論文
Semi-Supervised Classification with Graph Convolutional Networks

基本

GCNは隣接ノード間で属性情報を伝播し,畳み込んでノードの埋め込みとする.

何度か繰り返すと,2-hop, 3-hop先の情報を取れる.

image.png

元コード

pygcn のリポジトリでGCNが実装されているので,これを見ていきます.

import torch.nn as nn
import torch.nn.functional as F
from pygcn.layers import GraphConvolution

class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)

GraphConvolutionとは?

class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        # 中略

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

解説

GraphConvolutionは何をしている?

学習可能パラメータはウェイトとバイアスです.

input, adjが与えられていますが,これは特徴量と隣接行列です.

support = torch.mm(input, self.weight)

特徴量に重み行列をかけています.
重み行列はin_features × out_featuresの行列なので,各ノードの特徴量はout_features次元に変換されます.

よって,畳み込み前と畳み込み後は特徴量の各次元が表す概念が違います.なんと.

output = torch.spmm(adj, support)

隣接行列に先ほどの中間表現をかけています.

あるノードについて,隣接行列の対応する値がnon-zeroのもの,すなわち隣接ノード全ての中間表現の合計が格納されます.

行列演算の例.ノード1について,隣接ノード0, 2の特徴量が伝播され,その和がノード1の新たな埋め込みとして格納されている.
image.png

ちなみに,GCNに使用するグラフは自己ループが追加されていた気がします.

最後に,バイアスを追加する部分.バイアスは学習可能パラメータです.

if self.bias is not None:
    return output + self.bias
else:
    return output

つまりGraphConvolutionとは,特徴量にウェイトをかけて扱いやすい大きさの空間に埋め込んだ後,隣接ノードの中間表現の和を取る操作ということができる.

GCNは何をしている?

GraphConvolutionを一回かけ,reluで活性化している.
その後dropoutをはさみ,2回目のGraphConvolutionをかけた後,log softmaxで活性化している.

x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
return F.log_softmax(x, dim=1)

畳み込みと活性化関数を交互にかけるという基本が実装されています.

さいごに

グラフ畳み込みのコードを読みました.

複雑なコードかと思っていましたが,隣接行列と特徴量の乗算で畳み込みが実装できるんですね……
言われてみれば当然ですが,これに最初に気づいた人はすごい.

0
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3