理工系の学生さんは、数値計算によるシミュレーションを回すケースがよくあるかと思います。そこで、Pythonで大掛かりな数値計算をする時に、比較的簡単にスピードを上げることができるjit化と並列化についてパパっと紹介してみようと思います。もちろん、jit化についてはCython
、並列化についてはconcurrent.futures.ProcessPoolExecutor
をはじめ他のオプションもありますが、とりあえず一番導入が簡単(と思われる)、numbaとRayを導入してみます。
TL;DR
numbaでjit化、Rayで並列化しましょう。コードはこちらにあります:https://github.com/Yuricst/qiitaEx/blob/main/speedup_ode.py
諸々インポートと初期設定
まず使用するモジュールをインポートします。インポート後、並列化用のRayモジュールを初期起動するため、Ray.init()
をしています。
import numpy as np
from scipy.integrate import solve_ivp
import time
import ray
from numba import jit
# start ray
ray.init()
ここでscipy.integrate
で実装されているインテグレータodeint
とsolve_ivp
について一言。solve_ivp
の方が新しく、event handlingにも対応している為、今回はこちらを使用しますが、(2021年1月現在)筆者の経験上odeint
の方が10倍ほど早いのでおススメです。なお、この二つの関数は(なぜか)渡す微分方程式関数のインプットの順番がデフォルトでば違うという謎仕様ですが、odeint()
の因数にtfirst=True
とすることで、solve_ivp
と同じ形の微分方程式関数を使用することが出来ます。
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.odeint.html
https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
扱う例題
まあこれはなんでもいいんですけど… 自分のよく知ってる分野から問題を作ってみます。ということで、二体問題内において、基準となる軌道の軌道決定精度の誤差を加味し、一周期後の状態の分布を知りたいとします。そこで、基準となる軌道の初期状態と、この初期状態に一定の誤差を加算した初期状態を50,000個を解きたいとします。
力学系がこちら
$$
\ddot{x}=-\dfrac{\mu}{r^3}x ,,, , \quad \ddot{y}=-\dfrac{\mu}{r^3}y ,,, , \quad \ddot{z}=-\dfrac{\mu}{r^3}z
,,,, \mathrm{where} ,,,, r = \sqrt{x^2 + y^2 + z^2}
$$
初期状態には $\mathbf{x}_0 = \left[7000, 200, -4000, 0, 7.8, 0\right]$(位置ベクトルが $\mathbf{r}=\left[ 7000, 200, -4000 \right]$ $\mathrm{km}$、速度ベクトルが $\mathbf{v}=\left[ 0, 7.8, 0 \right]$ $\mathrm{km/s}$)を使用します。これに位置と速度を各方向に正規分布に沿って変化させた初期状態を50,000個生成、リスト化します。上の式の$\mu$は重量定数と天体の質量を掛け合わせたもので、下のコードの gm=398600.44
にあたります。
# define initial conditions
gm = 398600.44
state0 = np.array([7000.0, 200.0, -4000.0, 0.0, 7.8, 0.0])
sigma_r = 3.0 # [km]
sigma_v = 0.1 # [km/sec]
period = 3600.0 # 適当に一時間 [sec]
ics = []
N_ics = 50000
for idx in range(N_ics):
statep = state0 + np.concatenate((np.random.normal(size=(3))*sigma_r, np.random.normal(size=(3))*sigma_v), axis=0)
ics.append( statep )
numbaによるjit化
まず、何回も呼ぶことになる関数はなるべくjit化しましょう。numbaは全てのPythonモジュールに対応しているわけではなく、主にNumPy関係の関数を使用したものの場合のみ、関数をjit化できます。最も、微分方程式の数値計算において、式の右辺関数はシンプルな関数が多いため、jit化できるケースが多いでしょう。今回は速度の比較も兼ねて、jit化しない場合とする場合で二つの関数を用意します。
見てわかる通り、唯一の違いは関数名の前の行に入る@jit
デコレータです。これで、twobody()
が最初に使用される時、この関数はコンパイル(just-in-time compiled)されます。なお、jit化することで得られる速度の差は関数の長さによります。今回使用している力学系はさほど複雑なものではないため、大幅な実行時間の削減は期待できませんが、大きいデータの生成やハンドルをする関数の場合、違いが大きく出てきます。
# jit化しない場合
def twobody_nojit(t, state, gm):
x,y,z = state[0], state[1], state[2]
vx,vy,vz = state[3], state[4], state[5]
r = np.sqrt(x**2 + y**2 + z**2)
ax = -(gm/r**3)*x
ay = -(gm/r**3)*y
az = -(gm/r**3)*z
return np.array([ vx, vy, vz, ax, ay, az ])
# jit化する場合
@jit(nopython=True)
def twobody(t, state, gm):
x,y,z = state[0], state[1], state[2]
vx,vy,vz = state[3], state[4], state[5]
r = np.sqrt(x**2 + y**2 + z**2)
ax = -(gm/r**3)*x
ay = -(gm/r**3)*y
az = -(gm/r**3)*z
return np.array([ vx, vy, vz, ax, ay, az ])
Rayによる並列化
次に、Pythonによる並列化を試みます。並列化に使用できるモジュールはイロイロあり、Rayが特に優れているとは限らないものの、相対的に記述が簡単です。並列化する為に、(1)『軌道の数値計算』 → (2)『一周期後の状態の保存』を、因数として渡されたすべての初期状態について行う関数を書きます。因みに、この関数はscipy
のsolve_ivp
を使用している為、numba
によるjit
化はできません。今回も、並列化する場合としない場合を比較するため、別々の関数を用意します。
まとめると、
- 単一プロセス、jitなし:
compute_trajectory_nojit
- 単一プロセス、jitあり:
compute_trajectory
- マルチプロセス(並列化)、jitあり:
compute_trajectory_parallel
となります。上記オプションの中で、下に行けば行くほど処理速度が短くなります。(『マルチプロセス、jitなし』は省略します。)なお、インテグレータの手法にはLSODAを使用しています。もちろん、インテグレータの選択によっても速度は変わってきますが、今回は話が長くなるので省略します。以下のスクリプトのデコレータ@ray.remote
が機能するためには、前記の通りray.init()
をインポート後に済ませておく必要があるのでお忘れなく。
# define function used with single process
def compute_trajectory_nojit(ics, tf, gm):
results_single = []
for ic in ics:
sol = solve_ivp(twobody_nojit, (0,tf), ic, args=(gm,), method="LSODA")
statef = np.array([sol.y[0][-1], sol.y[1][-1], sol.y[2][-1], sol.y[3][-1], sol.y[4][-1], sol.y[5][-1]])
results_single.append( statef )
return results_single
# define function used with single process
def compute_trajectory(ics, tf, gm):
results_single = []
for ic in ics:
sol = solve_ivp(twobody, (0,tf), ic, args=(gm,), method="LSODA")
statef = np.array([sol.y[0][-1], sol.y[1][-1], sol.y[2][-1], sol.y[3][-1], sol.y[4][-1], sol.y[5][-1]])
results_single.append( statef )
return results_single
# define function to be parallelised
@ray.remote
def compute_trajectory_parallel(ics, tf, gm):
results_mp = []
for ic in ics:
sol = solve_ivp(twobody, (0,tf), ic, args=(gm,), method="LSODA")
statef = np.array([sol.y[0][-1], sol.y[1][-1], sol.y[2][-1], sol.y[3][-1], sol.y[4][-1], sol.y[5][-1]])
results_mp.append( statef )
return results_mp
実行結果
上記の関数を使用し実行してみましょう。
# ========================================================================== #
# single process
print('---------------------------------\nStarting single process...')
tstart_single_nojit = time.time()
results_single = compute_trajectory_nojit(ics, period, gm)
tend_single_nojit = time.time()
dt_single_nojit = tend_single_nojit - tstart_single_nojit
print(f"Single process: {dt_single_nojit:2.4f} sec")
# ========================================================================== #
# single process with jit
print(f"---------------------------------\nStarting single process with jit...")
tstart_single = time.time()
results_single = compute_trajectory(ics, period, gm)
tend_single = time.time()
dt_single = tend_single - tstart_single
print(f"Single process with jit: {dt_single:2.4f} sec")
# ========================================================================== #
# multiple process with jit
print(f"---------------------------------\nStarting multiple process with jit...")
tstart_mp = time.time()
results_parallel = compute_trajectory_parallel.remote(ics, period, gm)
tend_mp = time.time()
dt_mp = tend_mp - tstart_mp
print(f"Multiple process with jit: {dt_mp:2.4f} sec")
実行結果は以下の通りです。上から、『単一プロセス、jitなし』、『単一プロセス、jitあり』、『マルチプロセス、jitあり』
実行環境は物理4コアです。問題のサイズによって時間の相対比はかかりますが、例えば50,000件で初期状態でこの違いが出るので、やはりjit化、並列化は念頭に入れておく価値があるかと思います。
---------------------------------
Starting single process...
Single process: 72.9388 sec
---------------------------------
Starting single process with jit...
Single process with jit: 43.0668 sec
---------------------------------
Starting multiple process with jit...
Multiple process with jit: 3.1591 sec
さいごに
もちろん、ただjit化、並列化さえすればいい、というほど上手くいかないこともあります。問題のサイズや、並列化するにしてもどこまでの作業を単一コアにやらせるかなど、細かくチューニングすべきところは沢山あります。ただ、例えば既にコードを書き始めている状態で時間短縮したい時場合、numbaもRayも共に余分な開発コストが少なく済むオプションではないかと感じております。
参考になれば幸いです!