LoginSignup
1
1

More than 3 years have passed since last update.

Lorenz96をExtend Kalman Filterでデータ同化するJuliaのコード

Posted at

概要

Lorenz96をJuliaで解く その2
でLorenz96モデルを解くコードを書いたので,それ使って,Extend Kalman Filter(EKF)でデータ同化するコードを作りました。

グローバル変数でデータのやり取りしてたときは40変数のLorenz96で,1年分(7300ステップ)のデータ同化するのに30sかかってました。グローバル変数使わない,引数で全部データやり取りする下記のコードでは同じ計算が2sで終わります。

コード

メインの処理

using PyPlot
using LinearAlgebra
using HDF5
using BenchmarkTools

function main(alpha)
    ##Fの大きさで挙動が変わる
    F = 8
    N = 40
    # 6時間おきにデータ同化。ただし,計算はdt=0.01ですすめる。
    dt = 0.01
    # 1年分計算
    #tend = 73
    tend=73
    # 同化するタイミング。
    # 今回はdt=0.05おきにデータ同化
    interval = 5
    nstep = Int32(tend/dt)

    file_name = "L96_true_obs.h5"
    Xn_true, y = read_true(file_name)
    Xa, Xf, Pa, Pf, R, H = init(F, dt, N, nstep, Xn_true)
    Xa, Xf, Pa, Pf = step(dt, F, N, 2, nstep, interval, alpha, Xf, Xa, Pa, Pf, y, R, H)

    trPa = zeros(Float64, nstep)
    rms_f = zeros(Float64, nstep)
    rms_o = zeros(Float64, nstep)
    rms_a = zeros(Float64, nstep)
    for i=1:interval:nstep
        trPa[i] = sqrt(tr(Pa[:,:,i])/N)
        rms_f[i] = norm(Xf[:,i]-Xn_true[:,Int32((i-1)/interval)+1])/sqrt(N)
        rms_o[i] = norm(y[:,Int32((i-1)/interval)+1]-Xn_true[:,Int32((i-1)/interval)+1])/sqrt(N)
        rms_a[i] = norm(Xa[:,i]-Xn_true[:,Int32((i-1)/interval)+1])/sqrt(N)
    end  

    plot(dt, alpha, nstep, interval, trPa, rms_o, rms_a)

end

関数定義


# dx/dt=f(x) 
# Lorentz96の方程式
function f(x, F, N)
    g = fill(0.0, N)
    for i=3:N-1
        g[i] = (x[i+1]-x[i-2])x[i-1] - x[i] + F
    end

    # 周期境界
    g[1] = (x[2]-x[N-1])x[N] - x[1] + F
    g[2] = (x[3]-x[N])x[1] - x[2] + F
    g[N] = (x[1]-x[N-2])x[N-1] - x[N] + F

    return g
end

# L96をRK4で解く
function Model(xold, dt, F, N)
    k1 = f(xold, F, N)
    k2 = f(xold + k1*dt/2., F, N)
    k3 = f(xold + k2*dt/2., F, N)
    k4 = f(xold + k3*dt, F, N)

    xnew = xold + dt/6.0*(k1 + 2.0k2 + 2.0k3 + k4)
end

# 初期化
function read_true(file_name)
    # true value
    # 解析における真の値(Xn_true)と観測値(y)を保存したファイルをリード。
    # 観測値は真の値に対して,分散1の正規分布乱数で誤差を与えて,別途作る。
    file = h5open(file_name, "r") 
    Xn_true = read(file, "Xn_true")
    y = read(file, "Xn_obs")

    close(file)

    return Xn_true, y
end

function init(F, dt, N, nstep, Xn_true)
    R = Matrix{Float64}(I, N, N)
    H = Matrix{Float64}(I, N, N) # 観測値yもNx1なので,HをNxNの単位行列

    Xa = zeros(Float64, N, nstep)
    Xf = zeros(Float64, N, nstep)
    Pa = zeros(Float64, N, N, nstep)
    Pf = zeros(Float64, N, N, nstep)

    # 解析も予測も真の値はわからない。真の値に分散0.001の擾乱を与えて,1年分時間発展させて,
    #アトラクタに行くようにしてその値を初期値とする。
    X = copy(Xn_true[:,1] + 0.001randn(N))

    for i=1:7300
        X = Model(X, dt, F, N)
    end
    Xa[:,1] = copy(X)

    # Lorentz96の平均誤差の定常値5を基準にして,Paの初期値設定
    Pa[:,:,1] = Matrix{Float64}(25I, N, N)

    Xf[:,1] = copy(Xa[:,1])

    return Xa, Xf, Pa, Pf, R, H 
end

# モデル行列Mの生成。ヤコビアンを計算
function makeM(X, F, dt, N)
    delta = 1e-2
    E = Matrix{Float64}(I, N, N)
    M = zeros(Float64, N, N)

    for j=1:N
        M[:, j] = (Model(X + delta*E[:, j], dt, F, N) - Model(X, dt, F, N))/delta
    end

    return M
end

# データ同化の実行
# alpha: infrationのファクタ
function step(dt, F, N, nstart, nend, interval, alpha, Xf, Xa, Pa, Pf, y, R, H)
    for i=nstart:nend
        # forecast
        Xf[:,i] = Model(Xa[:,i-1], dt, F, N)
        M = makeM(Xf[:,i], F, dt, N)
        #Pf[:,:,i] = 1.7M*Pa[:,:,i-interval]*M'
        Pf[:,:,i] = M*Pa[:,:,i-1]*M'*alpha

        # no data assimilation step
        Xa[:,i] = Xf[:,i]
        Pa[:,:,i] = Pf[:,:,i]

        # data assimilation step
        # 5ステップおきにデータ同化を行う。
        if (i-1)%interval == 0
            K = Pf[:,:,i]*H'*inv(H*Pf[:,:,i]*H' + R)
            Xa[:,i] = Xf[:,i] + K*(y[:,Int32((i-1)/interval)+1]-H*Xf[:,i])
            Pa[:,:,i] = (I - K*H)*Pf[:,:,i]
        end

    end
    return Xa, Xf, Pa, Pf
end

グラフ描画

function plot(dt, alpha, nstep, interval, trPa, rms_o, rms_a)
    fig = plt.figure(figsize=(10, 5.))
    ax1 = fig.add_subplot(121)
    nstart = 1
    nend = nstep

    time = ((nstart:interval:nend).-1)*dt/0.2


    ax1.set_title("alpha=" * string(alpha))
    ax1.set_xlabel("time (day)")
    ax1.set_ylabel("RMSE")
    ax1.plot(time, trPa[nstart:interval:nend], label="trace(Pa)")
    ax1.plot(time, rms_o[nstart:interval:nend], label="RMSE observation")
    ax1.plot(time, rms_a[nstart:interval:nend], label="RMSE analysis")
    ax1.set_ylim(0, 6)
    ax1.legend()

    ax2 = fig.add_subplot(122)
    nstart = 6
    nend = nstep

    time = ((nstart:interval:nend).-1)*dt/0.2

    ax2.set_title("alpha=" * string(alpha))
    ax2.set_xlabel("time (day)")
    ax2.set_ylabel("RMSE")
    ax2.plot(time, trPa[nstart:interval:nend], label="trace(Pa)")
    ax2.plot(time, rms_o[nstart:interval:nend], label="RMSE observation")
    ax2.plot(time, rms_a[nstart:interval:nend], label="RMSE analysis")
    ax2.set_ylim(0, 1.5)
    ax2.legend()

    plt.savefig("EKF_RMSE_alpha_"*string(alpha)*".png")
end

結果

実行してみると

alpha=1.02
@time main(alpha)

2.560520 seconds (12.07 M allocations: 5.418 GiB, 34.35% gc time)

誤差のRMSはこんな感じ。観測の誤差が大体1(分散1の誤差を真値に与えてるので当然)で,データ同化したほうが0.2程度まで下がっています。

EKF_RMSE_alpha_1.02.png

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1