2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Scipy.solve_ivpの中身について

Posted at

計算法というよりはプログラムの構成に関してです.
コードはすべてv1.9.2のソースからの引用(現時点ではLatest)です.適宜省略しつつ引用します.

実装されている機能

solve_ivpには,

  • 時間ステップ幅の調整
  • 精度の保証
  • 解の値の判定,指定した値の追跡

などの機能があります.上記の処理が入っているためソースコードはそれなりに煩雑です.ここでは処理の詳細には触れずに,プログラム全体としてどのような構成になっているのか,解法のメインである時間ループの計算がどこで・どのような順で実行されるのか,という点について調べたことを書きます.RK23を例として見てみます.

Signature

def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
              events=None, vectorized=False, args=None, **options):

solve_ivpがやっていること

solve_ivpの直下の記述はそれほど多くありません.実際の処理を行う各種クラスを呼び出して命令を与えるだけです.

各種引数のチェック,初期値の設定など

ユーザーから渡される時間配列など引数のチェックを行い,計算結果を格納する配列などが用意されます.初期時刻,初期値を設定します.値の追跡を行う場合は,それを格納するため配列も用意されます.パラメーターを渡してユーザー定義関数を呼び出します.

ソルバの設定

solverは解法別のクラスです.指定した解法のクラスに初期設定を渡してインスタンス化します.例えばRK23は,OdeSolver -> RungeKutta -> RK23 という継承関係になっています.METHODはモジュールのグローバルで定義され,値として各ソルバクラスを持っています.

# def solve_ivp() 内

if method in METHODS:    
    method = METHODS[method]

solver = method(fun, t0, y0, tf, vectorized=vectorized, **options)
# モジュール"ivp.py"
...
from .rk import RK23, RK45, DOP853
...
METHODS = {'RK23': RK23,
           'RK45': RK45,
           'DOP853': DOP853,
...

時間ループ

計算のメインとなるループです.solve_ivpでは解の値の判定と追跡・時間幅の調整が行われるので,単にステップを進めるだけではなく,面倒な処理が入ります.

処理1:時間ステップの実行

.step()により時間ステップを進めます.step()は以下で説明するように,基底クラスで定義されています.この中で何が起きるのかは後述します.

status = None
while status is None:
    message = solver.step()

処理2:指定された値の検出

引数としてeventsを設定した場合に行われます.例えば指定した値が$x$だったとします.ひとつ前の時刻t_oldと現在時刻t間において,

  1. 解が$x$を横切ったかどうか判定(find_active_events
  2. $y(t)-x=0$という方程式を$t$について解く(handle_events)(scipy.optimize.brentqという求根アルゴリズムが使われています).

というプロセスが入ります.解の値が$x$になったら時間ループから抜ける,みたいな使い方ができます.また,その値を正の方向から横切るか,負の方向から横切るかといった区別も可能です.

処理3:戻り値の設定

時間ステップ幅は誤差の大きさを元に内部的に調整されつつ進んでいくので,上記のstep()ではユーザーが指定した時刻での値が計算されているとは限りません.なのでユーザーが指定した時刻t_evalでの値を計算し直して戻り値に設定します.これに関わっているのはDenseOutputというクラスです.これについては一番最後に書きます.

OdeSolverクラス

基底となっているクラスです.すべての解法で基底として使われるので,全ての解法に共通な部分を担います.主には,

  • 共通する引数のチェック
  • 時間ステップを実行し,現在時刻t,ひとつ前の時刻t_old,ステップ長t - t_oldを保持する,
  • ステップ実行の成功・失敗についてのstatusを保持する,

機能が実装されています.
solve_ivp()内の時間ループからsolver.step()が呼び出されると,その中で_step_impl()が呼び出されます._step_impl()は例えばRungeKuttaクラスによってオーバーライドされており,結果としてrk_step()関数を呼び出します.

class OdeSolver:

...

    def step(self):
        """Perform one integration step"""
        ...
        t = self.t
        success, message = self._step_impl()

        ...

    def _step_impl(self):
        raise NotImplementedError

RK23クラス

RK23を見てみます.これはRungeKuttaクラスの子クラスです.RK23内部では各テーブルが定義されているだけです.ルンゲクッタ法の時間ステップ実行はOdeSolver(を継承したRungeKuttaから呼び出されるrk_step())の方にあります.

# モジュール "rk.py"

class RK23(RungeKutta):
    order = 3
    error_estimator_order = 2
    n_stages = 3
    C = np.array([0, 1/2, 3/4])
    A = np.array([
        [0, 0, 0],
        [1/2, 0, 0],
        [0, 3/4, 0]
    ])
    B = np.array([2/9, 1/3, 4/9])
    E = np.array([5/72, -1/12, -1/9, 1/8])
    P = np.array([[1, -4 / 3, 5 / 9],
                  [0, 1, -2/3],
                  [0, 4/3, -8/9],
                  [0, -1, 1]])

RungeKuttaクラス

陽的ルンゲクッタ法の実装です.主に,

  • rk_step()を呼び出して実際に計算を行い,
  • 誤差の大きさを判定し,時間ステップ幅を決めます.
# モジュール "rk.py"

class RungeKutta(OdeSolver):
    """Base class for explicit Runge-Kutta methods."""
    C: np.ndarray = NotImplemented
    A: np.ndarray = NotImplemented
    B: np.ndarray = NotImplemented
    E: np.ndarray = NotImplemented
    P: np.ndarray = NotImplemented
    order: int = NotImplemented
    error_estimator_order: int = NotImplemented
    n_stages: int = NotImplemented
    ...

解法別のクラス(RK23など)にオーバーライドされるための値が用意されていますね.長いので載せませんがdef __init__(self)以下では各引数の値をチェックする処理がされます.

ステップ実行,ステップ幅制御

これは長いですが丸ごと引用します.ルンゲクッタステップをひとつ進めると同時に時間ステップ幅の制御をしています.

  1. 現在のステップ実行で生じた誤差の大きさをerror_normの値で判定します.
  2. error_normが1以上ならstep_rejected = Trueとなり,ステップ幅は小さい方に調整されます.そして,もう1度現在時刻でステップが実行されます.
  3. error_normが1より小さいならstep_accepted = Trueとなり,次の時刻でのステップに進みます.この時ステップ幅をfactor倍します.
# クラス"RungeKutta"内

def _step_impl(self):
    t = self.t
    y = self.y

    max_step = self.max_step
    rtol = self.rtol
    atol = self.atol

    min_step = 10 * np.abs(np.nextafter(t, self.direction * np.inf) - t)

    if self.h_abs > max_step:
        h_abs = max_step
    elif self.h_abs < min_step:
        h_abs = min_step
    else:
        h_abs = self.h_abs

    step_accepted = False
    step_rejected = False

    while not step_accepted:
        if h_abs < min_step:
            return False, self.TOO_SMALL_STEP

        h = h_abs * self.direction
        t_new = t + h

        if self.direction * (t_new - self.t_bound) > 0:
            t_new = self.t_bound

        h = t_new - t
        h_abs = np.abs(h)

        y_new, f_new = rk_step(self.fun, t, y, self.f, h, self.A,
                               self.B, self.C, self.K)
        scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol
        error_norm = self._estimate_error_norm(self.K, h, scale)

        if error_norm < 1:
            if error_norm == 0:
                factor = MAX_FACTOR
            else:
                factor = min(MAX_FACTOR,
                             SAFETY * error_norm ** self.error_exponent)

            if step_rejected:
                factor = min(1, factor)

            h_abs *= factor

            step_accepted = True
        else:
            h_abs *= max(MIN_FACTOR,
                         SAFETY * error_norm ** self.error_exponent)
            step_rejected = True

    self.h_previous = h
    self.y_old = y

    self.t = t_new
    self.y = y_new

    self.h_abs = h_abs
    self.f = f_new

    return True, None

ルンゲクッタ法アルゴリズム

ルンゲクッタアルゴリズムの心臓部はrk_step()にあります.これはモジュールのグローバルで定義されています(なぜでしょうか?).コーディングはとてもシンプルです.

# モジュール"rk.py"のグローバル

def rk_step(fun, t, y, f, h, A, B, C, K):
    """Perform a single Runge-Kutta step."""
    K[0] = f
    for s, (a, c) in enumerate(zip(A[1:], C[1:]), start=1):
        dy = np.dot(K[:s].T, a[:s]) * h
        K[s] = fun(t + c * h, y + dy)

    y_new = y + h * np.dot(K[:-1].T, B)
    f_new = fun(t + h, y_new)

    K[-1] = f_new

    return y_new, f_new

DenseOutputRKDenseOutput

上で書いたように,ステップ幅は自動で選択されます.しかし,ユーザーとしては適当な時刻での値を返されても困るので,ユーザー指定の時刻t_evalでの値を戻り値に設定する必要があります.そのために,現在時刻tとひとつ前の時刻t_old,及びひとつ前の解y_oldなどのすでに計算された値を用いて,新しい時刻tでの値を再計算します.

これを行うためにOdeSolverにはdense_outputというメソッドがあり,また,DenseOutputというクラスがあります.DenseOutputはcallableになっていて,時刻tを引数としてy(t)を返すようになっています.以下のような順序で実行されます.

  1. solve_ivpの時間ループ内でOdeSolver.dense_outputが呼ばれる.
  2. OdeSolver._dense_output_implが呼ばれる.これはRungeKutta._dense_output_implによりオーバーライドされている.
  3. RKDenseOutputインスタンスが生成される.
  4. solve_ivpの時間ループ内でDenseOutput(t)が呼ばれる.
  5. DenseOutput.__call__から_call_implが呼ばれる.これはRKDenseOutput._call_implによりオーバーライドされている.結果としてここでy(t)の値が計算される.

これはちょっと驚きではないでしょうか.いわれてみれば当然かもしれませんが,こんな回りくどいことをしないといけないなんて!

まとめ

思っていたよりもいろんなことをやっていました.そりゃ遅いわけです.数値計算用のプログラムを作成する際など,書き方の参考になればよいなと思います.特に,再利用を意識してオブジェクト指向な書き方を取り入れたいヨという場合,自己流でやったら大変なことになった(自戒)のでソースコードを眺めて勉強しています.この記事ではおおざっぱな構成だけまとめました.誤差推定の方法や補間の方法,陰解法の実装などについては,そのうち実際に使いつつもう少し掘ってみたいと思います.

参考

関連

  • numpy.searchsorted
    時間ループ内で時刻を決めるために使われます.

  • scipy.optimize.brentq
    eventを設定した際に使われる方程式の球根アルゴリズムです.

  • E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential Equations I: Nonstiff Problems"
    標準的な教科書です.色んなところで引用されています.

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?