2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

機械学習ポテンシャルにstress計算を実装する

Posted at

はじめに

前回の記事(機械学習ポテンシャル実装入門)では機械学習ポテンシャルを実装する方法について紹介しました。ここではさらに踏み込んで、結晶構造の構造緩和をするために必要なstress計算を実装します。

既存の機械学習ポテンシャルパッケージの多くは、 stress計算を実装していませんが、以下の実装のようにすれば、既存パッケージも最小限の変更でstress計算を実装できます。今回は前回と同様、CGCNNにstress計算を実装します。

コード

前回作成したCGCNNを以下のように変更します。他の部分は使い回しです。

class CGCNN(nn.Module):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        cutoff: float = 6.0,
        n_layers: int = 3,
        graph_reduce: str = "mean",
        properties: tuple = ("energy", "forces")
    ):
        super().__init__()
        input_dim = node_dim * 2 + edge_dim
        self.lin_out = nn.Linear(node_dim, 1)
        self.edge_featurizer = GaussianSmearing(0.0, cutoff, edge_dim)
        self.node_embedding = nn.Embedding(118, node_dim)
        self.graph_reduce = graph_reduce
        self._properties = self._check_properties(properties)
        
        layers = []
        for i in range(n_layers):
            layers.append(CGCNNLayer(node_dim, edge_dim, cutoff))
        self.layers = nn.ModuleList(layers)

    def _check_properties(self, props):
        available_types = (str, list, tuple)
        assert isinstance(props, available_types), (
            f"`props` type must be in {available_types}"
        )
        if isinstance(props, str):
            props = (props, )
        else:
            props = tuple(props)
        
        for p in props:
            assert p in ["energy", "forces", "stress"]
        return props
        
    @property
    def properties(self):
        return self._properties
    
    @properties.setter
    def properties(self, props):
        self._properties = self._check_properties(props)
        
    def get_energy(self, batch: Batch):
        node_features = self.node_embedding(batch.atomic_numbers - 1)
        batch.x = node_features
        
        distances = get_distances(batch)
        edge_attr = self.edge_featurizer(distances)
        batch.edge_attr = edge_attr
        
        for layer in self.layers:
            x_updated = layer(batch)
            batch.x = x_updated
        graph_features = scatter(batch.x, batch.batch, dim=0, reduce=self.graph_reduce)
        energy = self.lin_out(graph_features)
        return energy
    
    def forward(self, batch: Batch):
        self.assign_grad_mode(batch)
            
        energy = self.get_energy(batch).view(-1)
        out = {"energy": energy}
        forces, stress = self.get_gradients(energy, batch)
        out.update({"forces": forces, "stress": stress})
        return out
    
    def assign_grad_mode(self, batch: Batch) -> None:
        """
        Change `requires_grad` parameters for force and stress calculations.

        Args:
            batch: :class:`torch_geometric.data.Batch` object.
        """
        with torch.set_grad_enabled(True):
            if "forces" in self.properties:
                batch.pos.requires_grad_()
            if "stress" in self.properties:
                pos = torch.zeros_like(batch.pos)
                cell_deform = torch.zeros_like(batch.cell, requires_grad=True)
                for i in range(batch.cell.size(0)):
                    idx = torch.where(batch.batch==i)[0]
                    pos[idx] = (
                        batch.pos[idx]
                        + torch.matmul(batch.pos[idx], cell_deform[i])
                    )
                batch.pos = pos
                batch.cell = batch.cell + torch.bmm(batch.cell, cell_deform)
                batch.cell_deform = cell_deform
                
    def get_gradients(self, energies: torch.Tensor, batch: Batch) -> tuple:
        """
        Get energy gradients w.r.t. positions or cell deformation.

        Args:
            energies: Energy of structures.
            batch: Batch object.

        Returns:
            tuple: Forces and stress tensors.
        """
        forces, stress = None, None
        wrt = []
        if "forces" in self.properties:
            wrt.append(batch["pos"])
        if "stress" in self.properties:
            wrt.append(batch["cell_deform"])
        
        if wrt:
            cell_grad = None
            grads = torch.autograd.grad(
                energies,
                wrt,
                grad_outputs=torch.ones_like(energies),
                create_graph=self.training
            )
            if len(wrt) == 1:
                if "forces" in self.properties:
                    forces = -1 * grads[0]
                else:
                    cell_grad = grads[0]
            else:
                forces, cell_grad = grads
                forces.mul_(-1)
            
            if cell_grad is not None:
                volume = torch.abs(batch["cell"].det()).detach().view(-1,1,1)
                stress = cell_grad / volume

        return forces, stress

propertiesという属性を追加しています。これはforceやstressの計算が必要な時と、そうでない時で計算を分岐するために使用します。energyだけで良い場合は微分計算を実行せず、計算時間を短縮できます。

また、stress計算をするために、assign_grad_mode, get_gradientsというメソッドを追加しました。

assign_grad_modeでは計算グラフの構築が必要なテンソルに対しrequires_gradTrueにしています。具体的には、force計算時は原子位置、stress計算時はセルの変位です。

get_gradientsではself.propertiesの値に応じてエネルギーの勾配を計算します。forceとstressを同時に計算する場合は、引数にbatch["forces"]batch["cell_deform"]のリストを渡すことが大事です。別々に計算してしまうと、2回分のbackwardが発生し、計算時間が長くなってしまいます。リストを渡すことでほぼ1回分(+α)の時間で計算できます。

なお、stressの計算はエネルギーを原子の変位ベクトルで微分をとることでも計算可能です。私が調べたかぎり、既存の公開されている機械学習ポテンシャルパッケージはforward内で原子の変位ベクトルを計算しているため、変位ベクトルで微分をとる実装だと、stressを計算するためにコード変更が多く必要ですので、セル変位で微分するのが良いと思います。計算時間はほぼ同じです。

計算は以下のように行います。

dataset = ToyDataset(atoms, 10)
loader = DataLoader(dataset, batch_size=4)
batch = next(iter(loader))

cgcnn = CGCNN(node_dim=100, edge_dim=50, properties=("energy", "forces", "stress"))
out = cgcnn(batch) # 4 structure batch
out.keys()
# dict_keys(['energy', 'forces', 'stress'])

out["stress"].size()
# torch.Size([4, 3, 3])

まとめ

今回は機械学習ポテンシャルモデルにstress計算を実装する方法を紹介しました。構造最適化を行うためには、torchの最適化アルゴリズムでforce, stressを最小化しても良いですし、aseのCalculatorとして実装しても良いです。
時間があればその実装方法についても紹介したいと思います。

2
1
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?