0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Transformerでグラフデータを扱ってみる

Last updated at Posted at 2022-02-02

#概要
 Transformerエンコーダを使ってグラフデータを扱います.非自己回帰型のエンコーダでグラフのエッジに合わせた特殊なmaskを行います.

ちなみにこちらで解説されている論文ではもっと複雑に扱っており,SelfAttentionのQKVにエッジ情報を加えたQKVEという入力に拡張しています.

今回考えてみた手法は単純である反面,エッジの特徴を埋め込めない欠点があります.できても単方向or両方向くらいです.

#準備
 今回はGoogle Colabで実行しました,ランタイムはGPUに設定しておきましょう.こちらを参考にtorch-geometricをインストールします.

#コード

  • Transformerエンコーダレイヤーを用意.
  • これは前回の記事からのコピーです.
import copy
import math
import numpy as np

import torch
torch.manual_seed(41)
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F

class SoftmaxAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.head_dim = head_dim

    def forward(self, Q, K, V, mask=None):
        logit = torch.einsum("bhld,bhmd->bhlm",Q,K)/math.sqrt(self.head_dim)

        if mask!=None:
            logit = logit + mask[:,None,:,:]

        attention_weight = F.softmax(logit, dim=-1)
        X = torch.einsum("bhlm,bhmd->bhld",attention_weight,V)

        return X

class SelfAttention(nn.Module):
    def __init__(self, dim, head_dim, num_head):
        super().__init__()

        self.dim = dim
        self.head_dim = head_dim
        self.num_head = num_head

        assert dim == head_dim*num_head,print("ASSERT #dim = head_dim * num_head")

        self.W_q = nn.Linear(self.dim, self.dim)
        self.W_k = nn.Linear(self.dim, self.dim)
        self.W_v = nn.Linear(self.dim, self.dim)

        self.attn = SoftmaxAttention(head_dim)

    def forward(self, X, mask):
        Q = self.split_heads(self.W_q(X))
        K = self.split_heads(self.W_k(X))
        V = self.split_heads(self.W_v(X))

        attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float())
        attn_out = self.combine_heads(attn_out)
        return attn_out

    #[batch,head,len,head_dim]->[batch,len,dim]
    def combine_heads(self, X):
        X = X.transpose(1, 2)
        X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
        return X

    #[batch,len,dim]->[batch,head,len,head_dim]
    def split_heads(self, X):
        X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
        X = X.transpose(1, 2)
        return X

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.transpose(0,1))

    #[batch,len,dim]->[batch,len,dim]
    def forward(self, x):
        x = x + self.pe[:,:x.size(1),:]
        return self.dropout(x)


class TransformerEncoderLayer(nn.Module):
    def __init__(self, attn_layer, ff_layer, norm_layer, drop_layer):
        super().__init__()
        self.attn_layer = attn_layer
        self.ff_layer = ff_layer
        self.norm_layer = norm_layer
        self.drop_layer = drop_layer

    def forward(self, x, mask):
        x = self.drop_layer(self.attn_layer(x,mask)) + x
        x = self.norm_layer(x)

        x = self.drop_layer(self.ff_layer(x)) + x
        x = self.norm_layer(x)

        return x
  • グラフ用のTransformerです.

・xにノードの特徴量[nodes,dim]を入れてバイアス無しの線形変換で埋め込んでいます.
・maskを使ってエッジ情報を入れます.node_iとnode_jにエッジが無い場合はmask[i,j]=-infに設定します.他は0です.

Input: x[batch,nodes,dim] adj[batch,nodes,nodes]

from torch.nn.modules.container import ModuleList
from torch.nn.modules.normalization import LayerNorm

class TransformerGraph(nn.Module):
    def __init__(self, num_layer, dim, ff_dim, head_dim, num_head, feature_size=None, num_classes=None,drop_p=0.1):
        super().__init__()

        self.node_emb = nn.Linear(feature_size,dim,bias=False)
        self.output = nn.Linear(dim,num_classes)
        
        self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(dim, head_dim, num_head),\
                            nn.Sequential(nn.Linear(dim,ff_dim),nn.ReLU(),nn.Linear(ff_dim,dim)),\
                                LayerNorm(dim, eps=1e-5),\
                                    nn.Dropout(drop_p)) for i in range(num_layer)])
    
    def forward(self, x, adj):
        mask = torch.zeros(adj.shape).to(x.device).masked_fill_(adj==0,float("-inf"))
        x = self.node_emb(x.float())
        for layer in self.encoders:
            x = layer(x,mask)
        return self.output(x)
  • データセットを用意
#edge情報から隣接行列を生成
def edge_adjacency_matrix(edge,node_num):
    print(node_num)
    matrix = torch.zeros(node_num,node_num)
    for i,j in zip(edge[0,:],edge[1,:]):
        matrix[i,j]=1
    return matrix

#不必要なノードのラベルを-1にする
def mask_label(label,mask):
    label = label.masked_fill(mask,-1)
    return label

from torch_geometric.datasets import Planetoid
dataset = Planetoid("./", "Cora", split="public")
data = dataset[0]

nodes_feature = data['x'] #[2708, 1433]
nodes_label = data['y'] #[2708]
adjacency_matrix = edge_adjacency_matrix(data['edge_index'],data.num_nodes) #[2, 10556]

train_mask = data["train_mask"] #[2708]
val_mask = data["val_mask"] #[2708]
test_mask = data["test_mask"] #[2708]

train_label = mask_label(nodes_label,train_mask)
val_label = mask_label(nodes_label,val_mask)
test_label = mask_label(nodes_label,test_mask)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerGraph(num_layer=3, dim=128, ff_dim=256, head_dim=32, num_head=4, feature_size=1433, num_classes=7).to(device)

import torch.optim as optim
optimizer = optim.Adam(model.parameters(),lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-1) #train~test_maskで-1にしたlabelを無視
  • 訓練ループを書きます.
min_loss = float("inf")

nodes_feature = nodes_feature.unsqueeze(0).to(device)
adjacency_matrix = adjacency_matrix.unsqueeze(0).to(device)

for epoch in range(1000):
    model.train()
    optimizer.zero_grad()
    logits = model(nodes_feature,adjacency_matrix)
    loss = criterion(logits.reshape(-1, logits.size(-1)), train_label.reshape(-1).to(device))
    loss.backward()
    optimizer.step()
    print("train_loss:",loss.item())

    model.eval()
    with torch.no_grad():
        logits = model(nodes_feature,adjacency_matrix)
    loss = criterion(logits.reshape(-1, logits.size(-1)), val_label.reshape(-1).to(device))
    print("val_loss:",loss.item())
    if min_loss > loss.item():
        torch.save(model.state_dict(),"./model.bin")
        print("save!")
  • テストデータで評価
model.load_state_dict(torch.load("./model.bin"))
model.eval()

with torch.no_grad():
    logits = model(nodes_feature,adjacency_matrix)
preds = logits.argmax(-1)

accuracy = []
for pred,label,mask in zip(preds.squeeze(),nodes_label,test_mask):
    if mask.item():  #test_maskがFalseのラベルはスキップ
        if pred.item()==label.item():
            accuracy.append(1)
        else:
            accuracy.append(0)
print(sum(accuracy)/len(accuracy))

 90%以上出てます,コードが間違ってなければよいのですが….

#備考
 Transformerはattention-weight [batch,length,length]の不要な部分が0になるようにmaskしますが,本モデルはattention-weight [batch,nodes,nodes]のエッジの無い組にmaskしています.

しかし,この構造ではエッジの種類(単・共有結合など)に関して埋め込むことが出来ないので注意してください.

エッジのある部分に1次元Embeddingか何かで埋め込むのも一手かもしれません.

#まとめ
 Transformerエンコーダ(非自己回帰==No Causal)でグラフデータを扱いました.

#最後に
 誤っている部分等ございましたら,コメント等で優しく指摘して頂けると嬉しいです.(気付かなければ申し訳ありません)

0
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?