LoginSignup
30
29

More than 1 year has passed since last update.

機械学習ポテンシャル実装入門

Posted at

はじめに

機械学習ポテンシャル(NNP)は最近注目の技術ですが、実装について解説している記事が無く個人的に苦労したので、初学者用の解説記事を作りました。今回はGNNを用いたNNPについて取り上げます。torchや torch_geometric、材料系でよく使われるaseなどの基本的なライブラリ以外は、解説を加えながらスクラッチで実装していきます。

※ 速度を多少犠牲にして、分かりやすさを重視した実装にしています。特にグラフ化の部分は大いに効率化できますが、別記事で後々解説したいと思います。

※ pytorchの基本的な理解はあるものとしています。pytorchの解説は他に良記事が多くあるのでここでは省略します。

データ型の定義

GNNに流すデータはグラフデータです。これを便利に扱えるtorch_geometric.data.Dataを用いて実装していきます。
ます、今回必要なモジュールをimportしておきます。

from __future__ import annotations

from ase.io import read
from ase.neighborlist import neighbor_list
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.schnet import GaussianSmearing
from torch_scatter import scatter

次に、ase.Atomsから、torch_geometric.data.Dataに変換する関数を定義します。ネットワークに流す際に必要な情報を渡しておきます。

def get_pyg_data(
    atoms: Atoms, cutoff: float = 6.0, device: str = "cpu"
) -> Data:
    """
    Get :class:`torch_geometric.data.Data` class.
    
    Args:
        atoms: Atoms object
        cutoff: Cutoff radius to gather neighbors
        device: Calculation device
    """
    idx_i, idx_j, cell_offsets = neighbor_list(
        "ijS", atoms, cutoff, self_interaction=False
    )
    idx_i = torch.tensor(idx_i, device=device)
    idx_j = torch.tensor(idx_j, device=device)
    cell_offsets = torch.tensor(cell_offsets, device=device)
    
    edge_index = torch.stack([idx_j, idx_i])
    data_dict = {
        "pos": torch.Tensor(atoms.get_positions(wrap=True), device=device),
        "cell": torch.Tensor(atoms.get_cell(), device=device).unsqueeze(0),
        "atomic_numbers": torch.tensor(atoms.get_atomic_numbers(), device=device),
        "edge_index": edge_index,
        "cell_offsets": cell_offsets,
        "neighbors": torch.LongTensor([edge_index.size(1)])
    }
    data = Data(**data_dict)
    return data

まず、グラフを作る上で、エッジの情報が必要になります。一般的にあるカットオフ半径以内の原子を集めてくるのですが、結晶構造の場合は周期境界条件を考慮する必要があります。これはase.neighborlist.neighbor_listで簡単に取得できます。これで中心原子(idx_i)、隣接原子(idx_j)と周期境界のオフセット(cell_offsets)が取得できます。

edge_indexは
[[1,4,3,...], [0,0,0,...]]
のようなサイトのインデックス形式になっており、0-1サイト間、0-4サイト間、0-3サイト間にエッジが存在しているというようなデータの持ち方になっています。

cell_offsetsは
[0,0,1]→c軸方向に1つずれたセルの原子
のような形式になっています。このcell_offsetsは、後ほど周期境界を考慮した原子間距離を取得するのに必要です。

Dataの形式は以下のようになっています。

>>> atoms = read("LiCl_mp-22905_conventional_standard.cif")
>>> data = get_pyg_data(atoms)
>>> data
Data(edge_index=[2, 448], pos=[8, 3], cell=[1, 3, 3], atomic_numbers=[8], cell_offsets=[448, 3], neighbors=[1])

ここで、cellの次元を増やしたことには意味があります。後ほどこのDataをバッチ化するのですが、[3,3]のshapeの場合、例えば4つの構造をバッチ化した際に[12,3]となってしまいます。これだと扱いにくいのであらかじめ[1,3,3]のshapeにしておきます。こうすることでバッチ化した際に[4,3,3]のshapeになり、構造数との対応が明確になり扱いやすくなります。

テスト用のデータセットを作って、DataLoaderで取り出してみましょう。

class ToyDataset(Dataset):
    def __init__(self, atoms, n, cutoff=6.0):
        datas = []
        for i in range(n):
            atoms_temp = atoms.copy()
            atoms_temp.rattle(seed=i)
            datas.append(get_pyg_data(atoms_temp, cutoff))
        self.datas = datas
        
    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, i):
        return self.datas[i]

__getitem__Dataを返せれば何でもいいです。これをtorch_geometric.loader.DataLoaderでバッチ化すると

>>> dataset = ToyDataset(atoms, 10)
>>> loader = DataLoader(dataset, batch_size=4)
>>> batch_data = next(iter(loader))
>>> batch_data
DataBatch(edge_index=[2, 1792], pos=[32, 3], cell=[4, 3, 3], atomic_numbers=[32], cell_offsets=[1792, 3], neighbors=[4], batch=[32], ptr=[5])

となります。ノードの情報(atomic_numbersなど)は構造ごとの次元がないことに注意してください。
torch_geometricでは、グラフを複数持つという思想では無く、一つの大きなグラフを作っています。このようなバッチ化により、効率的に計算することが可能になっています(後ほど説明します)。どのatomic_numbersがどの構造に対応しているかを確認するにはbatchを参照します。

>>> batch_data.batch
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        3, 3, 3, 3, 3, 3, 3, 3])

atomic_numbersについて、0-7番目は構造0、8-15番目は構造1、・・・のように対応しているということです。
GNNに流すデータが定義できたので、次はネットワークを定義します。

ネットワークの定義

今回は簡単のためCGCNNを実装します。GNNは一般的に隣接原子からmessageを受け取り、中心原子の情報をアップデートします。このとき、全ての隣接原子の情報を集計する必要があります。これはtorch_scatterで便利に実装できます。
まず、CGCNNのmessageを確認しましょう。

\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)}
sigmoid \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right)
\odot softplus \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s  \right) \tag{1} \\
\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]

これは、各ノード特徴量 $\mathbf{x}$ とエッジ特徴量 $\mathbf{e}_{i,j}$ を結合し、2種の線形変換を施したのち、それぞれsingmoid, softplusで非線型変換したものの要素積をメッセージとし、それらを合計し、中心原子の元の特徴量に加えるという操作です。角度情報などが明示的に取り込まれていないので精度は低い傾向にありますが、単純なので実装の理解には良いでしょう。
それではネットワークを実装します。まず、Batch (or DataBatch)を入力にして、周期境界を考慮した原子間距離を返す関数を定義します。

def get_distances(batch: Batch) -> torch.Tensor:
    neighbor, center = batch.edge_index
    distance_vectors = batch.pos[neighbor] - batch.pos[center]
    cell = torch.repeat_interleave(batch.cell, batch.neighbors, dim=0)
    offsets = batch.cell_offsets.float().view(-1, 1, 3).bmm(cell).view(-1, 3)
    distance_vectors = distance_vectors + offsets
    distances = torch.linalg.norm(distance_vectors, dim=-1)
    return distances
>>> distances = get_distances(batch_data)
tensor([5.7609, 5.7609, 5.7609,  ..., 4.4624, 5.7609, 5.7609])

このように各エッジの距離が取得できます。cell_offsetsは周期境界の情報を持っており、ここで必要になっています。cellとの行列積をとることでユニットセル間の遷移ベクトルに変換しています。これをオフセットとして足せば周期境界条件下での距離が取得できます。

では、CGCNNを実装していきます。今回はエッジの特徴量として、距離をガウシアンで展開したもの、ノードの特徴量として単純なembeddingを使用します。CGCNNのオリジナル実装ではノード特徴量は元素の情報を明示的に組み込んだものになっていますが、今回は簡単のため単純なembeddingにしました。また、オリジナルの実装では上記の式と少し異なっており、メッセージの合計に対しバッチ正規化を施したのち元のノード特徴量に加え、さらにsoftplusで変換しています。今回の実装でもそのようにしていきます。

class CGCNNLayer(nn.Module):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        cutoff: float = 6.0,
    ):
        super().__init__()
        input_dim = node_dim * 2 + edge_dim
        self.lin1 = nn.Linear(input_dim, node_dim)
        self.lin2 = nn.Linear(input_dim, node_dim)    
        self._initialize_params(self.lin1)
        self._initialize_params(self.lin2)
        self.bn = nn.BatchNorm1d(node_dim)

    def _initialize_params(self, module):
        torch.nn.init.xavier_uniform_(module.weight)
        module.bias.data.fill_(0)
        
    def forward(self, batch):
        neighbor, center = batch.edge_index
        
        # z shape: [n_nodes, (node_dim*2 + edge_dim)]
        z = torch.cat(
            [batch.x[center], batch.x[neighbor], batch.edge_attr],
            dim=1
        )
        z1 = self.lin1(z).sigmoid()
        z2 = nn.functional.softplus(self.lin2(z))
        total_message = scatter(
            z1*z2, index=center, dim=0, reduce="sum"
        )
        x_updated = nn.functional.softplus(batch.x + self.bn(total_message))
        return x_updated


class CGCNN(nn.Module):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        cutoff: float = 6.0,
        n_layers: int = 3,
        graph_reduce: str = "mean",
        return_forces: bool = True
    ):
        super().__init__()
        input_dim = node_dim * 2 + edge_dim
        self.lin_out = nn.Linear(node_dim, 1)
        self.edge_featurizer = GaussianSmearing(0.0, cutoff, edge_dim)
        self.node_embedding = nn.Embedding(118, node_dim)
        self.graph_reduce = graph_reduce
        self.return_forces = return_forces
        
        layers = []
        for i in range(n_layers):
            layers.append(CGCNNLayer(node_dim, edge_dim, cutoff))
        self.layers = nn.ModuleList(layers)

    def get_energy(self, batch: Batch) -> torch.Tensor:
        node_features = self.node_embedding(batch.atomic_numbers - 1)
        batch.x = node_features
        
        distances = get_distances(batch)
        edge_attr = self.edge_featurizer(distances)
        batch.edge_attr = edge_attr
        
        for layer in self.layers:
            x_updated = layer(batch)
            batch.x = x_updated
        graph_features = scatter(batch.x, batch.batch, dim=0, reduce=self.graph_reduce)
        energy = self.lin_out(graph_features)
        return energy
    
    def forward(self, batch: Batch) -> dict[str, torch.Tensor]:
        if self.return_forces:
            batch.pos.requires_grad_(True)
            
        energy = self.get_energy(batch).view(-1)
        out = {"energy": energy}
        if self.return_forces:
            forces = -1 * (
                torch.autograd.grad(
                    energy,
                    batch.pos,
                    grad_outputs=torch.ones_like(energy),
                    create_graph=True,
                )[0]
            )
            out["forces"] = forces
        return out

以下ではコードの詳細を解説します。

CGCNNLayer

こちらではグラフ畳み込みの部分を実装しています。まずedge_indexからcenterとneighborのインデックスを取り出します。このインデックスを使ってそれぞれのノード特徴量(x)を取得し、エッジ特徴量と結合します。この特徴量に対し、2種類の変換を行いz1, z2を得ます。
ここでscatterが出てきます。GNNの実装において最も重要な部分です。ここではz1*z2に対しcenterが同じものを足し合わせるという操作になっています。具体的にscatterの実行結果を簡単な例で見てみましょう。

>>> t = torch.arange(10)
>>> t
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> idx = torch.tensor([0,0,0,1,1,1,2,2,2,3])
>>> scatter(t, idx, reduce="sum")
tensor([ 3, 12, 21,  9])

idxが同じ部分をグループとして、tの値をそれぞれのグループで合計するという操作です。例えば0のグループであれば,0+1+2=3, 2のグループであれば6+7+8=21という具合です。

CGCNNLayerに戻ると、centerのノードのインデックスをグループとして、メッセージを合計しています。つまり、式1の$\sum_{j \in \mathcal{N}(i)}$を行う処理になります。
あとはこれをバッチ正規化、softplusを通し、元のノード特徴量に加算します。これで1回のグラフ畳み込みが完了します。これを繰り返すことでより遠くの情報も取り入れることができ、ノード特徴量がリッチになっていきます。

CGCNN

こちらでは、node_embedding(原子embeddingを取得)、edge_featurizer(エッジ特徴量を取得)、layers(CGCNNLayer)を定義します。

get_energy

get_energy内ではノード、エッジ特徴量を取得したのち、ノード特徴量をlayerの数だけ更新します。
最後にscatterでノード特徴量をグラフレベルの特徴量に変換します。ここでbatchアトリビュートが必要になっており、torch_geometric.data.Dataであれば自動的に付与されているので便利です。ここではbatchのグループは各構造に対応するので、グラフレベル(構造ごと)の特徴量を取得できます。今回は各ノード特徴量の平均をとりグラフ特徴量としています。このような仕組みで計算できるため、torch_geometricのBatchは一つの大きなグラフとして実装されています。大きな行列計算ができるのでforループで順番に計算するよりも高速に計算できます。
最後に、グラフレベルの特徴量を線形変換し、エネルギーとして取得しています。
なお、scatterの前にLinear層を加えて変換してからreduceしても構いません。どの程度の表現力が必要かで調整すれば良いです。

forward

NNPではエネルギーに加え、原子に働く力も計算したい場合がほとんどです。力はエネルギーを原子位置で微分することで得られるので、forward内で実装しています。
ここで、batch.pos.requires_grad_(True)がポイントです。これによりposからの計算グラフが作られ、自動微分が可能になります。autogradでエネルギーに対して微分することで力を得ます。学習時はエネルギーと力のそれぞれのロスを組み合わせたロス関数で最適化することが多いので、create_graph=Trueとしておき、forcesに対しても、モデルのパラメータで勾配を計算できるようにしておきます。

最後にデータを流して確認しましょう。

>>> cgcnn = CGCNN(node_dim=100, edge_dim=50)
>>> out = cgcnn(batch_data)
>>> print(out["energy"].shape, out["forces"].shape)
torch.Size([4]) torch.Size([32, 3])

エネルギーと力を計算することができました。今回は8原子を持つ構造を4つまとめてバッチにしたので、エネルギーが4つ、力が32個算出されています。
学習では、得られた予測値と実測値に対し、エネルギーと力でそれぞれロス関数を設定し、重みをつけて合計したロスに対しbackwardを呼べば、モデルパラメータの勾配が得られます。

まとめ

今回NNP実装の入門としてCGCNNを実装しました。複雑なNNPモデルでも基本は同じで、message部分やmessageの集計部分を変えることで実装できます。今後は冒頭に述べた、グラフ構築の高速化についての記事を書こうと思っています。

30
29
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
29