常微分方程式

外部モジュールを使わずに素の Python だけで常微分方程式の数値解 (近似解) を求める (8 次 12 段のルンゲクッタ法)

外部モジュールは使わずに Python3 だけで常微分方程式の数値解 (近似解) を求めてみる。古典的ルンゲクッタ法による解法はどこにでもあるのでここでは E.B. Shanks による 8 次 12 段のルンゲクッタ型公式を用いた。

任意の常微分方程式に適用しやすくする手段として解法全体をクラス化した。計算手順および各係数については『パソコンで見る天体の動き』(長沢 工, 檜山 澄子, 地人書館, 1992/10/1) に依った。

例として次の連立微分方程式を解いてみる。
x'(t) = y(t)
y'(t) = t - x(t)

shanks8.py
# Python による常微分方程式の数値解法 / Shanks による 12 段 8 次の Runge-Kutta 法
class Shanks8:
    def __init__(self, funcs, t0, inits, h, numOfDiv=1):
        self.funcs    = funcs
        self.t0       = t0
        self.inits    = inits
        self.dim      = len(funcs)
        self.numOfDiv = numOfDiv
        self.h        = h / self.numOfDiv
        self.f        = [[None for i in range(self.dim)] for i in range(12)]
        self.temp     = [[None for i in range(self.dim)] for i in range(11)]
    def update(self):
        for i in range(self.numOfDiv):
            for j in range(self.dim):
                self.f[0][j] = self.funcs[j](self.t0 , *self.inits)
                self.temp[0][j] = self.inits[j] + (self.h/9) * (self.f[0][j])
            for j in range(self.dim):
                self.f[1][j] = self.funcs[j](self.t0 + self.h/9 , *self.temp[0])
                self.temp[1][j] = self.inits[j] + (self.h/24) * (self.f[0][j] + 3 * self.f[1][j])
            for j in range(self.dim):
                self.f[2][j] = self.funcs[j](self.t0 + self.h/6 , *self.temp[1])
                self.temp[2][j] = self.inits[j] + (self.h/16) * (self.f[0][j] + 3 * self.f[2][j])
            for j in range(self.dim):
                self.f[3][j] = self.funcs[j](self.t0 + self.h/4 , *self.temp[2])
                self.temp[3][j] = self.inits[j] + (self.h/500) * (29 * self.f[0][j] + 33 * self.f[2][j] - 12 * self.f[3][j])
            for j in range(self.dim):
                self.f[4][j] = self.funcs[j](self.t0 + self.h/10 , *self.temp[3])
                self.temp[4][j] = self.inits[j] + (self.h/972) * (33 * self.f[0][j] + 4 * self.f[3][j] + 125 * self.f[4][j])
            for j in range(self.dim):
                self.f[5][j] = self.funcs[j](self.t0 + self.h/6 , *self.temp[4])
                self.temp[5][j] = self.inits[j] + (self.h/36) * (-21 * self.f[0][j] + 76 * self.f[3][j] + 125 * self.f[4][j] - 162 * self.f[5][j])
            for j in range(self.dim):
                self.f[6][j] = self.funcs[j](self.t0 + self.h/2 , *self.temp[5])
                self.temp[6][j] = self.inits[j] + (self.h/243) * (-30 * self.f[0][j] - 32 * self.f[3][j] + 125 * self.f[4][j] + 99 * self.f[6][j])
            for j in range(self.dim):
                self.f[7][j] = self.funcs[j](self.t0 + self.h*2/3 , *self.temp[6])
                self.temp[7][j] = self.inits[j] + (self.h/324) * (1175 * self.f[0][j] - 3456 * self.f[3][j] - 6250 * self.f[4][j] + 8424 * self.f[5][j] + 242 * self.f[6][j] - 27 * self.f[7][j])
            for j in range(self.dim):
                self.f[8][j] = self.funcs[j](self.t0 + self.h/3 , *self.temp[7])
                self.temp[8][j] = self.inits[j] + (self.h/324) * (293 * self.f[0][j] - 852 * self.f[3][j] - 1375 * self.f[4][j] + 1836 * self.f[5][j] - 118 * self.f[6][j] + 162 * self.f[7][j] + 324 * self.f[8][j])
            for j in range(self.dim):
                self.f[9][j] = self.funcs[j](self.t0 + self.h*5/6 , *self.temp[8])
                self.temp[9][j] = self.inits[j] + (self.h/1620) * (1303 * self.f[0][j] - 4260 * self.f[3][j] - 6875 * self.f[4][j] + 9990 * self.f[5][j] + 1030 * self.f[6][j] + 162 * self.f[9][j])
            for j in range(self.dim):
                self.f[10][j] = self.funcs[j](self.t0 + self.h*5/6 , *self.temp[9])
                self.temp[10][j] = self.inits[j] + (self.h/4428) * (-8595 * self.f[0][j] + 30720 * self.f[3][j] + 48750 * self.f[4][j] - 66096 * self.f[5][j] + 378 * self.f[6][j] - 729 * self.f[7][j] - 1944 * self.f[8][j] - 1296 * self.f[9][j] + 3240 * self.f[10][j])
            for j in range(self.dim):
                self.f[11][j] = self.funcs[j](self.t0 + self.h , *self.temp[10])
            for j in range(self.dim):
                self.inits[j] += (self.h/840) * (41 * self.f[0][j] + 216 * self.f[5][j] + 272 * self.f[6][j] + 27 * self.f[7][j] + 27 * self.f[8][j] + 36 * self.f[9][j] + 180 * self.f[10][j] + 41 * self.f[11][j])
            self.t0 += self.h
        return self
    def print(self):
        print(self.t0, self.inits)
        return self


###########################################################################
# 解くべき聯立微分方程式を定義する。リストで括っておく。
def xDot(t, x, y):  # x'(t) = y(t)
    return y
def yDot(t, x, y):  # y'(t) = t - x(t)
    return t - x
funcs = [xDot, yDot]

# 独立変数の開始値と終了値とを指定する。
t0 = 0
tMax = 100

# 従属変数の初期値を指定する。リストで括っておく。
x0 = 0
y0 = 0
inits = [x0, y0]

# 刻み幅を指定する。2 の整数乗分の 1 にすることが望ましい。
h = 1/2**2

# 1 刻みの内部分割数を指定する場合は 2 の整数乗にすることが望ましい。
# numOfDiv = 2**4

# 1 刻みだけ計算する函数を実体化して、
sol = Shanks8(funcs, t0, inits, h)

# 初期値を更新しながら必要な回数だけ実行を繰り返す。
while sol.t0 < tMax:
    sol.update().print()

###########################################################################
from math import *
print("真値  " + str([tMax - sin(tMax), 1 - cos(tMax)]))

実行結果 (t [x(t), y(t)]):
image.png

関連: http://ti-nspire.hatenablog.com/