LoginSignup
6
1

More than 1 year has passed since last update.

ASEでGNNを用いたCalculatorを自作する

Posted at

はじめに

この記事ではASEで自作のCalculatorを実装する方法について解説します。材料分野でよく使われるGNNモデルを用いて結晶構造のエネルギーを算出するCalculatorを例とします。

必要なパッケージ

  • ase
  • torch
  • ocpmodels

コード

いきなりですが、コードは以下になります。

from __future__ import annotations

from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from ocpmodels.models import BaseModel
from ocpmodels.preprocessing.atoms_to_graphs import AtomsToGraphs
import torch

ALL_CHANGES = tuple(all_changes)


class NNPCalculator(Calculator):
    """
    Calculator using a neural network potential model.
    
    Args:
        model: OCP model
        cutoff: Cutoff radius to get a graph.
        device: Calculation device
        **kwargs: Keyward arguments passed to :meth:`Calculator.__init__`
    """
    implemented_properties = ("energy", "forces")
    
    def __init__(
        self,
        model: BaseModel,
        cutoff: float = 6.0,
        device: str = "cpu",
        **kwargs
    ):
        super().__init__(**kwargs)
        
        self.model = model
        for n, p in self.model.named_parameters():
            p.requires_grad_(False)
        self.model.eval()
        self.device = device
        self.model.to(self.device)
        self.atg = AtomsToGraphs(radius=cutoff)
    
    def calculate(
        self,
        atoms: Atoms | None = None,
        properties: tuple[str] = ("energy", "forces"),
        system_changes: tuple[str] = ALL_CHANGES
    ) -> None:
        """
        Args:
            atoms: Atoms object
            properties: Properties to be calculated.
            system_changes: Parameters to detect if Atoms have been changed.
        """
        super().calculate(atoms, properties, system_changes)
        
        data = self.atg.convert(atoms)
        data.neighbors = torch.tensor(data.edge_index.size(1))
        data.to(self.device)
        
        if "forces" in properties:
            self.model.regress_forces = True
            with torch.no_grad():
                energy, forces = self.model(data)
            self.results["forces"] = forces.cpu().numpy()
        else:
            self.model.regress_forces = False
            with torch.no_grad():
                energy = self.model(data)
        self.results["energy"] = energy.item()

コードの説明

順にコードを説明します。

__init__

まずは__init__内で、super().__init__()を呼び出します。これはCalculatorを継承したときに必須の操作で、諸々の設定を行ってくれます。
今回はmodelを引数に渡しています。これはocpmodelsで利用可能なモデルを渡すことができます。予測時にモデルのパラメータ更新は不要ですので、各パラメータが計算グラフを作成しないように設定します。また、model.eval()で予測モードにしておきます。

calculate

calculatorのコアの部分です。この部分にenergyやforcesを計算するプロセスを記述します。計算可能なpropertyはimplemented_propertyで設定しておく必要があります。
calculateの引数にはatoms(計算するAtomsオブジェクト), properties(計算する物性), system_changes(Atomsオブジェクトが変更されたか判断するための量)があります。Calculatorはキャッシュをうまく使っており、計算対象のAtomsが前回計算したものと同じである場合、計算はスキップされキャッシュを使用します。この同一判断にsystem_changesが用いられます。
なお、system_changesの引数にはASEの実装ではall_changesというものが使われています。これはリストになっているのですが、関数のデフォルト値としてリストを渡すのは非推奨です(内部で意図せず変更された場合にデフォルト値も変わってしまうため)。上記実装では代わりにtupleを渡しています。
super().calculate()で上記の処理を行います。

キャッシュが使用できない場合、計算に進みます。BaseModelの入力はtorch_geometricのDataBatchオブジェクトである必要があるので、まずはAtomsを変換します。これを計算するデバイスに移動させます。

その後はBaseModelにデータを渡してenergy, forcesを計算します。ocpmodelsの実装では、modelのattributeにregress_forcesというものがあります。これがTrueだと(energy, forces)のtupleが返され、Falseだとenergyのみ返されます。
なお、forcesはenergyをpositionで微分するので、torch.no_grad()では計算グラフが作られず、energyの自動微分が出来ないと思われるかもしれませんが、ocpmodelsでは内部でうまく処理していて、regress_forcesがTrueになっていると、計算グラフを作るモードに自動的に変更されます。

あとは計算結果をself.resultsに保存すれば完了です。numpy配列にする必要があることに注意してください。

Calculatorの使い方

上記の実装で、通常のCalculatorと同様に使用可能です。例えば、以下でエネルギーを取得できます。

from ocpmodels.models import SchNet

model = SchNet(None, None, 1)
calc = NNPCalculator(model)
energy = calc.get_potential_energy(atoms)

まとめ

このようにCalculatorの実装方法を勉強すれば、自作のNNPモデルで各種計算が可能です。例えばstressも実装することで、結晶構造の構造緩和も可能になり、DFT計算よりも圧倒的に高速に実行することができます。
色々応用範囲も広いので、他の例も追々投稿したいと思います。

6
1
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1