5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyG (PyTorch Geometric) で Graph Pooling Neural Network

Last updated at Posted at 2022-08-10

グラフ構造を深層学習する PyG (PyTorch Geometric) を Google Colaboratory 上で使ってみました。今回は、Graph Pooling Neural Network を使うことがテーマです。題材として、化学情報学のメインテーマの1つである、分子構造から物性を予測する問題を解いてみます。

PyG (PyTorch Geometric) インストール

PyG (PyTorch Geometric) のレポジトリは https://github.com/pyg-team/pytorch_geometric にあります。また、コードはチュートリアルドキュメント https://pytorch-geometric.readthedocs.io/en/latest/index.html を参考にしています。

import os
import torch

torch.manual_seed(53)
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html

import torch_cluster
import torch_geometric
1.12.0+cu113
[K     |████████████████████████████████| 7.9 MB 33.1 MB/s 
[K     |████████████████████████████████| 3.5 MB 37.6 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu113.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.4 MB)
[K     |████████████████████████████████| 2.4 MB 42.5 MB/s 
[?25hInstalling collected packages: torch-cluster
Successfully installed torch-cluster-1.6.0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')

RDKit インストール

化学情報学分野でよく使われている、化学構造を情報学的に取り扱えるフリーソフト RDKit をインストールします。

!pip install git+https://github.com/maskot1977/rdkit_installer.git
from rdkit_installer import install
install.from_miniconda(rdkit_version="2020.09.1")
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/maskot1977/rdkit_installer.git
  Cloning https://github.com/maskot1977/rdkit_installer.git to /tmp/pip-req-build-a_47g9ih
  Running command git clone -q https://github.com/maskot1977/rdkit_installer.git /tmp/pip-req-build-a_47g9ih
Building wheels for collected packages: rdkit-installer
  Building wheel for rdkit-installer (setup.py) ... [?25l[?25hdone
  Created wheel for rdkit-installer: filename=rdkit_installer-0.2.0-py3-none-any.whl size=7986 sha256=a2903c57a9e859627ed234c67f56f1d813f5e79c0ccae208fc917ca3c8d86293
  Stored in directory: /tmp/pip-ephem-wheel-cache-ttkjmkat/wheels/e6/72/a5/218f5f909a3a87c1ec1ccec03ac61298947fb5f1efa517eefa
Successfully built rdkit-installer
Installing collected packages: rdkit-installer
Successfully installed rdkit-installer-0.2.0


add /root/miniconda/lib/python3.7/site-packages to PYTHONPATH
python version: 3.7.13
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-4.7.12-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing rdkit
done
rdkit-2020.09.1 installation finished!

化学データのダウンロード

今回、化学情報学を題材としますが、その説明変数である「化学構造」と、目的変数である「物性」のデータをダウンロードします。今回は PCCDB というデータベースのデータの一部を使いたいと思います。データの全部は見せられませんが、だいたいこんな形状のデータです。この中の「Open Babel SMILES」というのが説明変数で、分子構造を表すグラフ構造として表現できます。

import pandas as pd

# csvからのデータ読み込み
df_reg = pd.read_csv(url)
df_reg.head(3)
PCCDB-ID Open Babel SMILES HOMO-LUMO gap HOMO energy LUMO energy Dipole moment Excitation energy (1st) Oscillator strength (1st) Excitation energy (2nd) Oscillator strength (2nd) Excitation energy (3rd) Oscillator strength (3rd) Excitation energy (4th) Oscillator strength (4th) Num. of H bond acceptor Num. of H bond donor TPSA logP Molecular refractivity Melting point
0 15493 CN(CCCN1C(=CC(=C[C@@H](C1=O)C)C)C)C 5.064 -5.440 -0.376 3.07 4.378 0.1218 4.792 0.0021 4.912 0.0125 5.056 NaN 3.0 0.0 23.55 2.204 76.26 43.79
1 20139 OCc1c(C)cc(cc1C)C 6.041 -6.212 -0.171 1.69 4.995 0.0059 5.425 0.1014 5.859 0.0070 5.924 NaN 1.0 1.0 20.23 2.104 47.47 40.29
2 7039 OCc1cc(C)cc(c1O)CO 5.576 -5.742 -0.166 4.50 4.556 0.0443 5.059 0.0047 5.419 0.0096 5.583 NaN 3.0 3.0 60.69 0.685 45.69 95.04

目的変数は、この中の「HOMO-LUMO gap」とします。これをそのまま予測しようと思うと回帰問題(regression)になります。また、次のように処理すると分類問題(classification)になります。

import numpy as np

df_cla = pd.DataFrame(np.where(df_reg > df_reg.mean(), 1, 0), columns=df_reg.columns)
df_cla.head(3)
PCCDB-ID Open Babel SMILES HOMO-LUMO gap HOMO energy LUMO energy Dipole moment Excitation energy (1st) Oscillator strength (1st) Excitation energy (2nd) Oscillator strength (2nd) Excitation energy (3rd) Oscillator strength (3rd) Excitation energy (4th) Oscillator strength (4th) Num. of H bond acceptor Num. of H bond donor TPSA logP Molecular refractivity Melting point
0 1 1 1 0 0 1 0 1 0 1 1 0 0 1 0 0 0 0 0 1
1 0 1 1 1 1 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0
2 1 1 1 1 1 1 1 1 1 0 1 1 0 0 0 0 0 0 1 0

分類問題

まずは、次のように説明変数と目的変数をセットして、分類問題を解いてみたいと思います。

smiles = list(df_reg["Open Babel SMILES"])
ys = list(df_cla["HOMO-LUMO gap"])

説明変数であるSMILES文字列から得た分子構造を、PyTorch Geometric で取り扱えるグラフ構造に変換して読み込みます。

from rdkit import Chem
from torch_geometric.data import Data, InMemoryDataset

class MoleculesDataset(InMemoryDataset):
    def __init__(self, smiles, ys, transform = None):
        super().__init__('.', transform)

        boolean = {True:1, False:0}
        hybridization = {'SP':1, 'SP2':2, 'SP3':3, 'SP3D':3.5}
        bondtype = {'SINGLE':1, 'DOUBLE':2, 'AROMATIC':1.5, 'TRIPLE':3}

        datas = []
        for smile, y in zip(smiles, ys):
            mol = Chem.MolFromSmiles(smile)

            embeddings = []
            for atom in mol.GetAtoms():
                a = []
                a.append(atom.GetAtomicNum())
                a.append(atom.GetMass())
                a.append(hybridization[str(atom.GetHybridization())])
                a.append(boolean[atom.IsInRing()])
                a.append(boolean[atom.GetIsAromatic()])
                embeddings.append(a)
            embeddings = torch.tensor(embeddings)

            edges = []
            edge_attr = []
            for bond in mol.GetBonds():
                edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
                b = []
                b.append(bondtype[str(bond.GetBondType())])
                b.append(boolean[bond.GetIsAromatic()])
                b.append(boolean[bond.IsInRing()])
                edge_attr.append(b)
            edges = torch.tensor(edges).T
            edge_attr = torch.tensor(edge_attr)

            y = torch.tensor(y, dtype=torch.long)

            data = Data(x=embeddings, edge_index=edges, y=y, edge_attr=edge_attr)
            datas.append(data)

        self.data, self.slices = self.collate(datas)

torch_geometric.transforms.ToDense を用いて transform することで、全ての大きさの分子グラフを max_nodes x max_nodes のサイズの隣接行列として表現します。

max_nodes = 128
dataset = MoleculesDataset(smiles, ys, transform=torch_geometric.transforms.ToDense(max_nodes))
dataset.data
Data(x=[8430, 5], edge_index=[2, 8654], edge_attr=[8654, 3], y=[633])

train, test, val の3つのデータセットに分割します。

dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = torch_geometric.data.DenseDataLoader(test_dataset, batch_size=32)
val_loader = torch_geometric.data.DenseDataLoader(val_dataset, batch_size=32)
train_loader = torch_geometric.data.DenseDataLoader(train_dataset, batch_size=32)
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DenseDataLoader' is deprecated, use 'loader.DenseDataLoader' instead
  warnings.warn(out)

Graph Pooling を組み込んだ Graph Neural Network を次のように実装します。

from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super(GNN, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))
        self.bns.append(torch.nn.BatchNorm1d(out_channels))

    def forward(self, x, adj, mask=None):
        batch_size, num_nodes, in_channels = x.size()  
        for step in range(len(self.convs)):
            x = self.convs[step](x, adj, mask)
            x = torch.nn.functional.relu(x)
        return x


from math import ceil
class DiffPool(torch.nn.Module):
    def __init__(self, regr=False):
        super(DiffPool, self).__init__()

        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)
        self.gnn1_embed = GNN(dataset.num_features, 64, 64)
        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = GNN(64, 64, num_nodes)
        self.gnn2_embed = GNN(64, 64, 64, lin=False)
        self.gnn3_embed = GNN(64, 64, 64, lin=False)
        self.lin1 = torch.nn.Linear(64, 64)
        self.lin2 = torch.nn.Linear(64, dataset.num_classes)
        self.regr = regr

    def forward(self, x, adj, mask=None):
        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)
        x, adj, l1, e1 = torch_geometric.nn.dense_diff_pool(x, adj, s, mask)
        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)
        x, adj, l2, e2 = torch_geometric.nn.dense_diff_pool(x, adj, s)
        x = self.gnn3_embed(x, adj)
        x = x.mean(dim=1)
        x = torch.nn.functional.relu(self.lin1(x))
        x = self.lin2(x)
        if self.regr:
            return x, l1 + l2, e1 + e2
        else:
            return torch.nn.functional.log_softmax(x, dim=-1), l1 + l2, e1 + e2

トレーニングとテストのコードです。

def train(epoch):
    model.train()
    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, _, _ = model(data.x, data.adj.sum(axis=3), data.mask)
        loss = torch.nn.functional.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj.sum(axis=3), data.mask)[0].max(dim=1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)

モデルとオプティマイザを初期化します。

model = DiffPool().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
max_epoch = 1001

学習を実行します。

best_val_acc = test_acc = 0
best_model = None
loss_hist = []
val_hist = []
test_hist = []
for epoch in range(1, max_epoch):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    test_acc = test(test_loader)
    loss_hist.append(train_loss)
    val_hist.append(val_acc)
    test_hist.append(test_acc)
    if best_model is None or best_val_acc < test_acc:
        best_val_acc = test_acc
        best_model = model
        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 001, Train Loss: 0.6907, Val Acc: 0.5469, Test Acc: 0.7031
Epoch: 021, Train Loss: 0.5864, Val Acc: 0.6562, Test Acc: 0.7188
Epoch: 038, Train Loss: 0.5106, Val Acc: 0.5938, Test Acc: 0.7344
Epoch: 041, Train Loss: 0.5034, Val Acc: 0.6250, Test Acc: 0.7500
Epoch: 050, Train Loss: 0.4889, Val Acc: 0.5938, Test Acc: 0.7812
Epoch: 077, Train Loss: 0.4940, Val Acc: 0.6562, Test Acc: 0.7969
Epoch: 116, Train Loss: 0.4641, Val Acc: 0.6094, Test Acc: 0.8125
Epoch: 140, Train Loss: 0.4201, Val Acc: 0.6719, Test Acc: 0.8281
Epoch: 831, Train Loss: 0.5051, Val Acc: 0.6250, Test Acc: 0.8594

学習曲線はこのようになりました。

import matplotlib.pyplot as plt

plt.plot(loss_hist, label="Train Loss")
plt.legend()
plt.show()
plt.plot(val_hist, label="Val Acc")
plt.plot(test_hist, label="Test Acc")
plt.legend()
plt.show()

RDKit_GraphPooling_26_0.png

RDKit_GraphPooling_26_1.png

回帰問題

基本的には同じですが、今度は回帰問題を解いてみましょう。説明変数は同じですが、目的変数を連続値のものにします。

smiles = list(df_reg["Open Babel SMILES"])
ys = list(df_reg["HOMO-LUMO gap"])

目的変数が異なるので、データセットを作り直します(コードは変更ありません)。

max_nodes = 128
dataset = MoleculesDataset(smiles, ys, transform=torch_geometric.transforms.ToDense(max_nodes))
dataset.data
Data(x=[8430, 5], edge_index=[2, 8654], edge_attr=[8654, 3], y=[633])

train, test, val も作り直します(コードは変更ありません)。

dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = torch_geometric.data.DenseDataLoader(test_dataset, batch_size=32)
val_loader = torch_geometric.data.DenseDataLoader(val_dataset, batch_size=32)
train_loader = torch_geometric.data.DenseDataLoader(train_dataset, batch_size=32)
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DenseDataLoader' is deprecated, use 'loader.DenseDataLoader' instead
  warnings.warn(out)

トレーニングとテストのコードは、多少の変更があります。

def train(epoch):
    model.train()
    loss_all = 0

    mae_loss = torch.nn.L1Loss()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, _, _ = model(data.x, data.adj.sum(axis=3), data.mask)
        loss = mae_loss(output.mean(dim=-1), data.y.ravel())
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    loss_all = 0
    mae_loss = torch.nn.L1Loss()
    for data in loader:
        data = data.to(device)
        output, _, _ = model(data.x, data.adj.sum(axis=3), data.mask)
        loss = mae_loss(output.mean(dim=-1), data.y.ravel())
        loss_all += data.y.size(0) * loss.item()
    return loss_all / len(loader.dataset)

モデルとオプティマイザを初期化します。 regr=True とすることで回帰問題に対応できるように設計してあります。

model = DiffPool(regr=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
max_epoch = 1001

学習を実行します。

best_val_loss = test_loss = None
best_model = None
loss_hist = []
val_hist = []
test_hist = []
for epoch in range(1, max_epoch):
    train_loss = train(epoch)
    val_loss = test(val_loader)
    test_loss = test(test_loader)
    loss_hist.append(train_loss)
    val_hist.append(val_loss)
    test_hist.append(test_loss)
    if best_model is None or best_val_loss > test_loss:
        best_val_loss = test_loss
        best_model = model
        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
          f'Val Loss: {val_loss:.4f}, Test Loss: {test_loss:.4f}')
Epoch: 001, Train Loss: 3.6091, Val Loss: 2.2762, Test Loss: 2.3770
Epoch: 002, Train Loss: 1.4374, Val Loss: 1.2515, Test Loss: 1.1452
Epoch: 003, Train Loss: 1.0540, Val Loss: 0.9974, Test Loss: 0.9199
Epoch: 004, Train Loss: 0.9311, Val Loss: 0.9712, Test Loss: 0.9143
Epoch: 005, Train Loss: 0.9169, Val Loss: 0.9186, Test Loss: 0.8573
Epoch: 006, Train Loss: 0.8947, Val Loss: 0.8962, Test Loss: 0.8457
Epoch: 007, Train Loss: 0.8850, Val Loss: 0.8760, Test Loss: 0.8341
Epoch: 008, Train Loss: 0.8621, Val Loss: 0.8560, Test Loss: 0.8233
Epoch: 009, Train Loss: 0.8439, Val Loss: 0.8357, Test Loss: 0.8082
Epoch: 010, Train Loss: 0.8328, Val Loss: 0.8105, Test Loss: 0.7917
Epoch: 011, Train Loss: 0.8109, Val Loss: 0.7803, Test Loss: 0.7712
Epoch: 012, Train Loss: 0.7810, Val Loss: 0.7437, Test Loss: 0.7493
Epoch: 013, Train Loss: 0.7486, Val Loss: 0.7143, Test Loss: 0.7286
Epoch: 014, Train Loss: 0.7175, Val Loss: 0.6732, Test Loss: 0.6987
Epoch: 016, Train Loss: 0.6576, Val Loss: 0.6195, Test Loss: 0.6701
Epoch: 017, Train Loss: 0.6420, Val Loss: 0.6057, Test Loss: 0.6555
Epoch: 018, Train Loss: 0.6107, Val Loss: 0.5897, Test Loss: 0.6229
Epoch: 019, Train Loss: 0.6087, Val Loss: 0.5664, Test Loss: 0.6090
Epoch: 020, Train Loss: 0.6106, Val Loss: 0.5814, Test Loss: 0.6009
Epoch: 022, Train Loss: 0.5955, Val Loss: 0.5603, Test Loss: 0.5702
Epoch: 025, Train Loss: 0.5763, Val Loss: 0.5616, Test Loss: 0.5589
Epoch: 027, Train Loss: 0.5779, Val Loss: 0.5476, Test Loss: 0.5375
Epoch: 028, Train Loss: 0.5589, Val Loss: 0.5455, Test Loss: 0.5359
Epoch: 029, Train Loss: 0.6167, Val Loss: 0.5357, Test Loss: 0.5314
Epoch: 030, Train Loss: 0.5480, Val Loss: 0.5298, Test Loss: 0.5110
Epoch: 046, Train Loss: 0.5906, Val Loss: 0.5245, Test Loss: 0.5015
Epoch: 053, Train Loss: 0.5741, Val Loss: 0.5223, Test Loss: 0.5003
Epoch: 067, Train Loss: 0.5547, Val Loss: 0.5186, Test Loss: 0.4949
Epoch: 082, Train Loss: 0.5357, Val Loss: 0.5049, Test Loss: 0.4876
Epoch: 143, Train Loss: 0.5167, Val Loss: 0.4952, Test Loss: 0.4818
Epoch: 148, Train Loss: 0.5100, Val Loss: 0.5130, Test Loss: 0.4788
Epoch: 154, Train Loss: 0.5031, Val Loss: 0.5070, Test Loss: 0.4718
Epoch: 156, Train Loss: 0.5088, Val Loss: 0.5221, Test Loss: 0.4616
Epoch: 158, Train Loss: 0.5043, Val Loss: 0.5169, Test Loss: 0.4599
Epoch: 178, Train Loss: 0.4858, Val Loss: 0.5008, Test Loss: 0.4584
Epoch: 193, Train Loss: 0.4740, Val Loss: 0.4940, Test Loss: 0.4536
Epoch: 197, Train Loss: 0.4656, Val Loss: 0.4834, Test Loss: 0.4399
Epoch: 221, Train Loss: 0.4521, Val Loss: 0.4836, Test Loss: 0.4396
Epoch: 226, Train Loss: 0.4412, Val Loss: 0.4616, Test Loss: 0.4303
Epoch: 238, Train Loss: 0.4474, Val Loss: 0.4597, Test Loss: 0.4253
Epoch: 245, Train Loss: 0.4372, Val Loss: 0.4638, Test Loss: 0.4215
Epoch: 247, Train Loss: 0.4174, Val Loss: 0.4485, Test Loss: 0.4198
Epoch: 299, Train Loss: 0.4182, Val Loss: 0.4168, Test Loss: 0.4150
Epoch: 314, Train Loss: 0.4331, Val Loss: 0.4220, Test Loss: 0.4080
Epoch: 682, Train Loss: 0.2717, Val Loss: 0.3798, Test Loss: 0.4003
Epoch: 717, Train Loss: 0.2514, Val Loss: 0.3735, Test Loss: 0.3986
Epoch: 819, Train Loss: 0.2344, Val Loss: 0.3724, Test Loss: 0.3842

学習曲線はこのようになりました。

import matplotlib.pyplot as plt

plt.plot(loss_hist, label="Train Loss")
plt.plot(val_hist, label="Val Loss")
plt.plot(test_hist, label="Test Loss")
plt.yscale('log')
plt.legend()
plt.show()

RDKit_GraphPooling_40_0.png

最後に

Graph Pooling を使った GNN で、graph property を目的変数とした分類問題・回帰問題を解いてみました。改善点は色々とあるでしょうけれども、とりあえず動くものはできました、ということで。

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?