はじめに
機械学習ポテンシャル(NNP)は最近注目の技術ですが、実装について解説している記事が無く個人的に苦労したので、初学者用の解説記事を作りました。今回はGNNを用いたNNPについて取り上げます。torch
や torch_geometric
、材料系でよく使われるase
などの基本的なライブラリ以外は、解説を加えながらスクラッチで実装していきます。
※ 速度を多少犠牲にして、分かりやすさを重視した実装にしています。特にグラフ化の部分は大いに効率化できますが、別記事で後々解説したいと思います。
※ pytorchの基本的な理解はあるものとしています。pytorchの解説は他に良記事が多くあるのでここでは省略します。
データ型の定義
GNNに流すデータはグラフデータです。これを便利に扱えるtorch_geometric.data.Data
を用いて実装していきます。
ます、今回必要なモジュールをimportしておきます。
from __future__ import annotations
from ase.io import read
from ase.neighborlist import neighbor_list
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.schnet import GaussianSmearing
from torch_scatter import scatter
次に、ase.Atoms
から、torch_geometric.data.Data
に変換する関数を定義します。ネットワークに流す際に必要な情報を渡しておきます。
def get_pyg_data(
atoms: Atoms, cutoff: float = 6.0, device: str = "cpu"
) -> Data:
"""
Get :class:`torch_geometric.data.Data` class.
Args:
atoms: Atoms object
cutoff: Cutoff radius to gather neighbors
device: Calculation device
"""
idx_i, idx_j, cell_offsets = neighbor_list(
"ijS", atoms, cutoff, self_interaction=False
)
idx_i = torch.tensor(idx_i, device=device)
idx_j = torch.tensor(idx_j, device=device)
cell_offsets = torch.tensor(cell_offsets, device=device)
edge_index = torch.stack([idx_j, idx_i])
data_dict = {
"pos": torch.Tensor(atoms.get_positions(wrap=True), device=device),
"cell": torch.Tensor(atoms.get_cell(), device=device).unsqueeze(0),
"atomic_numbers": torch.tensor(atoms.get_atomic_numbers(), device=device),
"edge_index": edge_index,
"cell_offsets": cell_offsets,
"neighbors": torch.LongTensor([edge_index.size(1)])
}
data = Data(**data_dict)
return data
まず、グラフを作る上で、エッジの情報が必要になります。一般的にあるカットオフ半径以内の原子を集めてくるのですが、結晶構造の場合は周期境界条件を考慮する必要があります。これはase.neighborlist.neighbor_list
で簡単に取得できます。これで中心原子(idx_i)、隣接原子(idx_j)と周期境界のオフセット(cell_offsets)が取得できます。
edge_indexは
[[1,4,3,...], [0,0,0,...]]
のようなサイトのインデックス形式になっており、0-1サイト間、0-4サイト間、0-3サイト間にエッジが存在しているというようなデータの持ち方になっています。
cell_offsetsは
[0,0,1]→c軸方向に1つずれたセルの原子
のような形式になっています。このcell_offsetsは、後ほど周期境界を考慮した原子間距離を取得するのに必要です。
Data
の形式は以下のようになっています。
>>> atoms = read("LiCl_mp-22905_conventional_standard.cif")
>>> data = get_pyg_data(atoms)
>>> data
Data(edge_index=[2, 448], pos=[8, 3], cell=[1, 3, 3], atomic_numbers=[8], cell_offsets=[448, 3], neighbors=[1])
ここで、cell
の次元を増やしたことには意味があります。後ほどこのData
をバッチ化するのですが、[3,3]
のshapeの場合、例えば4つの構造をバッチ化した際に[12,3]
となってしまいます。これだと扱いにくいのであらかじめ[1,3,3]
のshapeにしておきます。こうすることでバッチ化した際に[4,3,3]
のshapeになり、構造数との対応が明確になり扱いやすくなります。
テスト用のデータセットを作って、DataLoaderで取り出してみましょう。
class ToyDataset(Dataset):
def __init__(self, atoms, n, cutoff=6.0):
datas = []
for i in range(n):
atoms_temp = atoms.copy()
atoms_temp.rattle(seed=i)
datas.append(get_pyg_data(atoms_temp, cutoff))
self.datas = datas
def __len__(self):
return len(self.datas)
def __getitem__(self, i):
return self.datas[i]
__getitem__
でData
を返せれば何でもいいです。これをtorch_geometric.loader.DataLoader
でバッチ化すると
>>> dataset = ToyDataset(atoms, 10)
>>> loader = DataLoader(dataset, batch_size=4)
>>> batch_data = next(iter(loader))
>>> batch_data
DataBatch(edge_index=[2, 1792], pos=[32, 3], cell=[4, 3, 3], atomic_numbers=[32], cell_offsets=[1792, 3], neighbors=[4], batch=[32], ptr=[5])
となります。ノードの情報(atomic_numbersなど)は構造ごとの次元がないことに注意してください。
torch_geometricでは、グラフを複数持つという思想では無く、一つの大きなグラフを作っています。このようなバッチ化により、効率的に計算することが可能になっています(後ほど説明します)。どのatomic_numbersがどの構造に対応しているかを確認するにはbatch
を参照します。
>>> batch_data.batch
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3])
atomic_numbersについて、0-7番目は構造0、8-15番目は構造1、・・・のように対応しているということです。
GNNに流すデータが定義できたので、次はネットワークを定義します。
ネットワークの定義
今回は簡単のためCGCNNを実装します。GNNは一般的に隣接原子からmessageを受け取り、中心原子の情報をアップデートします。このとき、全ての隣接原子の情報を集計する必要があります。これはtorch_scatter
で便利に実装できます。
まず、CGCNNのmessageを確認しましょう。
\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)}
sigmoid \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right)
\odot softplus \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right) \tag{1} \\
\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]
これは、各ノード特徴量 $\mathbf{x}$ とエッジ特徴量 $\mathbf{e}_{i,j}$ を結合し、2種の線形変換を施したのち、それぞれsingmoid, softplusで非線型変換したものの要素積をメッセージとし、それらを合計し、中心原子の元の特徴量に加えるという操作です。角度情報などが明示的に取り込まれていないので精度は低い傾向にありますが、単純なので実装の理解には良いでしょう。
それではネットワークを実装します。まず、Batch
(or DataBatch
)を入力にして、周期境界を考慮した原子間距離を返す関数を定義します。
def get_distances(batch: Batch) -> torch.Tensor:
neighbor, center = batch.edge_index
distance_vectors = batch.pos[neighbor] - batch.pos[center]
cell = torch.repeat_interleave(batch.cell, batch.neighbors, dim=0)
offsets = batch.cell_offsets.float().view(-1, 1, 3).bmm(cell).view(-1, 3)
distance_vectors = distance_vectors + offsets
distances = torch.linalg.norm(distance_vectors, dim=-1)
return distances
>>> distances = get_distances(batch_data)
tensor([5.7609, 5.7609, 5.7609, ..., 4.4624, 5.7609, 5.7609])
このように各エッジの距離が取得できます。cell_offsetsは周期境界の情報を持っており、ここで必要になっています。cellとの行列積をとることでユニットセル間の遷移ベクトルに変換しています。これをオフセットとして足せば周期境界条件下での距離が取得できます。
では、CGCNNを実装していきます。今回はエッジの特徴量として、距離をガウシアンで展開したもの、ノードの特徴量として単純なembeddingを使用します。CGCNNのオリジナル実装ではノード特徴量は元素の情報を明示的に組み込んだものになっていますが、今回は簡単のため単純なembeddingにしました。また、オリジナルの実装では上記の式と少し異なっており、メッセージの合計に対しバッチ正規化を施したのち元のノード特徴量に加え、さらにsoftplusで変換しています。今回の実装でもそのようにしていきます。
class CGCNNLayer(nn.Module):
def __init__(
self,
node_dim: int,
edge_dim: int,
cutoff: float = 6.0,
):
super().__init__()
input_dim = node_dim * 2 + edge_dim
self.lin1 = nn.Linear(input_dim, node_dim)
self.lin2 = nn.Linear(input_dim, node_dim)
self._initialize_params(self.lin1)
self._initialize_params(self.lin2)
self.bn = nn.BatchNorm1d(node_dim)
def _initialize_params(self, module):
torch.nn.init.xavier_uniform_(module.weight)
module.bias.data.fill_(0)
def forward(self, batch):
neighbor, center = batch.edge_index
# z shape: [n_nodes, (node_dim*2 + edge_dim)]
z = torch.cat(
[batch.x[center], batch.x[neighbor], batch.edge_attr],
dim=1
)
z1 = self.lin1(z).sigmoid()
z2 = nn.functional.softplus(self.lin2(z))
total_message = scatter(
z1*z2, index=center, dim=0, reduce="sum"
)
x_updated = nn.functional.softplus(batch.x + self.bn(total_message))
return x_updated
class CGCNN(nn.Module):
def __init__(
self,
node_dim: int,
edge_dim: int,
cutoff: float = 6.0,
n_layers: int = 3,
graph_reduce: str = "mean",
return_forces: bool = True
):
super().__init__()
input_dim = node_dim * 2 + edge_dim
self.lin_out = nn.Linear(node_dim, 1)
self.edge_featurizer = GaussianSmearing(0.0, cutoff, edge_dim)
self.node_embedding = nn.Embedding(118, node_dim)
self.graph_reduce = graph_reduce
self.return_forces = return_forces
layers = []
for i in range(n_layers):
layers.append(CGCNNLayer(node_dim, edge_dim, cutoff))
self.layers = nn.ModuleList(layers)
def get_energy(self, batch: Batch) -> torch.Tensor:
node_features = self.node_embedding(batch.atomic_numbers - 1)
batch.x = node_features
distances = get_distances(batch)
edge_attr = self.edge_featurizer(distances)
batch.edge_attr = edge_attr
for layer in self.layers:
x_updated = layer(batch)
batch.x = x_updated
graph_features = scatter(batch.x, batch.batch, dim=0, reduce=self.graph_reduce)
energy = self.lin_out(graph_features)
return energy
def forward(self, batch: Batch) -> dict[str, torch.Tensor]:
if self.return_forces:
batch.pos.requires_grad_(True)
energy = self.get_energy(batch).view(-1)
out = {"energy": energy}
if self.return_forces:
forces = -1 * (
torch.autograd.grad(
energy,
batch.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
out["forces"] = forces
return out
以下ではコードの詳細を解説します。
CGCNNLayer
こちらではグラフ畳み込みの部分を実装しています。まずedge_indexからcenterとneighborのインデックスを取り出します。このインデックスを使ってそれぞれのノード特徴量(x)を取得し、エッジ特徴量と結合します。この特徴量に対し、2種類の変換を行いz1, z2を得ます。
ここでscatter
が出てきます。GNNの実装において最も重要な部分です。ここではz1*z2
に対しcenterが同じものを足し合わせるという操作になっています。具体的にscatter
の実行結果を簡単な例で見てみましょう。
>>> t = torch.arange(10)
>>> t
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> idx = torch.tensor([0,0,0,1,1,1,2,2,2,3])
>>> scatter(t, idx, reduce="sum")
tensor([ 3, 12, 21, 9])
idxが同じ部分をグループとして、tの値をそれぞれのグループで合計するという操作です。例えば0
のグループであれば,0+1+2=3
, 2
のグループであれば6+7+8=21
という具合です。
CGCNNLayerに戻ると、centerのノードのインデックスをグループとして、メッセージを合計しています。つまり、式1の$\sum_{j \in \mathcal{N}(i)}$を行う処理になります。
あとはこれをバッチ正規化、softplusを通し、元のノード特徴量に加算します。これで1回のグラフ畳み込みが完了します。これを繰り返すことでより遠くの情報も取り入れることができ、ノード特徴量がリッチになっていきます。
CGCNN
こちらでは、node_embedding(原子embeddingを取得)、edge_featurizer(エッジ特徴量を取得)、layers(CGCNNLayer)を定義します。
get_energy
get_energy
内ではノード、エッジ特徴量を取得したのち、ノード特徴量をlayerの数だけ更新します。
最後にscatter
でノード特徴量をグラフレベルの特徴量に変換します。ここでbatch
アトリビュートが必要になっており、torch_geometric.data.Data
であれば自動的に付与されているので便利です。ここではbatch
のグループは各構造に対応するので、グラフレベル(構造ごと)の特徴量を取得できます。今回は各ノード特徴量の平均をとりグラフ特徴量としています。このような仕組みで計算できるため、torch_geometricのBatch
は一つの大きなグラフとして実装されています。大きな行列計算ができるのでforループで順番に計算するよりも高速に計算できます。
最後に、グラフレベルの特徴量を線形変換し、エネルギーとして取得しています。
なお、scatter
の前にLinear層を加えて変換してからreduceしても構いません。どの程度の表現力が必要かで調整すれば良いです。
forward
NNPではエネルギーに加え、原子に働く力も計算したい場合がほとんどです。力はエネルギーを原子位置で微分することで得られるので、forward
内で実装しています。
ここで、batch.pos.requires_grad_(True)
がポイントです。これによりpos
からの計算グラフが作られ、自動微分が可能になります。autograd
でエネルギーに対して微分することで力を得ます。学習時はエネルギーと力のそれぞれのロスを組み合わせたロス関数で最適化することが多いので、create_graph=True
としておき、forces
に対しても、モデルのパラメータで勾配を計算できるようにしておきます。
最後にデータを流して確認しましょう。
>>> cgcnn = CGCNN(node_dim=100, edge_dim=50)
>>> out = cgcnn(batch_data)
>>> print(out["energy"].shape, out["forces"].shape)
torch.Size([4]) torch.Size([32, 3])
エネルギーと力を計算することができました。今回は8原子を持つ構造を4つまとめてバッチにしたので、エネルギーが4つ、力が32個算出されています。
学習では、得られた予測値と実測値に対し、エネルギーと力でそれぞれロス関数を設定し、重みをつけて合計したロスに対しbackward
を呼べば、モデルパラメータの勾配が得られます。
まとめ
今回NNP実装の入門としてCGCNNを実装しました。複雑なNNPモデルでも基本は同じで、message部分やmessageの集計部分を変えることで実装できます。今後は冒頭に述べた、グラフ構築の高速化についての記事を書こうと思っています。