概要
Lorenz96をJuliaで解く その2
でLorenz96モデルを解くコードを書いたので,それ使って,Extend Kalman Filter(EKF)でデータ同化するコードを作りました。
グローバル変数でデータのやり取りしてたときは40変数のLorenz96で,1年分(7300ステップ)のデータ同化するのに30sかかってました。グローバル変数使わない,引数で全部データやり取りする下記のコードでは同じ計算が2sで終わります。
コード
メインの処理
using PyPlot
using LinearAlgebra
using HDF5
using BenchmarkTools
function main(alpha)
##Fの大きさで挙動が変わる
F = 8
N = 40
# 6時間おきにデータ同化。ただし,計算はdt=0.01ですすめる。
dt = 0.01
# 1年分計算
#tend = 73
tend=73
# 同化するタイミング。
# 今回はdt=0.05おきにデータ同化
interval = 5
nstep = Int32(tend/dt)
file_name = "L96_true_obs.h5"
Xn_true, y = read_true(file_name)
Xa, Xf, Pa, Pf, R, H = init(F, dt, N, nstep, Xn_true)
Xa, Xf, Pa, Pf = step(dt, F, N, 2, nstep, interval, alpha, Xf, Xa, Pa, Pf, y, R, H)
trPa = zeros(Float64, nstep)
rms_f = zeros(Float64, nstep)
rms_o = zeros(Float64, nstep)
rms_a = zeros(Float64, nstep)
for i=1:interval:nstep
trPa[i] = sqrt(tr(Pa[:,:,i])/N)
rms_f[i] = norm(Xf[:,i]-Xn_true[:,Int32((i-1)/interval)+1])/sqrt(N)
rms_o[i] = norm(y[:,Int32((i-1)/interval)+1]-Xn_true[:,Int32((i-1)/interval)+1])/sqrt(N)
rms_a[i] = norm(Xa[:,i]-Xn_true[:,Int32((i-1)/interval)+1])/sqrt(N)
end
plot(dt, alpha, nstep, interval, trPa, rms_o, rms_a)
end
関数定義
# dx/dt=f(x)
# Lorentz96の方程式
function f(x, F, N)
g = fill(0.0, N)
for i=3:N-1
g[i] = (x[i+1]-x[i-2])x[i-1] - x[i] + F
end
# 周期境界
g[1] = (x[2]-x[N-1])x[N] - x[1] + F
g[2] = (x[3]-x[N])x[1] - x[2] + F
g[N] = (x[1]-x[N-2])x[N-1] - x[N] + F
return g
end
# L96をRK4で解く
function Model(xold, dt, F, N)
k1 = f(xold, F, N)
k2 = f(xold + k1*dt/2., F, N)
k3 = f(xold + k2*dt/2., F, N)
k4 = f(xold + k3*dt, F, N)
xnew = xold + dt/6.0*(k1 + 2.0k2 + 2.0k3 + k4)
end
# 初期化
function read_true(file_name)
# true value
# 解析における真の値(Xn_true)と観測値(y)を保存したファイルをリード。
# 観測値は真の値に対して,分散1の正規分布乱数で誤差を与えて,別途作る。
file = h5open(file_name, "r")
Xn_true = read(file, "Xn_true")
y = read(file, "Xn_obs")
close(file)
return Xn_true, y
end
function init(F, dt, N, nstep, Xn_true)
R = Matrix{Float64}(I, N, N)
H = Matrix{Float64}(I, N, N) # 観測値yもNx1なので,HをNxNの単位行列
Xa = zeros(Float64, N, nstep)
Xf = zeros(Float64, N, nstep)
Pa = zeros(Float64, N, N, nstep)
Pf = zeros(Float64, N, N, nstep)
# 解析も予測も真の値はわからない。真の値に分散0.001の擾乱を与えて,1年分時間発展させて,
#アトラクタに行くようにしてその値を初期値とする。
X = copy(Xn_true[:,1] + 0.001randn(N))
for i=1:7300
X = Model(X, dt, F, N)
end
Xa[:,1] = copy(X)
# Lorentz96の平均誤差の定常値5を基準にして,Paの初期値設定
Pa[:,:,1] = Matrix{Float64}(25I, N, N)
Xf[:,1] = copy(Xa[:,1])
return Xa, Xf, Pa, Pf, R, H
end
# モデル行列Mの生成。ヤコビアンを計算
function makeM(X, F, dt, N)
delta = 1e-2
E = Matrix{Float64}(I, N, N)
M = zeros(Float64, N, N)
for j=1:N
M[:, j] = (Model(X + delta*E[:, j], dt, F, N) - Model(X, dt, F, N))/delta
end
return M
end
# データ同化の実行
# alpha: infrationのファクタ
function step(dt, F, N, nstart, nend, interval, alpha, Xf, Xa, Pa, Pf, y, R, H)
for i=nstart:nend
# forecast
Xf[:,i] = Model(Xa[:,i-1], dt, F, N)
M = makeM(Xf[:,i], F, dt, N)
#Pf[:,:,i] = 1.7M*Pa[:,:,i-interval]*M'
Pf[:,:,i] = M*Pa[:,:,i-1]*M'*alpha
# no data assimilation step
Xa[:,i] = Xf[:,i]
Pa[:,:,i] = Pf[:,:,i]
# data assimilation step
# 5ステップおきにデータ同化を行う。
if (i-1)%interval == 0
K = Pf[:,:,i]*H'*inv(H*Pf[:,:,i]*H' + R)
Xa[:,i] = Xf[:,i] + K*(y[:,Int32((i-1)/interval)+1]-H*Xf[:,i])
Pa[:,:,i] = (I - K*H)*Pf[:,:,i]
end
end
return Xa, Xf, Pa, Pf
end
グラフ描画
function plot(dt, alpha, nstep, interval, trPa, rms_o, rms_a)
fig = plt.figure(figsize=(10, 5.))
ax1 = fig.add_subplot(121)
nstart = 1
nend = nstep
time = ((nstart:interval:nend).-1)*dt/0.2
ax1.set_title("alpha=" * string(alpha))
ax1.set_xlabel("time (day)")
ax1.set_ylabel("RMSE")
ax1.plot(time, trPa[nstart:interval:nend], label="trace(Pa)")
ax1.plot(time, rms_o[nstart:interval:nend], label="RMSE observation")
ax1.plot(time, rms_a[nstart:interval:nend], label="RMSE analysis")
ax1.set_ylim(0, 6)
ax1.legend()
ax2 = fig.add_subplot(122)
nstart = 6
nend = nstep
time = ((nstart:interval:nend).-1)*dt/0.2
ax2.set_title("alpha=" * string(alpha))
ax2.set_xlabel("time (day)")
ax2.set_ylabel("RMSE")
ax2.plot(time, trPa[nstart:interval:nend], label="trace(Pa)")
ax2.plot(time, rms_o[nstart:interval:nend], label="RMSE observation")
ax2.plot(time, rms_a[nstart:interval:nend], label="RMSE analysis")
ax2.set_ylim(0, 1.5)
ax2.legend()
plt.savefig("EKF_RMSE_alpha_"*string(alpha)*".png")
end
結果
実行してみると
alpha=1.02
@time main(alpha)
2.560520 seconds (12.07 M allocations: 5.418 GiB, 34.35% gc time)
誤差のRMSはこんな感じ。観測の誤差が大体1(分散1の誤差を真値に与えてるので当然)で,データ同化したほうが0.2程度まで下がっています。