LoginSignup
0
1

More than 1 year has passed since last update.

ニューラルネットワークポテンシャルの実装 DeepPot-SE (Atom Type Embedding) ②

Posted at

高速化のためのカスタム関数

①の続きで、高速に学習、推論を行うために、カスタム関数を作ります。
作る部分は、$R_i$→$\hat{R_i}$の部分です。
この部分をpytorchを用いて、forward, backwardを実装します。

前提

$x$ : $x_{ij}$の省略, 原子iから見た原子jの相対x座標
$y$ : $y_{ij}$の省略, 原子iから見た原子jの相対y座標
$z$ : $z_{ij}$の省略, 原子iから見た原子jの相対z座標
$r$ : $r_{ij}$の省略, 原子iと原子jの距離
$\hat{x}$ : $\hat{x_{ij}}$の省略, 原子iから見た原子jの一般化された相対x座標
$y$ : $\hat{y_{ij}}$の省略, 原子iから見た原子jの一般化された相対y座標
$z$ : $\hat{z_{ij}}$の省略, 原子iから見た原子jの一般化された相対z座標
$E$ : total potential energy

$\text{switch}()$ : switch関数は
$u = (r - r_{\text{cutoff_smth}})/(r_{\text{cutoff}} - r_{\text{cutoff_smth}})$とすると、
$r < r_{\text{cutoff_smth}}$のとき
$\text{switch}(u) = 1 $
$r_{\text{cutoff_smth}} \leqq r \leqq r_{\text{cutoff}}$のとき
$\text{switch}(u) = u^3 (-6u^2+15u-10) + 1$
$r_{\text{cutoff}} < r$のとき
$\text{switch}(u) = 0 $
となるような関数です。

$s(r)$ : $s(r_{ij})$の省略, $s(r) = \text{switch}(u) / r_{ij} $

方針

行列全体を考えると頭がバグるので、$R_i$→$\hat{R_i}$の行のみを考えます。
forwardでは、$[x, y, z]$を受け取って、$[s, \text{switch}・x/(r^2), \text{switch}・y/(r^2), \text{switch}・z/(r^2)]$を返せば良いです。
image.png

backwardではEを$[s, \hat{x}, \hat{y}, \hat{z}]$で微分したものを受け取って、
Eを$[x, y, z]$で微分したものを返せば良いです。
image.png

そのためには、Eを$[s, \hat{x}, \hat{y}, \hat{z}]$で微分した行列に何かの変換行列をかけて、Eを$[x, y, z]$で微分したものを返せば良いです。
image.png

変換行列は下のような行列です。この変換行列を手計算していきます。
対称性から、$S$を$x$で偏微分したもの、$\hat{x}$を$x$で偏微分したもの、$\hat{y}$を$x$で偏微分したものの3つ計算すれば良いです。
image.png

手計算

手計算で偏微分していくと、以下のようになります。
image.png

image.png

変換行列の、S以外の部分はきれいにまとめられます。
image.png

実装

これまでで手計算したものを、ひたすら実装します。


class custom_relative_coords_to_generalized_coords(torch.autograd.Function):
    """custom function of relative_coords_to_generalized_coords
    """
    @staticmethod
    def forward(ctx,
            relative_coords:torch.Tensor, 
            r_cutoff:float,
    ):
        relative_coords = relative_coords.reshape(-1, 3)
        generalized_coords = torch.full((relative_coords.shape[0], 4), 1e2)
        r_ij_norm = torch.linalg.norm(relative_coords, dim=1)
        s_vec = smooth_cut_s_function(
            r_ij_norm=r_ij_norm,
            r_cutoff=r_cutoff
        )
        generalized_coords[:,0] = s_vec
        generalized_coords[:,1:] = relative_coords * s_vec.view(-1, 1) / r_ij_norm.view(-1, 1)
        
        ctx.save_for_backward(relative_coords, r_ij_norm, torch.Tensor(r_cutoff), s_vec)

        return generalized_coords

    @staticmethod
    def backward(ctx, d_E_d_generalized_coords):  # d_E_d_generalized_coords = d_E/d_generalized_coords
        relative_coords, r_ij_norm, r_cutoff, s_vec = ctx.saved_tensors

        d_generalized_coords_d_relative_coords = torch.zeros((relative_coords.shape[0], 4, 3))

        u = (r_ij_norm - 0.1) / (r_cutoff - 0.1)
        d_switch_d_u = 3*u*u*(-6*u*u + 15*u - 10) + u*u*u*(-12*u +15)
        d_switch_d_u[r_ij_norm > r_cutoff] = 0
        d_u_d_r = 1 / (r_cutoff - 0.1)

        r_ji_norm_2_inv = (1/r_ij_norm) * (1/r_ij_norm)
        g = d_switch_d_u * d_u_d_r * r_ji_norm_2_inv - s_vec * r_ji_norm_2_inv
        # d_S/d_x, d_S/d_y, d_S/d_z
        d_generalized_coords_d_relative_coords[:, 0, :] = g.view(-1, 1) * relative_coords

        h = d_switch_d_u * d_u_d_r * r_ji_norm_2_inv * (1/r_ij_norm) - 2 * s_vec * (1/r_ij_norm) * r_ji_norm_2_inv
        
        d_generalized_coords_d_relative_coords[:, 1:, :] = \
            torch.bmm(h.view(-1, 1, 1)*relative_coords[:].reshape(-1, 3, 1), relative_coords.reshape(-1, 1, 3)) 
        
        s_vec_div_r_ij_norm = s_vec / r_ij_norm

        d_generalized_coords_d_relative_coords[:, 1, 0] += s_vec_div_r_ij_norm
        d_generalized_coords_d_relative_coords[:, 2, 1] += s_vec_div_r_ij_norm
        d_generalized_coords_d_relative_coords[:, 3, 2] += s_vec_div_r_ij_norm

        d_E_d_relative_coords = torch.bmm(d_E_d_generalized_coords.view(-1, 1, 4), d_generalized_coords_d_relative_coords)
        d_E_d_relative_coords.squeeze_()
        return d_E_d_relative_coords, None

確認

一応、custom backwardとpytorchの自動のbackwardで、forceが一致するか確認しますが、ほぼ一致しました。

速度の比較

最後に

カスタム関数を作ってからわかったのですが、カスタム関数を作った場合はtorch scriptに変換できません。
次回予告、全部C++, Libtorchで書き直します。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1