6
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

結晶・分子グラフの効率的なDataLoader実装を考える

Last updated at Posted at 2022-09-10

はじめに

最近は結晶構造や分子構造をグラフ構造で表現し、機械学習を行うことが流行っています。特に、DFT計算でのエネルギーなどの出力をグラフニューラルネットワーク(GNN)で学習する、ニューラルネットワークポテンシャル(NNP)が非常に注目されています。ハイスループット計算により大規模なDFTデータセットの作成が可能になりつつあり、学習を行う際には結晶構造をグラフ化するコストも考える必要があります。

そこでこの記事では簡便かつ高速にpytorchのDataLoaderグラフデータを取得する方法を考えます。

必要なライブラリ

  • ase
  • lmdb
  • ocpmodels (グラフ化のために、AtomsToGraphのみ使用)
  • torch
  • torch-geometric

実装方法

以下の3通りを試してみます。

  • aseのデータベースから読み込み
  • LMDBから読み込み
  • cifファイルから読み込み

ase

データベースを作成

まずaseのデータベースを用いてAtomsオブジェクトを保存します。今回はAl構造に適当に変位を与えて、EMTでエネルギー等を計算して保存しておきます。

from ase.calculators.emt import EMT
from ase.db import connect
from ase.io import read
import numpy as np

atoms = read("Al.cif")
atoms *= (2,2,2) # make supercell

base_pos = atoms.get_positions()
calc = EMT()

with connect("ase_toy_data.db") as conn:
    for i in range(100):
        atoms_rattle = atoms.copy()
        atoms_rattle.set_positions(
            base_pos + np.random.normal(scale=0.1, size=base_positions.shape)
        )
        calc.calculate(atoms_rattle)
        conn.write(atoms_rattle, data=calc.results)

pytochのデータセット作成

pytorchのDatasetを定義します。

from __future__ import annotations

from collections.abc import Callable, Iterable

from ocpmodels.preprocessing import AtomsToGraphs
import torch
from torch.utils.data import BatchSampler, Dataset, Sampler, SequentialSampler 
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader


class AseDataset(Dataset):
    """
    Dataset of ase database.
    
    Args:
        db_path: Path to the ase database.
        convert_fn: Function to convert :class:`.Atoms` to PyG :class:`.Data`.
    """
    def __init__(
        self,
        db_path: str,
        convert_fn: Callable[[Atoms], Data] = None
    ):
        self.db_path = db_path  
        if convert_fn is None:
            convert_fn = AtomsToGraphs().convert       
        self.convert_fn = convert_fn
    
    def __len__(self) -> int:
        with connect(self.db_path) as conn:
            n = conn.count()
        return n

    
class AseSingleDataset(AseDataset):
    """
    Dataset to get a single graph.
    """
    def __getitem__(self, idx: int) -> Data:
        with connect(self.db_path) as conn:
            atoms = conn.get_atoms(idx+1, add_additional_information=True)
        graph = self.convert_fn(atoms)
        graph = add_info(graph, atoms.info["data"])            
        return graph


class AseBatchedDataset(AseDataset):
    """
    Dataset to get a batched graph.
    """
    def __getitem__(self, idx: Iterable) -> list[Data]:
        datas = []
        with connect(self.db_path) as conn:
            for i in idx:
                atoms = conn.get_atoms(i+1, add_additional_information=True)
                graph = self.convert_fn(atoms)
                graph = add_info(graph, atoms.info["data"])     
                datas.append(graph)
        return datas
    
    
def add_info(data: Data, info: dict) -> Data:
    for k, v in info.items():
        if isinstance(v, (int, float, list, np.ndarray)):
            data[k] = torch.tensor(v)
    return data

普通の実装ではAseSingleDatasetのようにデータを一つずつ取得しますが、データベースへの接続が過多になるので、AseBatchedDatasetも作成しました。一度の接続で複数のAtomsを取得します。なお、今回の実装ではaseのデータベースはsqlite3を用いていますが、for文ではなく一度に取得するクエリを書いても速度はむしろ遅くなってしまったので今回は記載していません。

速度検証

まず、速度検証用の関数を定義します。データセット全てのデータ取得を10周します。

from time import perf_counter

def get_fetch_time(
    dataset: Dataset,
    loader_cls: type[torch.utils.loader.DataLoader],
    collate_fn: Callable[[...], Batch] | None = None,
    sampler: Sampler | None = None,
    batch_size: int = 16,
    n_iter: int = 10
) -> float:
    """
    Measure the time required for data fetch.
    
    Args:
        dataset: Dataset
        loader_cls: DataLoader class (torch or torch_geometric)
        collate_fn: Callable function to get a batch.
        sampler: :class:`.Sampler` class (for :class:`AseBatchedDataset`).
            If this is not None (:class:`.BatchSampler` is assumed),
            ``batch_size`` will be replaced by ``1``.
        batch_size: Batch size
        n_iter: The number of iterations.
    """
    if sampler is not None:
        _batch_size = 1
    else:
        _batch_size = batch_size

    loader = loader_cls(
        dataset, batch_size=_batch_size, collate_fn=collate_fn,
        sampler=sampler, shuffle=False
    )
    start = perf_counter()
    
    for i in range(n_iter):
        for data in loader:
            pass
    end = perf_counter()
    return end - start

以下のように、AseSingleDatasetAseBatchedDatasetでデータ取得する時間を計測します。

def collate(datas: list[list[Datas]]) -> Batch:
    """
    DataLoader with ``batch_size=1`` generates ``datas`` of
    (1, batch_size) shape list. Collate function transforms it
    to :class:`.Batch` object.
    """
    return Batch.from_data_list(datas[0])

dataset_single = AseSingleDataset("ase_toy_data.db")

dataset_batch = AseBatchedDataset("ase_toy_data.db")
sampler = BatchSampler(SequentialSampler(dataset_batch), 16, False)

t1 = get_fetch_time(dataset_single, DataLoader)
t2 = get_fetch_time(
    dataset_batch,
    torch.utils.data.DataLoader,
    collate_fn=collate,
    sampler=sampler
)
print(t1, t2)
output
single: 3.5297205930000928, batch: 2.992043105999983

バッチで取得することで若干速くなりました。

LMDB

データベースの作成

次にLMDBでAtomsを保存します。

import pickle

import lmdb

with connect("ase_toy_data.db") as conn:
    toy_datas = [
        conn.get_atoms(i+1, add_additional_information=True)
        for i in range(conn.count())
    ]
    
db = lmdb.open(
    "lmdb_toy_data.lmdb", subdir=False, meminit=False, map_async=True
)

for i, toy_data in enumerate(toy_datas):
    txn = db.begin(write=True)
    txn.put(
        f"{i}".encode("ascii"),
        pickle.dumps(toy_data, protocol=-1),
    )
    txn.commit()
    
txn = db.begin(write=True)
txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1))
txn.commit()
db.sync()
db.close()

pytochのデータセット作成

class LmdbDataset(Dataset):
    """
    LMDB dataset.
    
    Args:
        db_path: Path to the ase database.
        convert_fn: Function to convert :class:`.Atoms` to PyG :class:`.Data`.
    """
    def __init__(
        self,
        db_path: str,
        convert_fn: Callable[[Atoms], Data] = None,
        is_graph: bool = False
    ):
        if convert_fn is None:
            convert_fn = AtomsToGraphs().convert       
        self.convert_fn = convert_fn
        
        self.env = lmdb.open(
            db_path, subdir=False, readonly=True, lock=False,
            readahead=False, meminit=False, max_readers=1,
        )
        self._keys = [
            f"{j}".encode("ascii")
            for j in range(self.env.stat()["entries"] - 1)
        ]
        self.n_datas = len(self._keys)
        self.is_graph = is_graph

    def close(self) -> None:
        """Close a database."""
        self.env.close()
            
    def __len__(self):
        return self.n_datas

    def __getitem__(self, idx):
        data = self.env.begin().get(self._keys[idx])
        data = pickle.loads(data)
        if not self.is_graph:
            graph = self.convert_fn(data)
            graph = add_info(graph, data.info["data"])
        else:
            graph = data
        return graph

速度検証

dataset = LmdbDataset("lmdb_toy_data.lmdb")
print(get_fetch_time(dataset, DataLoader))
output
2.4930231989996173

aseよりも少し速くなりました。単純なkey-value型なので高速に読み込むことが可能であるためと思われます。
なお、LMDBであればグラフ化済みのデータを保存しておくこともできます。

dataset = LmdbDataset("lmdb_toy_data.lmdb")

db = lmdb.open(
    "lmdb_toy_graph_data.lmdb", subdir=False, meminit=False, map_async=True
)
for i, graph in enumerate(dataset):
    txn = db.begin(write=True)
    txn.put(
        f"{i}".encode("ascii"),
        pickle.dumps(graph, protocol=-1),
    )
    txn.commit()
    
txn = db.begin(write=True)
txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1))
txn.commit()
db.sync()
db.close()

dataset_graph = LmdbDataset("lmdb_toy_graph_data.lmdb", is_graph=True)
print(get_fetch_time(dataset_graph, DataLoader))
outut
0.7136490009997942

グラフ化が省略できるのでかなり高速化されました。一方で、今回のケースでは、100構造のデータベース容量は、Atoms保存では432 KBであるのに対しグラフ保存では7.3 MBになってしまいます。ストレージに余裕があるのであればあらかじめグラフ化しておくのが良いでしょう。

cif

最後に、cifファイルのリストから読み込む例を考えます。

データベースの作成

import os

from ase.io import write

os.mkdir("structures")
properties = {}
for i, data in enumerate(toy_datas):
    write(f"structures/{i}.cif", data)
    properties[i] = data.info["data"]

with open("structures/properties.pkl", "wb") as f:
    pickle.dump(properties, f)

pytochのデータセット作成

from pathlib import Path

class CifDataset(Dataset):
    """
    Cif files dataset.
    
    Args:
        directory: Path to the directory which stores cif files.
        convert_fn: Function to convert :class:`.Atoms` to PyG :class:`.Data`.
    """
    def __init__(
        self,
        directory: str,
        convert_fn: Callable[[Atoms], Data] = None,
        is_graph: bool = False
    ):
        if convert_fn is None:
            convert_fn = AtomsToGraphs().convert       
        self.convert_fn = convert_fn
        
        files = sorted([str(p) for p in Path(directory).glob("*.cif")])
        self.files = files
        self.n_datas = len(self.files)
        
        with open(Path(directory) / "properties.pkl", "rb") as f:
            properties = pickle.load(f)
        self.properties = properties
            
    def __len__(self) -> int:
        return self.n_datas

    def __getitem__(self, idx: int) -> Data:
        atoms = read(self.files[idx])
        props = self.properties[idx]
        graph = self.convert_fn(atoms)
        graph = add_info(graph, props)
        return graph

大量データを想定しているのでメモリにAtomsオブジェクトはのせない実装にしています。

速度検証

dataset_cif = CifDataset("structures/")
print(get_fetch_time(dataset_cif, DataLoader))
output
9.648674191999817

cifファイルを毎回読み込むため、かなり速度が遅くなってしまいます。特別な理由がない限り避けた方がいいでしょう。

まとめ

各実装でデータのロードにかかった時間は以下のとおりです。

手法 所要時間
ase (single) 3.530
ase (batched) 2.992
LMDB (Atoms) 2.493
LMDB (graph) 0.714
cif 9.649

基本的にはLMDBなどのkey-value型のデータベースを使えばよさそうです。pickle化できるものは保存でき、使い勝手も良いので実装の自由度も高いと思います。
グラフ化するときのカットオフ半径はハイパーパラメータなので、LMDB(graph)はカットオフ半径ごとにデータベースを作る必要があり、グラフの容量も大きいので大規模データでは厳しい場面が多そうです。高速なのでストレージに余裕があれば積極的に使いたいところです。

今回の実装した中では、ストレージに余裕があればLMDB(graph)、そうでなければLMDB(Atoms)での実装が良いという結果になりました。

6
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?