最近みんなGraph Convolutionalをしている気がしたので、その風潮に乗って作ってみました。
#Datasets
化学情報の分野でよく使われるデータセット。
solubility.train.sdf(1025分子)とsolubility.test.sdf(257分子)
今回予測するのはSOLの値です。
#使用した特徴量
ノード特徴量として原子のone_hot_vector(11次元)、原子の重み、原子がNOかそれ以外かという13次元のベクトルを使いました。
隣接行列は分子の結合を元に作成しました。
Rdkitの記述子を利用したlightgbmとほぼ同程度の精度が得られました。簡単なネットワークと少ない記述子を使うだけでこの精度が得られたのでGCNはやはり有望な手法だと思います。
#コード
GraphConvolution.py
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
class GraphConvolution(Module):
def __init__(self,in_features,out_features):
super(GraphConvolution,self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features,out_features))
self.weight2 = Parameter(torch.FloatTensor(in_features,out_features))
def forward(self,input,adj):
out = []
for i in range(len(input)):
support = torch.mm(input[i].view(input.shape[1],-1),self.weight)
output = torch.spmm(adj[i],support) + torch.mm(input[i].view(input.shape[1],-1),self.weight2)
out.append(output)
out = torch.stack(out,dim=0)
return out
隣接行列の重みと自己ループの重みを使って畳み込んでます。
Readout.py
class Readout(Module):
def __init__(self,in_features):
super(Readout,self).__init__()
self.in_features = in_features
self.weight = Parameter(torch.FloatTensor(in_features,in_features))
self.reset_parameters()
def reset_parameters(self):
stdv=1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv,stdv)
def forward(self,input):
input = torch.sum(input,1)
return torch.mm(input,self.weight)
node数の影響をなくすためにnodeの次元で和を取ってます。
GCN.py
import torch.nn as nn
import torch.nn.functional as F
from layer import GraphConvolution,Readout
class GCN(nn.Module):
def __init__(self,nfeat,nnode,nhid,dropout):
super(GCN,self).__init__()
self.gc1 = GraphConvolution(nfeat,nhid)
self.gc2 = GraphConvolution(nhid,nhid)
self.dropout = dropout
self.rd = Readout(nhid)
self.fc1 = nn.Linear(nhid,1)
nn.init.kaiming_normal_(self.fc1.weight)
def forward(self,x,adj):
x = F.relu(self.gc1(x,adj))
x = F.dropout(x,self.dropout,training=self.training)
x = F.relu(self.gc2(x,adj))
x = F.dropout(x,self.dropout,training=self.training)
x = F.relu(self.rd(x))
x = self.fc1(x)
return x
今回使用したモデルです。
#まとめ
Graph Convolutionは今ホットな分野らしいので触ってみることをおすすめします。
今回使用したコードはこれです。
#参考サイト
香川大学農学部 ケミカルバイオロジー研究室