計算法というよりはプログラムの構成に関してです.
コードはすべてv1.9.2のソースからの引用(現時点ではLatest)です.適宜省略しつつ引用します.
実装されている機能
solve_ivp
には,
- 時間ステップ幅の調整
- 精度の保証
- 解の値の判定,指定した値の追跡
などの機能があります.上記の処理が入っているためソースコードはそれなりに煩雑です.ここでは処理の詳細には触れずに,プログラム全体としてどのような構成になっているのか,解法のメインである時間ループの計算がどこで・どのような順で実行されるのか,という点について調べたことを書きます.RK23を例として見てみます.
Signature
def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
events=None, vectorized=False, args=None, **options):
solve_ivpがやっていること
solve_ivp
の直下の記述はそれほど多くありません.実際の処理を行う各種クラスを呼び出して命令を与えるだけです.
各種引数のチェック,初期値の設定など
ユーザーから渡される時間配列など引数のチェックを行い,計算結果を格納する配列などが用意されます.初期時刻,初期値を設定します.値の追跡を行う場合は,それを格納するため配列も用意されます.パラメーターを渡してユーザー定義関数を呼び出します.
ソルバの設定
solver
は解法別のクラスです.指定した解法のクラスに初期設定を渡してインスタンス化します.例えばRK23
は,OdeSolver
-> RungeKutta
-> RK23
という継承関係になっています.METHOD
はモジュールのグローバルで定義され,値として各ソルバクラスを持っています.
# def solve_ivp() 内
if method in METHODS:
method = METHODS[method]
solver = method(fun, t0, y0, tf, vectorized=vectorized, **options)
# モジュール"ivp.py"
...
from .rk import RK23, RK45, DOP853
...
METHODS = {'RK23': RK23,
'RK45': RK45,
'DOP853': DOP853,
...
時間ループ
計算のメインとなるループです.solve_ivp
では解の値の判定と追跡・時間幅の調整が行われるので,単にステップを進めるだけではなく,面倒な処理が入ります.
処理1:時間ステップの実行
.step()
により時間ステップを進めます.step()
は以下で説明するように,基底クラスで定義されています.この中で何が起きるのかは後述します.
status = None
while status is None:
message = solver.step()
処理2:指定された値の検出
引数としてevents
を設定した場合に行われます.例えば指定した値が$x$だったとします.ひとつ前の時刻t_old
と現在時刻t
間において,
- 解が$x$を横切ったかどうか判定(
find_active_events
) - $y(t)-x=0$という方程式を$t$について解く(
handle_events
)(scipy.optimize.brentqという求根アルゴリズムが使われています).
というプロセスが入ります.解の値が$x$になったら時間ループから抜ける,みたいな使い方ができます.また,その値を正の方向から横切るか,負の方向から横切るかといった区別も可能です.
処理3:戻り値の設定
時間ステップ幅は誤差の大きさを元に内部的に調整されつつ進んでいくので,上記のstep()
ではユーザーが指定した時刻での値が計算されているとは限りません.なのでユーザーが指定した時刻t_eval
での値を計算し直して戻り値に設定します.これに関わっているのはDenseOutput
というクラスです.これについては一番最後に書きます.
OdeSolver
クラス
基底となっているクラスです.すべての解法で基底として使われるので,全ての解法に共通な部分を担います.主には,
- 共通する引数のチェック
- 時間ステップを実行し,現在時刻
t
,ひとつ前の時刻t_old
,ステップ長t - t_old
を保持する, - ステップ実行の成功・失敗についての
status
を保持する,
機能が実装されています.
solve_ivp()
内の時間ループからsolver.step()
が呼び出されると,その中で_step_impl()
が呼び出されます._step_impl()
は例えばRungeKutta
クラスによってオーバーライドされており,結果としてrk_step()
関数を呼び出します.
class OdeSolver:
...
def step(self):
"""Perform one integration step"""
...
t = self.t
success, message = self._step_impl()
...
def _step_impl(self):
raise NotImplementedError
RK23
クラス
RK23
を見てみます.これはRungeKutta
クラスの子クラスです.RK23
内部では各テーブルが定義されているだけです.ルンゲクッタ法の時間ステップ実行はOdeSolver
(を継承したRungeKutta
から呼び出されるrk_step()
)の方にあります.
# モジュール "rk.py"
class RK23(RungeKutta):
order = 3
error_estimator_order = 2
n_stages = 3
C = np.array([0, 1/2, 3/4])
A = np.array([
[0, 0, 0],
[1/2, 0, 0],
[0, 3/4, 0]
])
B = np.array([2/9, 1/3, 4/9])
E = np.array([5/72, -1/12, -1/9, 1/8])
P = np.array([[1, -4 / 3, 5 / 9],
[0, 1, -2/3],
[0, 4/3, -8/9],
[0, -1, 1]])
RungeKutta
クラス
陽的ルンゲクッタ法の実装です.主に,
-
rk_step()
を呼び出して実際に計算を行い, - 誤差の大きさを判定し,時間ステップ幅を決めます.
# モジュール "rk.py"
class RungeKutta(OdeSolver):
"""Base class for explicit Runge-Kutta methods."""
C: np.ndarray = NotImplemented
A: np.ndarray = NotImplemented
B: np.ndarray = NotImplemented
E: np.ndarray = NotImplemented
P: np.ndarray = NotImplemented
order: int = NotImplemented
error_estimator_order: int = NotImplemented
n_stages: int = NotImplemented
...
解法別のクラス(RK23
など)にオーバーライドされるための値が用意されていますね.長いので載せませんがdef __init__(self)
以下では各引数の値をチェックする処理がされます.
ステップ実行,ステップ幅制御
これは長いですが丸ごと引用します.ルンゲクッタステップをひとつ進めると同時に時間ステップ幅の制御をしています.
- 現在のステップ実行で生じた誤差の大きさを
error_norm
の値で判定します. -
error_norm
が1以上ならstep_rejected = True
となり,ステップ幅は小さい方に調整されます.そして,もう1度現在時刻でステップが実行されます. -
error_norm
が1より小さいならstep_accepted = True
となり,次の時刻でのステップに進みます.この時ステップ幅をfactor
倍します.
# クラス"RungeKutta"内
def _step_impl(self):
t = self.t
y = self.y
max_step = self.max_step
rtol = self.rtol
atol = self.atol
min_step = 10 * np.abs(np.nextafter(t, self.direction * np.inf) - t)
if self.h_abs > max_step:
h_abs = max_step
elif self.h_abs < min_step:
h_abs = min_step
else:
h_abs = self.h_abs
step_accepted = False
step_rejected = False
while not step_accepted:
if h_abs < min_step:
return False, self.TOO_SMALL_STEP
h = h_abs * self.direction
t_new = t + h
if self.direction * (t_new - self.t_bound) > 0:
t_new = self.t_bound
h = t_new - t
h_abs = np.abs(h)
y_new, f_new = rk_step(self.fun, t, y, self.f, h, self.A,
self.B, self.C, self.K)
scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol
error_norm = self._estimate_error_norm(self.K, h, scale)
if error_norm < 1:
if error_norm == 0:
factor = MAX_FACTOR
else:
factor = min(MAX_FACTOR,
SAFETY * error_norm ** self.error_exponent)
if step_rejected:
factor = min(1, factor)
h_abs *= factor
step_accepted = True
else:
h_abs *= max(MIN_FACTOR,
SAFETY * error_norm ** self.error_exponent)
step_rejected = True
self.h_previous = h
self.y_old = y
self.t = t_new
self.y = y_new
self.h_abs = h_abs
self.f = f_new
return True, None
ルンゲクッタ法アルゴリズム
ルンゲクッタアルゴリズムの心臓部はrk_step()
にあります.これはモジュールのグローバルで定義されています(なぜでしょうか?).コーディングはとてもシンプルです.
# モジュール"rk.py"のグローバル
def rk_step(fun, t, y, f, h, A, B, C, K):
"""Perform a single Runge-Kutta step."""
K[0] = f
for s, (a, c) in enumerate(zip(A[1:], C[1:]), start=1):
dy = np.dot(K[:s].T, a[:s]) * h
K[s] = fun(t + c * h, y + dy)
y_new = y + h * np.dot(K[:-1].T, B)
f_new = fun(t + h, y_new)
K[-1] = f_new
return y_new, f_new
DenseOutput
,RKDenseOutput
上で書いたように,ステップ幅は自動で選択されます.しかし,ユーザーとしては適当な時刻での値を返されても困るので,ユーザー指定の時刻t_eval
での値を戻り値に設定する必要があります.そのために,現在時刻t
とひとつ前の時刻t_old
,及びひとつ前の解y_old
などのすでに計算された値を用いて,新しい時刻t
での値を再計算します.
これを行うためにOdeSolver
にはdense_output
というメソッドがあり,また,DenseOutput
というクラスがあります.DenseOutput
はcallableになっていて,時刻t
を引数としてy(t)
を返すようになっています.以下のような順序で実行されます.
-
solve_ivp
の時間ループ内でOdeSolver.dense_output
が呼ばれる. -
OdeSolver._dense_output_impl
が呼ばれる.これはRungeKutta._dense_output_impl
によりオーバーライドされている. -
RKDenseOutput
インスタンスが生成される. -
solve_ivp
の時間ループ内でDenseOutput(t)
が呼ばれる. -
DenseOutput.__call__
から_call_impl
が呼ばれる.これはRKDenseOutput._call_impl
によりオーバーライドされている.結果としてここでy(t)
の値が計算される.
これはちょっと驚きではないでしょうか.いわれてみれば当然かもしれませんが,こんな回りくどいことをしないといけないなんて!
まとめ
思っていたよりもいろんなことをやっていました.そりゃ遅いわけです.数値計算用のプログラムを作成する際など,書き方の参考になればよいなと思います.特に,再利用を意識してオブジェクト指向な書き方を取り入れたいヨという場合,自己流でやったら大変なことになった(自戒)のでソースコードを眺めて勉強しています.この記事ではおおざっぱな構成だけまとめました.誤差推定の方法や補間の方法,陰解法の実装などについては,そのうち実際に使いつつもう少し掘ってみたいと思います.
参考
-
こちらのページでも似た話題があります:https://www.kosh.dev/article/4/
関連
-
numpy.searchsorted
時間ループ内で時刻を決めるために使われます. -
scipy.optimize.brentq
event
を設定した際に使われる方程式の球根アルゴリズムです. -
E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential Equations I: Nonstiff Problems"
標準的な教科書です.色んなところで引用されています.