#概要
Transformerエンコーダを使ってグラフデータを扱います.非自己回帰型のエンコーダでグラフのエッジに合わせた特殊なmaskを行います.
ちなみにこちらで解説されている論文ではもっと複雑に扱っており,SelfAttentionのQKVにエッジ情報を加えたQKVEという入力に拡張しています.
今回考えてみた手法は単純である反面,エッジの特徴を埋め込めない欠点があります.できても単方向or両方向くらいです.
#準備
今回はGoogle Colabで実行しました,ランタイムはGPUに設定しておきましょう.こちらを参考にtorch-geometricをインストールします.
#コード
- Transformerエンコーダレイヤーを用意.
- これは前回の記事からのコピーです.
import copy
import math
import numpy as np
import torch
torch.manual_seed(41)
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
class SoftmaxAttention(nn.Module):
def __init__(self, head_dim):
super().__init__()
self.head_dim = head_dim
def forward(self, Q, K, V, mask=None):
logit = torch.einsum("bhld,bhmd->bhlm",Q,K)/math.sqrt(self.head_dim)
if mask!=None:
logit = logit + mask[:,None,:,:]
attention_weight = F.softmax(logit, dim=-1)
X = torch.einsum("bhlm,bhmd->bhld",attention_weight,V)
return X
class SelfAttention(nn.Module):
def __init__(self, dim, head_dim, num_head):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.num_head = num_head
assert dim == head_dim*num_head,print("ASSERT #dim = head_dim * num_head")
self.W_q = nn.Linear(self.dim, self.dim)
self.W_k = nn.Linear(self.dim, self.dim)
self.W_v = nn.Linear(self.dim, self.dim)
self.attn = SoftmaxAttention(head_dim)
def forward(self, X, mask):
Q = self.split_heads(self.W_q(X))
K = self.split_heads(self.W_k(X))
V = self.split_heads(self.W_v(X))
attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float())
attn_out = self.combine_heads(attn_out)
return attn_out
#[batch,head,len,head_dim]->[batch,len,dim]
def combine_heads(self, X):
X = X.transpose(1, 2)
X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
return X
#[batch,len,dim]->[batch,head,len,head_dim]
def split_heads(self, X):
X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
X = X.transpose(1, 2)
return X
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.transpose(0,1))
#[batch,len,dim]->[batch,len,dim]
def forward(self, x):
x = x + self.pe[:,:x.size(1),:]
return self.dropout(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, attn_layer, ff_layer, norm_layer, drop_layer):
super().__init__()
self.attn_layer = attn_layer
self.ff_layer = ff_layer
self.norm_layer = norm_layer
self.drop_layer = drop_layer
def forward(self, x, mask):
x = self.drop_layer(self.attn_layer(x,mask)) + x
x = self.norm_layer(x)
x = self.drop_layer(self.ff_layer(x)) + x
x = self.norm_layer(x)
return x
- グラフ用のTransformerです.
・xにノードの特徴量[nodes,dim]を入れてバイアス無しの線形変換で埋め込んでいます.
・maskを使ってエッジ情報を入れます.node_iとnode_jにエッジが無い場合はmask[i,j]=-infに設定します.他は0です.
Input: x[batch,nodes,dim] adj[batch,nodes,nodes]
from torch.nn.modules.container import ModuleList
from torch.nn.modules.normalization import LayerNorm
class TransformerGraph(nn.Module):
def __init__(self, num_layer, dim, ff_dim, head_dim, num_head, feature_size=None, num_classes=None,drop_p=0.1):
super().__init__()
self.node_emb = nn.Linear(feature_size,dim,bias=False)
self.output = nn.Linear(dim,num_classes)
self.encoders = ModuleList([TransformerEncoderLayer(SelfAttention(dim, head_dim, num_head),\
nn.Sequential(nn.Linear(dim,ff_dim),nn.ReLU(),nn.Linear(ff_dim,dim)),\
LayerNorm(dim, eps=1e-5),\
nn.Dropout(drop_p)) for i in range(num_layer)])
def forward(self, x, adj):
mask = torch.zeros(adj.shape).to(x.device).masked_fill_(adj==0,float("-inf"))
x = self.node_emb(x.float())
for layer in self.encoders:
x = layer(x,mask)
return self.output(x)
- データセットを用意
#edge情報から隣接行列を生成
def edge_adjacency_matrix(edge,node_num):
print(node_num)
matrix = torch.zeros(node_num,node_num)
for i,j in zip(edge[0,:],edge[1,:]):
matrix[i,j]=1
return matrix
#不必要なノードのラベルを-1にする
def mask_label(label,mask):
label = label.masked_fill(mask,-1)
return label
from torch_geometric.datasets import Planetoid
dataset = Planetoid("./", "Cora", split="public")
data = dataset[0]
nodes_feature = data['x'] #[2708, 1433]
nodes_label = data['y'] #[2708]
adjacency_matrix = edge_adjacency_matrix(data['edge_index'],data.num_nodes) #[2, 10556]
train_mask = data["train_mask"] #[2708]
val_mask = data["val_mask"] #[2708]
test_mask = data["test_mask"] #[2708]
train_label = mask_label(nodes_label,train_mask)
val_label = mask_label(nodes_label,val_mask)
test_label = mask_label(nodes_label,test_mask)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerGraph(num_layer=3, dim=128, ff_dim=256, head_dim=32, num_head=4, feature_size=1433, num_classes=7).to(device)
import torch.optim as optim
optimizer = optim.Adam(model.parameters(),lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-1) #train~test_maskで-1にしたlabelを無視
- 訓練ループを書きます.
min_loss = float("inf")
nodes_feature = nodes_feature.unsqueeze(0).to(device)
adjacency_matrix = adjacency_matrix.unsqueeze(0).to(device)
for epoch in range(1000):
model.train()
optimizer.zero_grad()
logits = model(nodes_feature,adjacency_matrix)
loss = criterion(logits.reshape(-1, logits.size(-1)), train_label.reshape(-1).to(device))
loss.backward()
optimizer.step()
print("train_loss:",loss.item())
model.eval()
with torch.no_grad():
logits = model(nodes_feature,adjacency_matrix)
loss = criterion(logits.reshape(-1, logits.size(-1)), val_label.reshape(-1).to(device))
print("val_loss:",loss.item())
if min_loss > loss.item():
torch.save(model.state_dict(),"./model.bin")
print("save!")
- テストデータで評価
model.load_state_dict(torch.load("./model.bin"))
model.eval()
with torch.no_grad():
logits = model(nodes_feature,adjacency_matrix)
preds = logits.argmax(-1)
accuracy = []
for pred,label,mask in zip(preds.squeeze(),nodes_label,test_mask):
if mask.item(): #test_maskがFalseのラベルはスキップ
if pred.item()==label.item():
accuracy.append(1)
else:
accuracy.append(0)
print(sum(accuracy)/len(accuracy))
90%以上出てます,コードが間違ってなければよいのですが….
#備考
Transformerはattention-weight [batch,length,length]
の不要な部分が0になるようにmaskしますが,本モデルはattention-weight [batch,nodes,nodes]
のエッジの無い組にmaskしています.
しかし,この構造ではエッジの種類(単・共有結合など)に関して埋め込むことが出来ないので注意してください.
エッジのある部分に1次元Embeddingか何かで埋め込むのも一手かもしれません.
#まとめ
Transformerエンコーダ(非自己回帰==No Causal)でグラフデータを扱いました.
#最後に
誤っている部分等ございましたら,コメント等で優しく指摘して頂けると嬉しいです.(気付かなければ申し訳ありません)