このノートの目的
目標は、分子動力学法(MD)で必要となる力の計算(=エネルギーの原子座標に対する偏微分より定まる)のコーディング。
原子座標に対するエネルギーの偏微分を手計算ですべて求める作業、またその式を正しくコードに落とし込む作業は、ポテンシャル関数が少し複雑になると(Tersoffなどの3体間ポテンシャルで既に)非常に煩雑になる。
そこで、逆誤差伝播法を活用した、比較的簡単な偏微分の求め方と、コーディングの際の注意点を学ぶことを目的として、簡単な関数からステップバイステップで進めていくこととする。
まずはpythonでの実装を試みる。
簡単な関数を用いたテスト
e=e(\boldsymbol{h})=|\boldsymbol{h}|
\boldsymbol{h}=\boldsymbol{h(x)}=\frac{\boldsymbol{x}}{|\boldsymbol{x}|^2}
のように、$\boldsymbol{x}$(ベクトル)から$e$(スカラー)が
\boldsymbol{x} \to \boldsymbol{h} \to e
のように(2段階で)計算される場合を考える。$\boldsymbol{h}$は$\boldsymbol{x}$と同じ次元を持つベクトルである。
この順序で関数の値を求めていくことをforward処理と呼び、逆向きに微分を求めていくことをbackward処理と呼ぶことにする。
※あくまでコーディングのテストなので、関数の形そのものに特に意味はない。
微分の式
各変数による$e$の偏微分を$^*$を付けて表すことにする。すなわち
h_l^* := \frac{\partial e}{\partial h_l}
x_i^* := \frac{\partial e}{\partial x_i} = \sum_l \frac{\partial e}{\partial h_l}\cdot\frac{\partial h_l}{\partial x_i} = \sum_l h_l^*\frac{\partial h_l}{\partial x_i}
なお下付添字はベクトルの成分を表し、例えば$x_i$は$\boldsymbol{x}$の$i$番目の成分である。
一応出来たコード
forward処理で$\boldsymbol{x} \to \boldsymbol{h} \to e$と求め、backward処理で$1\to\boldsymbol{h}^*\to\boldsymbol{x}^*$と微分を求めている。最後に検算として、数値微分との比較をしている。
サンプルコード
'''
Testing back-propagation for derivative calculation
02: sample function e = e(h), h = h(x)
x: vector, h: vector, e: scalar
h(x) := x / x^2
e(h) := |x|
'''
import numpy as np
class Bead(object):
def __init__(self, nins, nout):
self.nins = nins
self.nout = nout
self.ar_out = np.zeros(nout)
self.ar_gout = np.zeros(nins)
self.ar_in = np.zeros(nins)
self.ar_gin = np.zeros(nout)
def resetg(self):
self.ar_gout[:] = 0
class BeadH(Bead):
def forward(self, ar_in):
xnor = np.linalg.norm(ar_in)
for i in range(len(ar_in)):
self.ar_in[i] = ar_in[i]
self.ar_out[i] = ar_in[i] / xnor / xnor
return self.ar_out
def backward(self, ar_gin):
xnor = np.linalg.norm(self.ar_in)
xnor2 = xnor * xnor
xnor4 = xnor2 * xnor2
for i in range(len(self.ar_in)):
dhdx = np.zeros(len(self.ar_in))
for l in range(len(self.ar_in)):
dhdx[l] = - 2.0*self.ar_in[i]*self.ar_in[l] / xnor4
if l == i : dhdx[l] += 1.0 / xnor2
self.ar_gout[i] += np.dot(ar_gin, dhdx)
return self.ar_gout
class BeadE(Bead):
def forward(self, ar_in):
xnor = np.linalg.norm(ar_in)
self.ar_in = ar_in
self.ar_out = xnor
return self.ar_out
def backward(self, ar_gin):
xnor = np.linalg.norm(self.ar_in)
for i in range(len(self.ar_in)):
self.ar_gout[i] = ar_gin*self.ar_in[i] / xnor
return self.ar_gout
x = [1.0, 2.0, 1.5] # as an example
print('x = ', x)
h = BeadH(len(x),len(x))
e = BeadE(len(x),1)
h.forward(x)
e.forward(h.ar_out)
e.backward(1.0)
h.backward(e.ar_gout)
dedx = h.ar_gout.copy()
print('h.ar_out = ', h.ar_out)
print('e.ar_out = ', e.ar_out)
print('e.ar_gout = ', e.ar_gout)
print('dedx (h.ar_gout) = ', dedx)
# numerical derivative
x_org = np.zeros(len(x))
dedx_num = np.zeros(len(x))
x_org = x.copy()
dx = 0.001
for i in range(len(x)):
x[i] = x_org[i] + dx
h.forward(x)
e.forward(h.ar_out)
ep = e.ar_out
x[i] = x_org[i] - dx
h.forward(x)
e.forward(h.ar_out)
em = e.ar_out
x[i] = x_org[i]
dedx_num[i] = (ep-em)/dx/2.0
print('numerical derivative:', dedx_num)
コードの解説
親クラス。コンストラクタと、リセット機能を書いておく。
class Bead(object):
def __init__(self, nins, nout):
self.nins = nins
self.nout = nout
self.ar_out = np.zeros(nout)
self.ar_gout = np.zeros(nins)
self.ar_in = np.zeros(nins)
self.ar_gin = np.zeros(nout)
def resetg(self):
self.ar_gout[:] = 0
ninsが入力変数の要素数(ベクトルの次元)、noutが出力側のそれ。
ar_in, ar_outはforward処理の入力(変数)と出力(関数の値)。
ar_gin, ar_goutはbackward処理の入力と出力。
リセット機能はar_goutをゼロ化する。
関数の実装(1)
関数$\boldsymbol{h(x)}$のforward, backward処理を実装する。
Beadクラスを継承したBeadHクラスを作る。
class BeadH(Bead):
def forward(self, ar_in):
xnor = np.linalg.norm(ar_in)
for i in range(len(ar_in)):
self.ar_in[i] = ar_in[i]
self.ar_out[i] = ar_in[i] / xnor / xnor
return self.ar_out
def backward(self, ar_gin):
xnor = np.linalg.norm(self.ar_in)
xnor2 = xnor * xnor
xnor4 = xnor2 * xnor2
for i in range(len(self.ar_in)):
dhdx = np.zeros(len(self.ar_in))
for l in range(len(self.ar_in)):
dhdx[l] = - 2.0*self.ar_in[i]*self.ar_in[l] / xnor4
if l == i : dhdx[l] += 1.0 / xnor2
self.ar_gout[i] += np.dot(ar_gin, dhdx)
return self.ar_gout
forward関数を定義し、
h_i(\boldsymbol{x})=\frac{x_i}{|\boldsymbol{x}|^2}
を求めていることがわかる。
backward関数では、
\frac{\partial\boldsymbol{h(x)}}{\partial x_i} =
\frac{\partial\boldsymbol{x}}{\partial x_i}|\boldsymbol{x}|^{-2}
-\frac{\boldsymbol{x}}{|\boldsymbol{x}|^4}\frac{\partial|\boldsymbol{x}|^2}{\partial x_i}
を求めて、ar_ginとして入力された $h_l^*$ と、上式で得た$\frac{\partial h_l}{\partial x_i}$の内積($\sum_l h_l^* \frac{\partial h_l}{\partial x_i}$)をとることで$x_i^*$を計算し、これをar_goutとして出力する。
なお蛇足であるが、上式において
\frac{\partial\boldsymbol{x}}{\partial x_i}=\boldsymbol{e}_i
\frac{\partial|\boldsymbol{x}|^2}{\partial x_i}=\frac{\partial}{\partial x_i}\sum_j x_j^2=2x_i
である。
関数の実装(2)
関数$e(\boldsymbol{h})$のforward, backward処理を実装する。
Beadクラスを継承したBeadEクラスを作る。
class BeadE(Bead):
def forward(self, ar_in):
xnor = np.linalg.norm(ar_in)
self.ar_in = ar_in
self.ar_out = xnor
return self.ar_out
def backward(self, ar_gin):
xnor = np.linalg.norm(self.ar_in)
for i in range(len(self.ar_in)):
self.ar_gout[i] = ar_gin*self.ar_in[i] / xnor
return self.ar_gout
forward関数を定義し、
e(\boldsymbol{h})=|\boldsymbol{h}|
を求めていることがわかる。
backward関数では、
\frac{\partial e}{\partial h_i}=\frac{\partial}{\partial h_i}\left(\sum_j h_j^2\right)^{1/2}=\frac{h_i}{|\boldsymbol{h}|} \quad(=\boldsymbol{h}^*)
を求め、これをar_outとして出力する。なお、backward処理はここから始まるため、ar_inとしては1を与えている。
実装した関数を用いたforward, backward処理
x = [1.0, 2.0, 1.5] # as an example
print('x = ', x)
h = BeadH(len(x),len(x))
e = BeadE(len(x),1)
h.forward(x)
e.forward(h.ar_out)
e.backward(1.0)
h.backward(e.ar_gout)
dedx = h.ar_gout.copy()
print('h.ar_out = ', h.ar_out)
print('e.ar_out = ', e.ar_out)
print('e.ar_gout = ', e.ar_gout)
print('dedx (h.ar_gout) = ', dedx)
$\boldsymbol{x}$として適当なベクトルを与え、BeadH, BeadEをインスタンス化したのち、forward処理で$\boldsymbol{x}\to\boldsymbol{h}\to e$と求め、その後backward処理で$1\to\boldsymbol{h}^*\to\boldsymbol{x}^*$と求める。
数値微分で検算
# numerical derivative
x_org = np.zeros(len(x))
dedx_num = np.zeros(len(x))
x_org = x.copy()
dx = 0.001
for i in range(len(x)):
x[i] = x_org[i] + dx
h.forward(x)
e.forward(h.ar_out)
ep = e.ar_out
x[i] = x_org[i] - dx
h.forward(x)
e.forward(h.ar_out)
em = e.ar_out
x[i] = x_org[i]
dedx_num[i] = (ep-em)/dx/2.0
print('numerical derivative:', dedx_num)
最後に、正しく微分が求められているかどうかを確認するため、数値微分したものと比較する。
テスト結果
数値微分の結果と概ね一致していることが確認できる。