高速化のためのカスタム関数
①の続きで、高速に学習、推論を行うために、カスタム関数を作ります。
作る部分は、$R_i$→$\hat{R_i}$の部分です。
この部分をpytorchを用いて、forward, backwardを実装します。
前提
$x$ : $x_{ij}$の省略, 原子iから見た原子jの相対x座標
$y$ : $y_{ij}$の省略, 原子iから見た原子jの相対y座標
$z$ : $z_{ij}$の省略, 原子iから見た原子jの相対z座標
$r$ : $r_{ij}$の省略, 原子iと原子jの距離
$\hat{x}$ : $\hat{x_{ij}}$の省略, 原子iから見た原子jの一般化された相対x座標
$y$ : $\hat{y_{ij}}$の省略, 原子iから見た原子jの一般化された相対y座標
$z$ : $\hat{z_{ij}}$の省略, 原子iから見た原子jの一般化された相対z座標
$E$ : total potential energy
$\text{switch}()$ : switch関数は
$u = (r - r_{\text{cutoff_smth}})/(r_{\text{cutoff}} - r_{\text{cutoff_smth}})$とすると、
$r < r_{\text{cutoff_smth}}$のとき
$\text{switch}(u) = 1 $
$r_{\text{cutoff_smth}} \leqq r \leqq r_{\text{cutoff}}$のとき
$\text{switch}(u) = u^3 (-6u^2+15u-10) + 1$
$r_{\text{cutoff}} < r$のとき
$\text{switch}(u) = 0 $
となるような関数です。
$s(r)$ : $s(r_{ij})$の省略, $s(r) = \text{switch}(u) / r_{ij} $
方針
行列全体を考えると頭がバグるので、$R_i$→$\hat{R_i}$の行のみを考えます。
forwardでは、$[x, y, z]$を受け取って、$[s, \text{switch}・x/(r^2), \text{switch}・y/(r^2), \text{switch}・z/(r^2)]$を返せば良いです。
backwardではEを$[s, \hat{x}, \hat{y}, \hat{z}]$で微分したものを受け取って、
Eを$[x, y, z]$で微分したものを返せば良いです。
そのためには、Eを$[s, \hat{x}, \hat{y}, \hat{z}]$で微分した行列に何かの変換行列をかけて、Eを$[x, y, z]$で微分したものを返せば良いです。
変換行列は下のような行列です。この変換行列を手計算していきます。
対称性から、$S$を$x$で偏微分したもの、$\hat{x}$を$x$で偏微分したもの、$\hat{y}$を$x$で偏微分したものの3つ計算すれば良いです。
手計算
実装
これまでで手計算したものを、ひたすら実装します。
class custom_relative_coords_to_generalized_coords(torch.autograd.Function):
"""custom function of relative_coords_to_generalized_coords
"""
@staticmethod
def forward(ctx,
relative_coords:torch.Tensor,
r_cutoff:float,
):
relative_coords = relative_coords.reshape(-1, 3)
generalized_coords = torch.full((relative_coords.shape[0], 4), 1e2)
r_ij_norm = torch.linalg.norm(relative_coords, dim=1)
s_vec = smooth_cut_s_function(
r_ij_norm=r_ij_norm,
r_cutoff=r_cutoff
)
generalized_coords[:,0] = s_vec
generalized_coords[:,1:] = relative_coords * s_vec.view(-1, 1) / r_ij_norm.view(-1, 1)
ctx.save_for_backward(relative_coords, r_ij_norm, torch.Tensor(r_cutoff), s_vec)
return generalized_coords
@staticmethod
def backward(ctx, d_E_d_generalized_coords): # d_E_d_generalized_coords = d_E/d_generalized_coords
relative_coords, r_ij_norm, r_cutoff, s_vec = ctx.saved_tensors
d_generalized_coords_d_relative_coords = torch.zeros((relative_coords.shape[0], 4, 3))
u = (r_ij_norm - 0.1) / (r_cutoff - 0.1)
d_switch_d_u = 3*u*u*(-6*u*u + 15*u - 10) + u*u*u*(-12*u +15)
d_switch_d_u[r_ij_norm > r_cutoff] = 0
d_u_d_r = 1 / (r_cutoff - 0.1)
r_ji_norm_2_inv = (1/r_ij_norm) * (1/r_ij_norm)
g = d_switch_d_u * d_u_d_r * r_ji_norm_2_inv - s_vec * r_ji_norm_2_inv
# d_S/d_x, d_S/d_y, d_S/d_z
d_generalized_coords_d_relative_coords[:, 0, :] = g.view(-1, 1) * relative_coords
h = d_switch_d_u * d_u_d_r * r_ji_norm_2_inv * (1/r_ij_norm) - 2 * s_vec * (1/r_ij_norm) * r_ji_norm_2_inv
d_generalized_coords_d_relative_coords[:, 1:, :] = \
torch.bmm(h.view(-1, 1, 1)*relative_coords[:].reshape(-1, 3, 1), relative_coords.reshape(-1, 1, 3))
s_vec_div_r_ij_norm = s_vec / r_ij_norm
d_generalized_coords_d_relative_coords[:, 1, 0] += s_vec_div_r_ij_norm
d_generalized_coords_d_relative_coords[:, 2, 1] += s_vec_div_r_ij_norm
d_generalized_coords_d_relative_coords[:, 3, 2] += s_vec_div_r_ij_norm
d_E_d_relative_coords = torch.bmm(d_E_d_generalized_coords.view(-1, 1, 4), d_generalized_coords_d_relative_coords)
d_E_d_relative_coords.squeeze_()
return d_E_d_relative_coords, None
確認
一応、custom backwardとpytorchの自動のbackwardで、forceが一致するか確認しますが、ほぼ一致しました。
速度の比較
最後に
カスタム関数を作ってからわかったのですが、カスタム関数を作った場合はtorch scriptに変換できません。
次回予告、全部C++, Libtorchで書き直します。