Help us understand the problem. What is going on with this article?

最近話題のGraphConvolutionを作って溶解度予測をしてみた

More than 1 year has passed since last update.

最近みんなGraph Convolutionalをしている気がしたので、その風潮に乗って作ってみました。

Datasets

化学情報の分野でよく使われるデータセット。
solubility.train.sdf(1025分子)とsolubility.test.sdf(257分子)
data.png

今回予測するのはSOLの値です。

使用した特徴量

ノード特徴量として原子のone_hot_vector(11次元)、原子の重み、原子がNOかそれ以外かという13次元のベクトルを使いました。
隣接行列は分子の結合を元に作成しました。

結果

Screenshot from 2019-02-28 13-20-00.png

Rdkitの記述子を利用したlightgbmとほぼ同程度の精度が得られました。簡単なネットワークと少ない記述子を使うだけでこの精度が得られたのでGCNはやはり有望な手法だと思います。

コード

GraphConvolution.py
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

class GraphConvolution(Module):
    def __init__(self,in_features,out_features):
        super(GraphConvolution,self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features,out_features))
        self.weight2 = Parameter(torch.FloatTensor(in_features,out_features))

    def forward(self,input,adj):
        out = []
        for i in range(len(input)):
            support = torch.mm(input[i].view(input.shape[1],-1),self.weight)
            output = torch.spmm(adj[i],support) + torch.mm(input[i].view(input.shape[1],-1),self.weight2)
            out.append(output)
        out = torch.stack(out,dim=0)
        return out

隣接行列の重みと自己ループの重みを使って畳み込んでます。

Readout.py
class Readout(Module):

    def __init__(self,in_features):
        super(Readout,self).__init__()
        self.in_features = in_features
        self.weight = Parameter(torch.FloatTensor(in_features,in_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv=1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv,stdv)

    def forward(self,input):
        input = torch.sum(input,1)
        return torch.mm(input,self.weight)

node数の影響をなくすためにnodeの次元で和を取ってます。

GCN.py
import torch.nn as nn
import torch.nn.functional as F
from layer import GraphConvolution,Readout

class GCN(nn.Module):
    def __init__(self,nfeat,nnode,nhid,dropout):
        super(GCN,self).__init__()

        self.gc1 = GraphConvolution(nfeat,nhid)
        self.gc2 = GraphConvolution(nhid,nhid)
        self.dropout = dropout
        self.rd = Readout(nhid)
        self.fc1 = nn.Linear(nhid,1)
        nn.init.kaiming_normal_(self.fc1.weight)

    def forward(self,x,adj):
        x = F.relu(self.gc1(x,adj))
        x = F.dropout(x,self.dropout,training=self.training)
        x = F.relu(self.gc2(x,adj))
        x = F.dropout(x,self.dropout,training=self.training)
        x = F.relu(self.rd(x))
        x = self.fc1(x)
        return x

今回使用したモデルです。

まとめ

Graph Convolutionは今ホットな分野らしいので触ってみることをおすすめします。
今回使用したコードはこれです。

参考サイト

香川大学農学部 ケミカルバイオロジー研究室

rookzeno
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away