はじめに
最近は結晶構造や分子構造をグラフ構造で表現し、機械学習を行うことが流行っています。特に、DFT計算でのエネルギーなどの出力をグラフニューラルネットワーク(GNN)で学習する、ニューラルネットワークポテンシャル(NNP)が非常に注目されています。ハイスループット計算により大規模なDFTデータセットの作成が可能になりつつあり、学習を行う際には結晶構造をグラフ化するコストも考える必要があります。
そこでこの記事では簡便かつ高速にpytorchのDataLoader
グラフデータを取得する方法を考えます。
必要なライブラリ
- ase
- lmdb
- ocpmodels (グラフ化のために、
AtomsToGraph
のみ使用) - torch
- torch-geometric
実装方法
以下の3通りを試してみます。
- aseのデータベースから読み込み
- LMDBから読み込み
- cifファイルから読み込み
ase
データベースを作成
まずaseのデータベースを用いてAtoms
オブジェクトを保存します。今回はAl構造に適当に変位を与えて、EMTでエネルギー等を計算して保存しておきます。
from ase.calculators.emt import EMT
from ase.db import connect
from ase.io import read
import numpy as np
atoms = read("Al.cif")
atoms *= (2,2,2) # make supercell
base_pos = atoms.get_positions()
calc = EMT()
with connect("ase_toy_data.db") as conn:
for i in range(100):
atoms_rattle = atoms.copy()
atoms_rattle.set_positions(
base_pos + np.random.normal(scale=0.1, size=base_positions.shape)
)
calc.calculate(atoms_rattle)
conn.write(atoms_rattle, data=calc.results)
pytochのデータセット作成
pytorchのDataset
を定義します。
from __future__ import annotations
from collections.abc import Callable, Iterable
from ocpmodels.preprocessing import AtomsToGraphs
import torch
from torch.utils.data import BatchSampler, Dataset, Sampler, SequentialSampler
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
class AseDataset(Dataset):
"""
Dataset of ase database.
Args:
db_path: Path to the ase database.
convert_fn: Function to convert :class:`.Atoms` to PyG :class:`.Data`.
"""
def __init__(
self,
db_path: str,
convert_fn: Callable[[Atoms], Data] = None
):
self.db_path = db_path
if convert_fn is None:
convert_fn = AtomsToGraphs().convert
self.convert_fn = convert_fn
def __len__(self) -> int:
with connect(self.db_path) as conn:
n = conn.count()
return n
class AseSingleDataset(AseDataset):
"""
Dataset to get a single graph.
"""
def __getitem__(self, idx: int) -> Data:
with connect(self.db_path) as conn:
atoms = conn.get_atoms(idx+1, add_additional_information=True)
graph = self.convert_fn(atoms)
graph = add_info(graph, atoms.info["data"])
return graph
class AseBatchedDataset(AseDataset):
"""
Dataset to get a batched graph.
"""
def __getitem__(self, idx: Iterable) -> list[Data]:
datas = []
with connect(self.db_path) as conn:
for i in idx:
atoms = conn.get_atoms(i+1, add_additional_information=True)
graph = self.convert_fn(atoms)
graph = add_info(graph, atoms.info["data"])
datas.append(graph)
return datas
def add_info(data: Data, info: dict) -> Data:
for k, v in info.items():
if isinstance(v, (int, float, list, np.ndarray)):
data[k] = torch.tensor(v)
return data
普通の実装ではAseSingleDataset
のようにデータを一つずつ取得しますが、データベースへの接続が過多になるので、AseBatchedDataset
も作成しました。一度の接続で複数のAtoms
を取得します。なお、今回の実装ではaseのデータベースはsqlite3を用いていますが、for文ではなく一度に取得するクエリを書いても速度はむしろ遅くなってしまったので今回は記載していません。
速度検証
まず、速度検証用の関数を定義します。データセット全てのデータ取得を10周します。
from time import perf_counter
def get_fetch_time(
dataset: Dataset,
loader_cls: type[torch.utils.loader.DataLoader],
collate_fn: Callable[[...], Batch] | None = None,
sampler: Sampler | None = None,
batch_size: int = 16,
n_iter: int = 10
) -> float:
"""
Measure the time required for data fetch.
Args:
dataset: Dataset
loader_cls: DataLoader class (torch or torch_geometric)
collate_fn: Callable function to get a batch.
sampler: :class:`.Sampler` class (for :class:`AseBatchedDataset`).
If this is not None (:class:`.BatchSampler` is assumed),
``batch_size`` will be replaced by ``1``.
batch_size: Batch size
n_iter: The number of iterations.
"""
if sampler is not None:
_batch_size = 1
else:
_batch_size = batch_size
loader = loader_cls(
dataset, batch_size=_batch_size, collate_fn=collate_fn,
sampler=sampler, shuffle=False
)
start = perf_counter()
for i in range(n_iter):
for data in loader:
pass
end = perf_counter()
return end - start
以下のように、AseSingleDataset
とAseBatchedDataset
でデータ取得する時間を計測します。
def collate(datas: list[list[Datas]]) -> Batch:
"""
DataLoader with ``batch_size=1`` generates ``datas`` of
(1, batch_size) shape list. Collate function transforms it
to :class:`.Batch` object.
"""
return Batch.from_data_list(datas[0])
dataset_single = AseSingleDataset("ase_toy_data.db")
dataset_batch = AseBatchedDataset("ase_toy_data.db")
sampler = BatchSampler(SequentialSampler(dataset_batch), 16, False)
t1 = get_fetch_time(dataset_single, DataLoader)
t2 = get_fetch_time(
dataset_batch,
torch.utils.data.DataLoader,
collate_fn=collate,
sampler=sampler
)
print(t1, t2)
single: 3.5297205930000928, batch: 2.992043105999983
バッチで取得することで若干速くなりました。
LMDB
データベースの作成
次にLMDBでAtoms
を保存します。
import pickle
import lmdb
with connect("ase_toy_data.db") as conn:
toy_datas = [
conn.get_atoms(i+1, add_additional_information=True)
for i in range(conn.count())
]
db = lmdb.open(
"lmdb_toy_data.lmdb", subdir=False, meminit=False, map_async=True
)
for i, toy_data in enumerate(toy_datas):
txn = db.begin(write=True)
txn.put(
f"{i}".encode("ascii"),
pickle.dumps(toy_data, protocol=-1),
)
txn.commit()
txn = db.begin(write=True)
txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1))
txn.commit()
db.sync()
db.close()
pytochのデータセット作成
class LmdbDataset(Dataset):
"""
LMDB dataset.
Args:
db_path: Path to the ase database.
convert_fn: Function to convert :class:`.Atoms` to PyG :class:`.Data`.
"""
def __init__(
self,
db_path: str,
convert_fn: Callable[[Atoms], Data] = None,
is_graph: bool = False
):
if convert_fn is None:
convert_fn = AtomsToGraphs().convert
self.convert_fn = convert_fn
self.env = lmdb.open(
db_path, subdir=False, readonly=True, lock=False,
readahead=False, meminit=False, max_readers=1,
)
self._keys = [
f"{j}".encode("ascii")
for j in range(self.env.stat()["entries"] - 1)
]
self.n_datas = len(self._keys)
self.is_graph = is_graph
def close(self) -> None:
"""Close a database."""
self.env.close()
def __len__(self):
return self.n_datas
def __getitem__(self, idx):
data = self.env.begin().get(self._keys[idx])
data = pickle.loads(data)
if not self.is_graph:
graph = self.convert_fn(data)
graph = add_info(graph, data.info["data"])
else:
graph = data
return graph
速度検証
dataset = LmdbDataset("lmdb_toy_data.lmdb")
print(get_fetch_time(dataset, DataLoader))
2.4930231989996173
aseよりも少し速くなりました。単純なkey-value型なので高速に読み込むことが可能であるためと思われます。
なお、LMDBであればグラフ化済みのデータを保存しておくこともできます。
dataset = LmdbDataset("lmdb_toy_data.lmdb")
db = lmdb.open(
"lmdb_toy_graph_data.lmdb", subdir=False, meminit=False, map_async=True
)
for i, graph in enumerate(dataset):
txn = db.begin(write=True)
txn.put(
f"{i}".encode("ascii"),
pickle.dumps(graph, protocol=-1),
)
txn.commit()
txn = db.begin(write=True)
txn.put("length".encode("ascii"), pickle.dumps(i, protocol=-1))
txn.commit()
db.sync()
db.close()
dataset_graph = LmdbDataset("lmdb_toy_graph_data.lmdb", is_graph=True)
print(get_fetch_time(dataset_graph, DataLoader))
0.7136490009997942
グラフ化が省略できるのでかなり高速化されました。一方で、今回のケースでは、100構造のデータベース容量は、Atoms
保存では432 KBであるのに対しグラフ保存では7.3 MBになってしまいます。ストレージに余裕があるのであればあらかじめグラフ化しておくのが良いでしょう。
cif
最後に、cifファイルのリストから読み込む例を考えます。
データベースの作成
import os
from ase.io import write
os.mkdir("structures")
properties = {}
for i, data in enumerate(toy_datas):
write(f"structures/{i}.cif", data)
properties[i] = data.info["data"]
with open("structures/properties.pkl", "wb") as f:
pickle.dump(properties, f)
pytochのデータセット作成
from pathlib import Path
class CifDataset(Dataset):
"""
Cif files dataset.
Args:
directory: Path to the directory which stores cif files.
convert_fn: Function to convert :class:`.Atoms` to PyG :class:`.Data`.
"""
def __init__(
self,
directory: str,
convert_fn: Callable[[Atoms], Data] = None,
is_graph: bool = False
):
if convert_fn is None:
convert_fn = AtomsToGraphs().convert
self.convert_fn = convert_fn
files = sorted([str(p) for p in Path(directory).glob("*.cif")])
self.files = files
self.n_datas = len(self.files)
with open(Path(directory) / "properties.pkl", "rb") as f:
properties = pickle.load(f)
self.properties = properties
def __len__(self) -> int:
return self.n_datas
def __getitem__(self, idx: int) -> Data:
atoms = read(self.files[idx])
props = self.properties[idx]
graph = self.convert_fn(atoms)
graph = add_info(graph, props)
return graph
大量データを想定しているのでメモリにAtoms
オブジェクトはのせない実装にしています。
速度検証
dataset_cif = CifDataset("structures/")
print(get_fetch_time(dataset_cif, DataLoader))
9.648674191999817
cifファイルを毎回読み込むため、かなり速度が遅くなってしまいます。特別な理由がない限り避けた方がいいでしょう。
まとめ
各実装でデータのロードにかかった時間は以下のとおりです。
手法 | 所要時間 |
---|---|
ase (single) | 3.530 |
ase (batched) | 2.992 |
LMDB (Atoms ) |
2.493 |
LMDB (graph) | 0.714 |
cif | 9.649 |
基本的にはLMDBなどのkey-value型のデータベースを使えばよさそうです。pickle化できるものは保存でき、使い勝手も良いので実装の自由度も高いと思います。
グラフ化するときのカットオフ半径はハイパーパラメータなので、LMDB(graph)はカットオフ半径ごとにデータベースを作る必要があり、グラフの容量も大きいので大規模データでは厳しい場面が多そうです。高速なのでストレージに余裕があれば積極的に使いたいところです。
今回の実装した中では、ストレージに余裕があればLMDB(graph)、そうでなければLMDB(Atoms
)での実装が良いという結果になりました。