#はじめに
DeepChem の GraphConvLayer を Pytorch のカスタムレイヤーで実装してみた。
#環境
- DeepChem 2.3
- PyTorch 1.7.0
#ソース
前回作成したDataSet, DataLorderセットを使ってミニバッチを取り出し、GraphConvに食わせて出力してみた。
import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
class GraphConv(nn.Module):
def __init__(self,
in_channel,
out_channel,
min_deg=0,
max_deg=10,
activation=lambda x: x
):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.min_degree = min_deg
self.max_degree = max_deg
num_deg = 2 * self.max_degree + (1 - self.min_degree)
self.W_list = [
nn.Parameter(torch.Tensor(
np.random.normal(size=(in_channel, out_channel))).double())
for k in range(num_deg)]
self.b_list = [
nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]
def forward(self, atom_features, deg_slice, deg_adj_lists):
#print("deg_adj_list")
print(deg_adj_lists)
W = iter(self.W_list)
b = iter(self.b_list)
# Sum all neighbors using adjacency matrix
deg_summed = self.sum_neigh(atom_features, deg_adj_lists)
# Get collection of modified atom features
new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]
for deg in range(1, self.max_degree + 1):
# Obtain relevant atoms for this degree
rel_atoms = deg_summed[deg - 1]
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Apply hidden affine to relevant atoms and append
rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
self_out = torch.matmul(self_atoms, next(W)) + next(b)
out = rel_out + self_out
new_rel_atoms_collection[deg - self.min_degree] = out
# Determine the min_deg=0 case
if self.min_degree == 0:
deg = 0
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Only use the self layer
out = torch.matmul(self_atoms, next(W)) + next(b)
new_rel_atoms_collection[deg - self.min_degree] = out
# Combine all atoms back into the list
#print(new_rel_atoms_collection)
atom_features = torch.cat(new_rel_atoms_collection, 0)
return atom_features
def sum_neigh(self, atoms, deg_adj_lists):
"""Store the summed atoms by degree"""
deg_summed = self.max_degree * [None]
for deg in range(1, self.max_degree + 1):
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atoms[index]
# Sum along neighbors as well as self, and store
summed_atoms = torch.sum(gathered_atoms, 1)
deg_summed[deg - 1] = summed_atoms
return deg_summed
class GCNDataset(data.Dataset):
def __init__(self, smiles_list, label_list):
self.smiles_list = smiles_list
self.label_list = label_list
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, index):
return self.smiles_list[index], self.label_list[index]
def gcn_collate_fn(batch):
from rdkit import Chem
cmf = ConvMolFeaturizer()
mols = []
labels = []
for sample, label in batch:
mols.append(Chem.MolFromSmiles(sample))
labels.append(torch.tensor(label))
conv_mols = cmf.featurize(mols)
multiConvMol = ConvMol.agglomerate_mols(conv_mols)
atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
deg_adj_lists = []
for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])
return atom_feature, deg_slice, membership, deg_adj_lists, labels
def main():
dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
model = GraphConv(75, 20)
for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
print("atom_feature")
print(atom_feature)
print("deg_slice")
print(deg_slice)
print("membership")
print(membership)
print("result")
print(model(atom_feature, deg_slice, deg_adj_lists))
if __name__ == "__main__":
main()
###結果
はい、どん。
とりあえず、結果の形状は、原子数 x 20次元(75次元から畳み込みで圧縮した次元)であってるようだ。
atom_feature
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 1., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0.]], dtype=torch.float64)
deg_slice
tensor([[ 0., 0.],
[ 0., 6.],
[ 6., 6.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.],
[12., 0.]], dtype=torch.float64)
membership
tensor([0., 0., 1., 1., 2., 2., 0., 1., 1., 2., 2., 2.], dtype=torch.float64)
result
tensor([[-0.2910, 2.2571, 1.6459, -4.0687, -3.3893, 4.3271, 1.5363, 1.2956,
-1.1717, 0.8923, -0.9046, -3.9463, 4.2884, -3.5612, -9.7249, 1.9113,
1.7882, 1.6279, -3.7770, -6.3691],
[-0.2910, 2.2571, 1.6459, -4.0687, -3.3893, 4.3271, 1.5363, 1.2956,
-1.1717, 0.8923, -0.9046, -3.9463, 4.2884, -3.5612, -9.7249, 1.9113,
1.7882, 1.6279, -3.7770, -6.3691],
[-0.2910, 2.2571, 1.6459, -4.0687, -3.3893, 4.3271, 1.5363, 1.2956,
-1.1717, 0.8923, -0.9046, -3.9463, 4.2884, -3.5612, -9.7249, 1.9113,
1.7882, 1.6279, -3.7770, -6.3691],
[-0.2910, 2.2571, 1.6459, -4.0687, -3.3893, 4.3271, 1.5363, 1.2956,
-1.1717, 0.8923, -0.9046, -3.9463, 4.2884, -3.5612, -9.7249, 1.9113,
1.7882, 1.6279, -3.7770, -6.3691],
[-0.2910, 2.2571, 1.6459, -4.0687, -3.3893, 4.3271, 1.5363, 1.2956,
-1.1717, 0.8923, -0.9046, -3.9463, 4.2884, -3.5612, -9.7249, 1.9113,
1.7882, 1.6279, -3.7770, -6.3691],
[-0.2910, 2.2571, 1.6459, -4.0687, -3.3893, 4.3271, 1.5363, 1.2956,
-1.1717, 0.8923, -0.9046, -3.9463, 4.2884, -3.5612, -9.7249, 1.9113,
1.7882, 1.6279, -3.7770, -6.3691],
[-1.6645, 6.3024, 0.6540, -0.7638, 5.3761, -6.3710, -0.3202, 1.3862,
6.6121, -0.5707, -8.2441, -5.8404, 4.4354, 0.8659, -2.3474, -4.8642,
8.3175, 0.1378, -4.6038, -3.9733],
[-0.3320, 1.6265, -0.2117, -0.5792, 5.7710, 0.5828, -0.7252, 3.6408,
7.6525, -0.3339, -6.1131, -2.3356, 3.6018, 1.5834, -2.7556, -4.1401,
1.4335, -0.4723, -1.7117, -3.6721],
[-0.3320, 1.6265, -0.2117, -0.5792, 5.7710, 0.5828, -0.7252, 3.6408,
7.6525, -0.3339, -6.1131, -2.3356, 3.6018, 1.5834, -2.7556, -4.1401,
1.4335, -0.4723, -1.7117, -3.6721],
[-0.3320, 1.6265, -0.2117, -0.5792, 5.7710, 0.5828, -0.7252, 3.6408,
7.6525, -0.3339, -6.1131, -2.3356, 3.6018, 1.5834, -2.7556, -4.1401,
1.4335, -0.4723, -1.7117, -3.6721],
[ 1.0006, -3.0494, -1.0774, -0.3946, 6.1658, 7.5366, -1.1302, 5.8955,
8.6929, -0.0971, -3.9820, 1.1691, 2.7682, 2.3009, -3.1638, -3.4160,
-5.4505, -1.0824, 1.1805, -3.3708],
[-0.3320, 1.6265, -0.2117, -0.5792, 5.7710, 0.5828, -0.7252, 3.6408,
7.6525, -0.3339, -6.1131, -2.3356, 3.6018, 1.5834, -2.7556, -4.1401,
1.4335, -0.4723, -1.7117, -3.6721]], dtype=torch.float64,
grad_fn=<CatBackward>)