LoginSignup
0
0

More than 3 years have passed since last update.

DeepChemのConvMolFeaturizerでfeaturizeしたミニバッチを返すPyTorchのDataLoaderを実装する

Last updated at Posted at 2020-11-08

はじめに

Keras, TensorflowからPyTorchに乗り換えることとした。
そして、PyTorchを使って化合物によるGraph Convolutional Network(GCN)を実装することとした。
まずは、SMILESで表される化合物を学習に利用できる形に変換する必要がある。
これらの処理を自前で実装してもよいが、Keras, Tensorflow ベースのライブラリである DeepChem の前処理を流用すれば楽ができると考えた。
そこで SMILES を DeepChem の ConvMolFeaturizer で feturize し、それを Pytorch の DataLoaderで使えるようにしてみた。
これにより、面倒な化合物のハンドリング処理を自前で実装することなく、GCNの実装に集中することができると目論んでいる。

環境

  • PyTorch 1.7.0
  • DeepChem 2.3

実装方法

  • Datasetは、単純にSMILESと正解データのリストを保持するものとした。
  • ミニバッチ毎に、ミニバッチ内の全化合物をグラフに変換し、結合次数行列や、隣接行列を生成する必要があるため、collate_fn を独自に実装し、DataLoader の引数に与えることとした。
  • collate_fn では、DeepChem の ConvMolFeaturizer によりSMILES を featulize しリスト化したものを、ConvMolクラスのagglomerate_molsメソッドに与える。これによりミニバッチ内の全化合物の結合次数行列、隣接行列が生成されるため、それぞれPyTorchのテンソル形式に変換し、正解データと共に返却している。

ソース

import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol

class GCNDataset(data.Dataset):

    def __init__(self, smiles_list, label_list):
        self.smiles_list = smiles_list
        self.label_list = label_list

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, index):
        return self.smiles_list[index], self.label_list[index]


def gcn_collate_fn(batch):
    from rdkit import Chem
    cmf = ConvMolFeaturizer()

    mols = []
    labels = []

    for sample, label in batch:
        mols.append(Chem.MolFromSmiles(sample))
        labels.append(torch.tensor(label))

    conv_mols = cmf.featurize(mols)
    multiConvMol = ConvMol.agglomerate_mols(conv_mols)

    atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
    deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
    membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
    return atom_feature, deg_slice, membership, labels


def main():
    dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
    dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)
    for atom_feature, deg_slice, membership, labels in dataloader:
        print(atom_feature)
        print(deg_slice)
        print(membership)

if __name__ == "__main__":
    main()

実行結果

3化合物によるミニバッチは以下の通りとなる。
3化合物内の12原子の特徴および、結合次数行列、隣接行列が生成される。
これらについては別の機会に説明する。

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.]], dtype=torch.float64)
tensor([[ 0.,  0.],
        [ 0.,  6.],
        [ 6.,  6.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.]], dtype=torch.float64)
tensor([0., 0., 1., 1., 2., 2., 0., 1., 1., 2., 2., 2.], dtype=torch.float64)

今後

今後は、GCNモデルのコード、および今回のDataLoaderを用いて学習を行うコードを書いていく。

感想

Kerasの窮屈感、Tensorflowのそっけなさに比べ、PyTorchの丁度良さがすごく心地いい(今のところ)。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0