17
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

最近話題のGraphConvolutionを作って溶解度予測をしてみた

Last updated at Posted at 2019-02-24

最近みんなGraph Convolutionalをしている気がしたので、その風潮に乗って作ってみました。

#Datasets
化学情報の分野でよく使われるデータセット。
solubility.train.sdf(1025分子)とsolubility.test.sdf(257分子)
data.png

今回予測するのはSOLの値です。

#使用した特徴量
ノード特徴量として原子のone_hot_vector(11次元)、原子の重み、原子がNOかそれ以外かという13次元のベクトルを使いました。
隣接行列は分子の結合を元に作成しました。

#結果
Screenshot from 2019-02-28 13-20-00.png

Rdkitの記述子を利用したlightgbmとほぼ同程度の精度が得られました。簡単なネットワークと少ない記述子を使うだけでこの精度が得られたのでGCNはやはり有望な手法だと思います。

#コード

GraphConvolution.py
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

class GraphConvolution(Module):
    def __init__(self,in_features,out_features):
        super(GraphConvolution,self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features,out_features))
        self.weight2 = Parameter(torch.FloatTensor(in_features,out_features))

    def forward(self,input,adj):
        out = []
        for i in range(len(input)):
            support = torch.mm(input[i].view(input.shape[1],-1),self.weight)
            output = torch.spmm(adj[i],support) + torch.mm(input[i].view(input.shape[1],-1),self.weight2)
            out.append(output)
        out = torch.stack(out,dim=0)
        return out

隣接行列の重みと自己ループの重みを使って畳み込んでます。

Readout.py
class Readout(Module):

    def __init__(self,in_features):
        super(Readout,self).__init__()
        self.in_features = in_features
        self.weight = Parameter(torch.FloatTensor(in_features,in_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv=1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv,stdv)
    
    def forward(self,input):
        input = torch.sum(input,1)
        return torch.mm(input,self.weight)

node数の影響をなくすためにnodeの次元で和を取ってます。

GCN.py
import torch.nn as nn
import torch.nn.functional as F
from layer import GraphConvolution,Readout

class GCN(nn.Module):
    def __init__(self,nfeat,nnode,nhid,dropout):
        super(GCN,self).__init__()

        self.gc1 = GraphConvolution(nfeat,nhid)
        self.gc2 = GraphConvolution(nhid,nhid)
        self.dropout = dropout
        self.rd = Readout(nhid)
        self.fc1 = nn.Linear(nhid,1)
        nn.init.kaiming_normal_(self.fc1.weight)
    
    def forward(self,x,adj):
        x = F.relu(self.gc1(x,adj))
        x = F.dropout(x,self.dropout,training=self.training)
        x = F.relu(self.gc2(x,adj))
        x = F.dropout(x,self.dropout,training=self.training)
        x = F.relu(self.rd(x))
        x = self.fc1(x)
        return x

今回使用したモデルです。

#まとめ
Graph Convolutionは今ホットな分野らしいので触ってみることをおすすめします。
今回使用したコードはこれです。

#参考サイト
香川大学農学部 ケミカルバイオロジー研究室

17
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
17

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?