Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

DeepChemのGraphGatherLayerをPyTorchのカスタムレイヤーで実装する

はじめに

GraphConvLayer, GraphPoolLayerに続いて、DeepChem の GraphGatherLayer を Pytorch のカスタムレイヤーで実装してみた。

環境

  • DeepChem 2.3
  • PyTorch 1.7.0

ソース

DeepChemのGraphGatherLayerをPyTorchに移植し、前回のGraphConvLayerの出力結果を、作成したGraphPoolLayerに食わせてみた。

import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
from torch_scatter import scatter_max


def unsorted_segment_sum(data, segment_ids, num_segments):

    # segment_ids is a 1-D tensor repeat it to have the same shape as data
    if len(segment_ids.shape) == 1:
        s = torch.prod(torch.tensor(data.shape[1:])).long()
        segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])

    shape = [num_segments] + list(data.shape[1:])
    tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
    tensor = tensor.type(data.dtype)
    return tensor

class GraphConv(nn.Module):

    def __init__(self,
               in_channel,
               out_channel,
               min_deg=0,
               max_deg=10,
               activation=lambda x: x
               ):

        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.min_degree = min_deg
        self.max_degree = max_deg

        num_deg = 2 * self.max_degree + (1 - self.min_degree)

        self.W_list = [
            nn.Parameter(torch.Tensor(
                np.random.normal(size=(in_channel, out_channel))).double())
            for k in range(num_deg)]

        self.b_list = [
            nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]

    def forward(self, atom_features, deg_slice, deg_adj_lists):

        #print("deg_adj_list")
        #print(deg_adj_lists)

        W = iter(self.W_list)
        b = iter(self.b_list)

        # Sum all neighbors using adjacency matrix
        deg_summed = self.sum_neigh(atom_features, deg_adj_lists)

        # Get collection of modified atom features
        new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

        for deg in range(1, self.max_degree + 1):
            # Obtain relevant atoms for this degree
            rel_atoms = deg_summed[deg - 1]

            # Get self atoms
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]

            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Apply hidden affine to relevant atoms and append
            rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
            self_out = torch.matmul(self_atoms, next(W)) + next(b)

            out = rel_out + self_out
            new_rel_atoms_collection[deg - self.min_degree] = out

        # Determine the min_deg=0 case
        if self.min_degree == 0:
            deg = 0

            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Only use the self layer
            out = torch.matmul(self_atoms, next(W)) + next(b)

            new_rel_atoms_collection[deg - self.min_degree] = out

        # Combine all atoms back into the list
        #print(new_rel_atoms_collection)
        atom_features = torch.cat(new_rel_atoms_collection, 0)

        return atom_features


    def sum_neigh(self, atoms, deg_adj_lists):
        """Store the summed atoms by degree"""
        deg_summed = self.max_degree * [None]

        for deg in range(1, self.max_degree + 1):
            index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
            gathered_atoms = atoms[index]

            # Sum along neighbors as well as self, and store
            summed_atoms = torch.sum(gathered_atoms, 1)
            deg_summed[deg - 1] = summed_atoms

        return deg_summed


class GraphPool(nn.Module):

    def __init__(self, min_degree=0, max_degree=10):
        super().__init__()
        self.min_degree = min_degree
        self.max_degree = max_degree


    def forward(self, atom_features, deg_slice, deg_adj_lists):

        # Perform the mol gather
        deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]

        # Tensorflow correctly processes empty lists when using concat
        for deg in range(1, self.max_degree + 1):
            # Get self atoms
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Expand dims
            self_atoms = torch.unsqueeze(self_atoms, 1)

            # always deg-1 for deg_adj_lists
            index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)

            gathered_atoms = atom_features[index]
            gathered_atoms = torch.cat([self_atoms, gathered_atoms], 1)

            if gathered_atoms.shape[0] > 0:
                maxed_atoms = torch.max(gathered_atoms, 1)[0]
            else:
                maxed_atoms = torch.Tensor([])

            deg_maxed[deg - self.min_degree] = maxed_atoms

        if self.min_degree == 0:
            begin = deg_slice[0, 0]
            size = deg_slice[0, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
            deg_maxed[0] = self_atoms

        return torch.cat(deg_maxed, 0)


class GraphGather(nn.Module):

    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def forward(self, atom_features, membership):

        assert self.batch_size > 1, "graph_gather requires batches larger than 1"

        sparse_reps = unsorted_segment_sum(atom_features, membership, self.batch_size)
        max_reps = scatter_max(atom_features, membership, dim=0)
        mol_features = torch.cat([sparse_reps, max_reps[0]], 1)
        return mol_features


class GCNDataset(data.Dataset):

    def __init__(self, smiles_list, label_list):
        self.smiles_list = smiles_list
        self.label_list = label_list

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, index):
        return self.smiles_list[index], self.label_list[index]


def gcn_collate_fn(batch):
    from rdkit import Chem
    cmf = ConvMolFeaturizer()

    mols = []
    labels = []

    for sample, label in batch:
        mols.append(Chem.MolFromSmiles(sample))
        labels.append(torch.tensor(label))

    conv_mols = cmf.featurize(mols)
    multiConvMol = ConvMol.agglomerate_mols(conv_mols)

    atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
    deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
    membership = torch.tensor(multiConvMol.membership, dtype=torch.int64)
    deg_adj_lists = []

    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
        deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])

    return atom_feature, deg_slice, membership, deg_adj_lists,  labels


def main():
    dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
    dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)

    gc = GraphConv(75, 20)
    gp = GraphPool()
    gt = GraphGather(3)
    for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
        print("atom_feature")
        print(atom_feature)
        print("deg_slice")
        print(deg_slice)
        print("membership")
        print(membership)
        print("result")
        gc_out = gc(atom_feature, deg_slice, deg_adj_lists)
        gp_out = gp(gc_out, deg_slice, deg_adj_lists)
        #print(gp_out)
        gt_out = gt(gp_out, membership)
        print(gt_out)


if __name__ == "__main__":
    main()

結果

はい、どん。
とりあえず、結果の形状は、分子数 x 40次元であり、原子が分子に集約されていることが分かる。
相変わらずこのホワイトボックス感がいいね(毎回コメントが全く同じで手抜き)。
今回は、TensorFlowのunsorted_segment_sumとunsorted_segment_max演算を移植するのに超苦労した。検証はこれからということで。

tensor([[ 7.7457,  2.1970, 22.1151,  1.8238,  7.5860, 15.5079, -1.3865,  5.3634,
          0.3872, 24.7713, 30.9865, 13.0032,  5.8331, 12.8195,  9.2520, 16.4660,
         -8.8977, 10.5881, 16.8875,  3.6356,  2.5819,  0.7323,  7.3717,  0.6079,
          2.5287,  5.1693, -0.4622,  1.7878,  0.1291,  8.2571, 10.3288,  4.3344,
          1.9444,  4.2732,  3.0840,  5.4887, -2.9659,  3.5294,  5.6292,  1.2119],
        [12.4624, 16.9705, 26.8321,  4.3047, 17.4027, 23.3370, -1.8487,  7.1511,
          0.2538, 23.2520, 25.0874, 17.3375,  7.7775,  9.7369,  8.3362, 20.8373,
         -4.3081, 14.1175, 17.6781,  6.4011,  3.1156,  4.2426,  6.7080,  1.0762,
          4.3507,  5.8342, -0.4622,  1.7878,  0.0634,  5.8130,  6.2718,  4.3344,
          1.9444,  2.4342,  2.0840,  5.2093, -1.0770,  3.5294,  4.4195,  1.6003],
        [17.1790, 31.7441, 33.5401,  8.6282, 27.2195, 31.1660, -4.6301,  4.2145,
         -1.0452, 29.0650, 31.3592, 15.0395, 14.6857, 12.1711, 10.4202, 26.0466,
          3.5187, 10.4842, 22.0976,  9.1667,  3.6493,  7.7530,  6.7080,  2.1586,
          6.1727,  6.4992, -0.4622,  1.7878,  0.0634,  5.8130,  6.2718,  4.3344,
          3.5990,  2.4342,  2.0840,  5.2093,  1.8909,  3.5294,  4.4195,  1.9887]],
       dtype=torch.float64, grad_fn=<CatBackward>)

参考

kimisyo
主にライフサイエンス分野向けサービス開発を行っています。これからはライフサイエンスの時代です。化学、AI(機械学習)を中心に、学んだこと、経験したことをシェアしていきます。
https://github.com/kimisyo
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away