#はじめに
昨日のGraphConvLayerに続いて、DeepChem の GraphPoolLayer を Pytorch のカスタムレイヤーで実装してみた。
#環境
- DeepChem 2.3
- PyTorch 1.7.0
#ソース
DeepChemのGraphPoolLayerをPyTorchに移植し、前回のGraphConvLayerの出力結果を、作成したGraphPoolLayerに食わせてみた。
import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
class GraphConv(nn.Module):
def __init__(self,
in_channel,
out_channel,
min_deg=0,
max_deg=10,
activation=lambda x: x
):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.min_degree = min_deg
self.max_degree = max_deg
num_deg = 2 * self.max_degree + (1 - self.min_degree)
self.W_list = [
nn.Parameter(torch.Tensor(
np.random.normal(size=(in_channel, out_channel))).double())
for k in range(num_deg)]
self.b_list = [
nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]
def forward(self, atom_features, deg_slice, deg_adj_lists):
#print("deg_adj_list")
#print(deg_adj_lists)
W = iter(self.W_list)
b = iter(self.b_list)
# Sum all neighbors using adjacency matrix
deg_summed = self.sum_neigh(atom_features, deg_adj_lists)
# Get collection of modified atom features
new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]
for deg in range(1, self.max_degree + 1):
# Obtain relevant atoms for this degree
rel_atoms = deg_summed[deg - 1]
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Apply hidden affine to relevant atoms and append
rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
self_out = torch.matmul(self_atoms, next(W)) + next(b)
out = rel_out + self_out
new_rel_atoms_collection[deg - self.min_degree] = out
# Determine the min_deg=0 case
if self.min_degree == 0:
deg = 0
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Only use the self layer
out = torch.matmul(self_atoms, next(W)) + next(b)
new_rel_atoms_collection[deg - self.min_degree] = out
# Combine all atoms back into the list
#print(new_rel_atoms_collection)
atom_features = torch.cat(new_rel_atoms_collection, 0)
return atom_features
def sum_neigh(self, atoms, deg_adj_lists):
"""Store the summed atoms by degree"""
deg_summed = self.max_degree * [None]
for deg in range(1, self.max_degree + 1):
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atoms[index]
# Sum along neighbors as well as self, and store
summed_atoms = torch.sum(gathered_atoms, 1)
deg_summed[deg - 1] = summed_atoms
return deg_summed
class GraphPool(nn.Module):
def __init__(self, min_degree=0, max_degree=10):
super().__init__()
self.min_degree = min_degree
self.max_degree = max_degree
def forward(self, atom_features, deg_slice, deg_adj_lists):
# Perform the mol gather
deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]
# Tensorflow correctly processes empty lists when using concat
for deg in range(1, self.max_degree + 1):
# Get self atoms
begin = deg_slice[deg - self.min_degree, 0]
size = deg_slice[deg - self.min_degree, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
# Expand dims
self_atoms = torch.unsqueeze(self_atoms, 1)
# always deg-1 for deg_adj_lists
index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
gathered_atoms = atom_features[index]
gathered_atoms = torch.cat([self_atoms, gathered_atoms], 1)
if gathered_atoms.shape[0] > 0:
maxed_atoms = torch.max(gathered_atoms, 1)[0]
else:
maxed_atoms = torch.Tensor([])
deg_maxed[deg - self.min_degree] = maxed_atoms
if self.min_degree == 0:
begin = deg_slice[0, 0]
size = deg_slice[0, 1]
self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
deg_maxed[0] = self_atoms
return torch.cat(deg_maxed, 0)
class GCNDataset(data.Dataset):
def __init__(self, smiles_list, label_list):
self.smiles_list = smiles_list
self.label_list = label_list
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, index):
return self.smiles_list[index], self.label_list[index]
def gcn_collate_fn(batch):
from rdkit import Chem
cmf = ConvMolFeaturizer()
mols = []
labels = []
for sample, label in batch:
mols.append(Chem.MolFromSmiles(sample))
labels.append(torch.tensor(label))
conv_mols = cmf.featurize(mols)
multiConvMol = ConvMol.agglomerate_mols(conv_mols)
atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
deg_adj_lists = []
for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])
return atom_feature, deg_slice, membership, deg_adj_lists, labels
def main():
dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
gc = GraphConv(75, 20)
gp = GraphPool()
for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
print("atom_feature")
print(atom_feature)
print("deg_slice")
print(deg_slice)
print("membership")
print(membership)
print("result")
gc_out = gc(atom_feature, deg_slice, deg_adj_lists)
gp_out = gp(gc_out, deg_slice, deg_adj_lists)
print(gp_out)
if __name__ == "__main__":
main()
###結果
はい、どん。
とりあえず、結果の形状は、原子数 x 20次元であり、GraphConvLayerの出力した次元を維持している ためあってるようだ。
相変わらずこのホワイトボックス感がいいね(前回とコメントが全く同じで手抜き)。
しかし TensorFlowと微妙に演算が違っていて、ちょいちょい調べるのに手間はかかる。
tensor([[ 1.8113e+00, 1.1862e+00, 1.3068e+00, 1.8266e+00, 6.0706e-03,
7.2303e+00, -8.7022e-01, 1.1336e+00, -5.1411e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 3.8385e+00, 1.7524e+00, 5.2120e+00,
2.8675e+00, 4.8746e+00, -2.5079e+00, 8.1260e+00, 7.8020e+00],
[ 1.8113e+00, 1.1862e+00, 1.3068e+00, 1.8266e+00, 6.0706e-03,
7.2303e+00, -8.7022e-01, 1.1336e+00, -5.1411e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 3.8385e+00, 1.7524e+00, 5.2120e+00,
2.8675e+00, 4.8746e+00, -2.5079e+00, 8.1260e+00, 7.8020e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00],
[ 3.0749e+00, 2.2618e+00, 8.2658e-02, 3.1331e+00, 6.0706e-03,
4.5357e+00, -8.7022e-01, 1.1336e+00, -5.9143e+00, -3.3319e-02,
1.8048e+00, 4.7143e+00, 5.9190e+00, 1.7524e+00, 5.2120e+00,
1.5569e+00, 3.0329e+00, -2.5079e+00, 4.3327e+00, 4.7906e+00]],
dtype=torch.float64, grad_fn=<MaxBackward0>)