24
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

いろんなカルマンフィルターでデータ同化を実装する (L-96)

Last updated at Posted at 2020-04-09

データ同化の概要

モデルから求まる予報と観測を混ぜていい感じの解析値を得ることで観測に含まれる誤差を減らすことです。その混ぜる割合によって、誤差がどのくらいまで軽減されるのかが変わります。
また、データ同化にはもう一つあって、観測データを予報空間に取り込む方法としても考えることができます。この考え方を同化と呼ぶこともできます。

具体的には、制御工学的なフィルターの役割を果たすこともありますし、それとは別に気象的、カオスな現象のシミュレーションと現実世界を結びつける同化する目的でも使われます。

データ同化には大きく分けて予報ステップと解析ステップの二つがあります。

  1. 一つ前の時刻の解析解からモデルを使って今の時刻の予報を求める
  2. その予報と観測から解析解を求める
    これらの実装に様々な方法が生まれます。

まず最初に、三次元変分法では、一定割合でデータを混ぜることでノイズを減らします。性能は低いです。

カルマンフィルタは、ノイズが正規分布に従う仮定の上で、その時における一番良い比率で混ぜます。

アンサンブルカルマンフィルタは、カルマンフィルタの改良でアンサンブルメンバーの分散からカルマンフィルタの近似を行います。様々な方面から研究が進められています。PO法とLETKFを紹介します。

Lorenz 1996

フィルタを作る前に、モデルを考えます。モデルでは、一時刻前から次の時刻の予報を求めるために使います。
今回考えたいものとして気象に関することで、これにはカオス性があります。大規模なモデルでもカオス性は確認できますが、今回はシンプルなモデルでまずはフィルタの効果を確かめたいと思います。
image.png
モデルはこんな感じで、uは変数、大きさは40くらいで環状のバッファにします。
$u_i=F$のとき、定常状態になり、そこから少しそれると、Fの値によってはカオス性があります。
image.png
F=8のとき、中央付近を1.001倍にしたときどのように発展するかを示しています。
ルンゲクッタ4次で解きました。
image.png

カオスとスケーリング

カオスであることには、予測が困難であり、初期値の小さなノイズが時間積分で大きくなるとき、そのモデルのカオス性が認められます。
そのため、解析値の誤差は予報を求めるモデルに誤差がなくても拡大します。今回はモデル誤差は考えません。
スピンアップに1年分のシミュレーションを行い、そこから少量のノイズを加えたときと加えないときの誤差の発達を確認します。サンプルを多くとり、平均してグラフにします。気象の世界では約2日で二倍になっていると思いますが、これがerror doubling timeになります。
image.png
今回dt=0.05でやっていて、8ステップの0.2で2倍になっているので0.2で2日でよいでしょう。

同化のテストの仕方

まず、スピンアップで1年程度空の計算をし、その後さらに1年間計算してそれを正解とします。
正解に分散1、平均0の標準正規分布に従うノイズを加え、それを観測データとします。
観測データからどの程度誤差がなくなったのかを評価することにします。

3D-VAR

3D-VARでは、一番シンプルにいつも同じ割合でデータを混ぜます。
比率をKとしましょう。また、ここでは観測空間と予報空間のグリッドは異なり、写像Hで結ばれているものとします。
予報$x^f_i$に、$K(y_i^o-H(x_i^f))$を足し合わせたものが解析解になります。ゲインが大きければ観測から多くとり、小さければ解析解に頼るようになります。
今回は、このKはカルマンゲインの最後のほうから統計的な平均をとって使うことにしました。

カルマンフィルタ

カルマンフィルタでは、ノイズが平均0の正規分布に従っている仮定の下で混ぜるとき、よい性能を発揮します。
混ぜる割合であるカルマンゲインKを求めることを目標にします。

予報ステップ

予報ステップでは、予報解とその誤差共分散行列$P^f$を求めます。
$P^f=MP^a_{i-1}M^T$で計算します。この時の$M^T$は、アジョイント行列と呼ばれますが、モデルの接線、ヤコビアンから求めることができます。
モデルの接線ですが、今回は単に時間を少し進めて求めることにしました。

解析ステップ

カルマンゲインのKには予報で求めた$P^a$のトレース和を最小にするような、カルマンゲインが最適です。
細かい計算は他所を参考にしてもらうことにして、以下のようにすると求めることができます。
$$K=P^f_iH^T(HP^f_iH^T+R)^{-1}$$
ここで、$P_i^a=(I-KH)P_i^f$も求めることができます。
なぜ求めるかというと、この解析解の共分散行列のトレース和とRMSEの関係を調べたいのでここではあえて計算しておきます。

結果の評価

結果を測るときにはRMSEを使うことにします。Root Mean Scale Errorで、エラーのユークリッド距離に相当するものになります。
また、$Tr(P^a)$とスケールをそろえておきます。
rmse-NG.png
この結果を見てみましょう。xobsは観測の評価で、分散が1付近なのでそのところを横ばいになっています。
最初のほうはxaも下がっていますが、途中から$Tr(P^a)$との差がどんどんずれていって誤差がかなり大きくなっていることがわかります。これはカルマンフィルタの線形モデルの誤差が原因で、共分散が過少に見積もられているためです。そのため、共分散膨張を使い、$P^f$を少し大きくしてあげることが必要です。
$$P^f*=(1+\delta)$$
として、δ=0.1程度で次のような結果になると思います。
rmse.png
これでうまくRMSEが0.2付近で落ち着きました。また、緑と青の線が近くに来ていることがわかると思います。

3D-VAR再び

カルマンフィルタの時の$P^f$をもとに、データの最後のほうの$P^f$の平均からゲインを計算し、Kを同じ値を用いて同化してみます。
全体を定数倍することで、少しRMSEを下げることができます。が、あまり下がりません。するとどう頑張ってもRMSEは0.4くらいまで上昇してしまいます。

アンサンブルカルマンフィルタ

KFはある意味最強なのですが、弱点もあります。まず、格子点すべてにデータが与えられることは考えられません。そのデータの与えられ方によってはすぐに解析解が発散してしまうことが知られています。
rmse-final-dens-zoomout.png
これは格子点の最後のデータをデータ同化から隠したとき、どのようにRMSEが変わるのかを試したものです。カルマンフィルタだけ急に立ち上がっていることがわかります。

それだけではなく、モデルのアジョイントをどうするのかなど様々な問題があります。

アンサンブルカルマンフィルタでは、Pfの計算に、アンサンブルから直接計算することで接線モデルの計算を不要にします。

アンサンブルカルマンフィルタの考え方は以下の通りです。
アンサンブルメンバーをそれぞれ別々に予報を行い、どれだけデータがばらけたのかで予報の分散を求めます。それをもとにデータ同化を行うというものです。

PO法

アンサンブルメンバーの生成などによってアンサンブルカルマンフィルタの実装には流派があります。この中の一つがPO法で、これは観測に観測の分散と同じだけの分散のノイズをわざと加えることで同じようなアンサンブルメンバーの同化を行えないのかという考え方です。
image.png
PO法をそのまま実装すると、サンプリングエラーによりうまく同化できないことがあります。サンプル数が少なくてうまく共分散行列が推定できなくなるということです。下の図を見てください。カルマンフィルタが理想的な誤差共分散行列です。
image.png
64メンバーだけだと対角から離れた位置にかなり強いノイズが残ってしまっています。そこで局所化を導入します。

局所化

image.png
このような関数で対角からの距離でマスクします。
image.png
σ=5のとき、
image.png
σ=3の時のマスクです。これによって、濃い青のところは0がかけられて濃い赤は1(そのまま)になります。
これは観測の局所化に相当するもので、近くのデータ点以外の相関を0にすることで、サンプリングエラーを抑えることができます。
局所化をすることでメンバー数が少なくてもうまく同化できると思います。
image.png
ちなみに16以上メンバーを増やしてもあまりRMSEは改善しなさそうです。

LETKF

image.png
LETKFでは、アンサンブルメンバーの摂動を作り出すのに(ばらつきを与え、分散を評価する)平方根を用いります。
このメリットは、乱数によらずに安定してデータを同化することができます。
また、コストの高い平方根演算や逆行列の計算を一つの対角化をもとに行うため、計算負荷も軽いことが特徴です。

LETKFの局所化は少し面倒で、Rの逆行列を用いて局所化します。
Rは観測誤差行列で、この誤差が大きくなるとその値はあまり有効にとらえられることはありません。
そのため、それぞれの格子点でデータ同化を行い(局所化の当て布)、そのなかでRの逆行列をその都度計算することで局所化を行えます。
image.png
Rの逆行列はこんな感じになります。1に近いほどデータが有効に活用され、0に近いとそのデータは無視されます。
(Rは観測誤差のままで考えて、格子点ごとにウェイトを用意して計算においても同様に局所化ができます。その時でも、必要に応じて計算を切り替えられるようにできるといいでしょう。)

LETKFは少ないメンバーでも低いRMSEを出しますが、計算時間がかなり長く、並列化の恩恵は受けやすいですが逆にそれぞれの格子点すべてでデータ同化をそれぞれで行う必要があります。

結果のグラフs

RMSEで測ります。共分散膨張のδを変化させて一番小さい値をプロットに使います。

データの最後のほうを隠したとき

rmse-final-dens.png
rmse-final-dens-zoomout.png

データを等間隔で間引いたとき

rmse-final-turn-zoom.png
rmse-final-turn.png

以上、走り書きでした。

結果と考察、今後の検討事項など

データの次元って増えれば情報量が増えますが、それに伴って正しく処理する必要性が増してくるなと思います。
観測の局所化は、「離れたところの観測には目をつぶりましょう」ということです。データの数をわざと減らすことで得られるのでその分改善する余地がなくなります。
じゃあ全く局所化はいらないかというと、そうもならなくて実際に局所化しないとあちらこちらのエラーでサンプリングノイズがかなり乗ってしまいます。疑似相関なんかといいますが、サンプル数が少ないときによく起きます。処理能力が有り余っている今なら、もっと良いアルゴリズムやより良いデータ同化手法が見つかるかもしれませんね。

また、個人的には同化に与える解析値を複数持つアンサンブルカルマンフィルタの考え方はいいと思いますが、モデルの誤差などを考えれば複数のモデルを深層学習などで得られた結果と既存のモデルを組み合わせるなどしても良い結果が得られるかもなんて勝手に考えています。
大規模コンピューティングが大好きなので、そこら辺も少しずつ研究できればなと思います。

直近でやりたいこと

共分散膨張の推定,手動チューニングはもう嫌じゃ。
少し規模を上げたモデルへの適応も考えたいところですね。そのためにはもう少し高速化もしたいところ。

#追記

共分散膨張

共分散膨張は、分散を少し大きくしてスプレッドと誤差を近くさせる操作です。少なくともスプレッドが大きい分には観測を多くとるので誤差が大きくなるで済みますが、小さくなるとより多くの解析誤差をとってしまい、フィルタの発散につながります。
あともう一つ大事なことがあって、次のグラフによく表れているので見てみてください。LETKF、観測10個の1/4まで間引いたときの1年間のデータ同化結果のRMSEと共分散膨張のグラフです。

image.png

ここでいうLowestでは、確かにRMSEで測る誤差は最小ですが、実用的ではありません。ここでは1年くらいの時間積分ですが、これは長くとれば発散すると考えられます。

理由はこの共分散の次に高いピークがあって、これは部分的に同化に失敗していることを示しています。共分散膨張を大きくすればするほどより多くの観測を取り込むようになり、フィルターが安定して動作するようになります。
ただ、共分散膨張の係数を上げると誤差が大きくなる傾向にあります。観測を多くとりこみますが、これには大きな誤差が含まれているためです。
共分散膨張の動的推定は一つ重要なファクターなのでそれについても追記したいなと思います。

動的な膨張の実装

そもそもアンサンブルカルマンフィルタにおける共分散膨張というものはどんなものでどのように導入されるのでしょうか? 共分散膨張は、アンサンブルカルマンフィルタでは、スプレッド(粒子の広がり)と解析誤差を一致させるように大きくする操作で、スプレッドが小さく解析誤差が過少に評価されることでフィルターが発散してしまうことを防ぐことが目的です。
image.png

そこで次の式を見てみます。
image.png
この式の$d$は、観測と予報の差を示しています。
これが予報誤差と観測誤差の和と一致していれば、予報の広がりが正しく評価されていると判断できます。
通常はスプレッドは過少評価されがちなので、$P^b$に$\delta$をかけてこの等式が成り立つようにします。
この$\delta$にあたる部分が共分散膨張の重要なパラメータで今、動的に決定したいものになります。

$\delta$について解けば式は求められますが、このままではノイズがすごく使い物にならないので、さらにフィルタリングすることでパラメータを推定します。
image.png
フィルターをかけなかったとき。オレンジの線が途切れていることがわかる。
image.png
フィルターをかけて少しおとなしくなった例。定数をチューニングしたときに近いRMSEを出すことができます。

参考文献

https://qiita.com/litharge3141/items/41b8dd3104413529407f
https://qiita.com/litharge3141/items/7c1c879240d6c9d46166
https://www2.atmos.umd.edu/~ekalnay/pubs/harlim_hunt05.pdf
その他、講義資料も。

追記の動的な共分散膨張の推定の参考資料
https://journals.ametsoc.org/doi/10.1175/2010MWR3570.1
https://rmets.onlinelibrary.wiley.com/doi/epdf/10.1002/qj.371
その他もろもろ。

ソースコード

ソースコード全部公開します。苦情が出たらたぶん消します。

import time
import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt
import pandas as pd
import sys
import json
import numba as nb
import Bmat
import os
import scipy


F=8
J=40
ADAY=0.2
AYEAR=365
YEARS_TO_SIMULATE=2
dt=0.05
t_spinup=ADAY*AYEAR*1
tend=AYEAR*ADAY*YEARS_TO_SIMULATE
t=np.arange(0,tend+dt,dt)
ft=np.arange(ADAY*AYEAR,tend+dt,dt)

x=np.zeros(J)
v=np.zeros(J)
RMSE=[]
xgtlist=[]
xobzlist=[]
STEP=2
FIRSTN=40
delta=0.1
H=np.identity(J)[:FIRSTN:STEP]
R=np.identity(int(FIRSTN/STEP))

printB=False

#use 5 points formula
# @nb.jit
# def differential(x,h,hr,dt):
#    return (1.0/hr)*(
#        -(1./12.)*runge4th(x+2*h,dt)
#        +(2./3.)*runge4th(x+h,dt)
#        -(2./3.)*runge4th(x-h,dt)
#        +(1./12.)*runge4th(x-2*h,dt)
#         )

@nb.jit
def differential(x,h,hr,dt):
   return (runge4th(x+h,dt)-runge4th(x,dt))/hr

@nb.jit
def invnumbajit(A):
    return np.linalg.inv(A)
@nb.jit
def solveacc(x):
    a=np.empty(J)
    for j in range(J):
        a[j]=(x[(j+1)%J]-x[(j-2+J)%J])*x[(j-1+J)%J]-x[j]+F
    return a


@nb.jit
def euler(x,dt):
    a=solveacc(x)
    return x+a*dt

@nb.jit
def runge4th(x,dt):
    a=solveacc(x)
    k=np.empty((5,J))
    k[1]=a*dt
    a=solveacc(x+k[1]*0.5)
    
    k[2]=a*dt
    a=solveacc(x+k[2]*0.5)

    k[3]=a*dt
    a=solveacc(x+k[3])
    
    k[4] = a*dt
    k[0]=(k[1]+2*k[2]+2*k[3]+k[4])/6.
    return x+k[0]


def initialize():
    global x
    x=np.ones(J)*F
    x[int(J/2)+1]*=1.001

def rmse(x):
    return np.sqrt(np.average(
        np.square(x)
    ))

def evaluate_error():
    rmsevlist=np.zeros(100)
    for j in range(10000):
        x=xgtlist[1000]+np.random.randn(J)/10000
        for i,xnow in enumerate(xgtlist[1001:1100]):
            x=runge4th(x,dt)
            diff=x-xnow
            rmsev=rmse(diff)
            rmsevlist[i]+=(rmsev)
    
    plt.title("RMSE")
    plt.yscale('log')
    plt.plot(t[:99],rmsevlist[:99]/rmsevlist[0])
    plt.show()

def runsimulation(spinup):
    global x
    initialize()
    for i in range(spinup):
        x=runge4th(x,dt) #Update x

    for i,tnow in enumerate(t):
        x=runge4th(x,dt) #Update x
        xgtlist.append(np.copy(x))

def loaddata(path):
    global F,J,ADAY,AYEAR,YEARS_TO_SIMULATE,dt,t_spinup,tend,xobzlist,xgtlist
    observation_path=path+"_obz.npy"
    groundtruth_path=path+"_gt.npy"
# load parameter
    with open(path+"cfg.json") as f:
        df = json.load(f)
        F,J,ADAY,AYEAR=df["F"],df["J"],df["ADAY"],df["AYEAR"]
        YEARS_TO_SIMULATE,dt,t_spinup,tend=df["YEARS_TO_SIMULATE"],df["dt"],df["t_spinup"],df["tend"]
        observation_path,groundtruth_path=df["observation_filepath"],df["groundtruth_filepath"]

# load experiment data
    xobzlist=np.load(observation_path)
    xgtlist=np.load(groundtruth_path)

def storedata(path):
    observation_path=path+"_obz.npy"
    groundtruth_path=path+"_gt.npy"
    param={"F":F,
           "J":J,
           "ADAY":ADAY,
           "AYEAR":AYEAR,
           "YEARS_TO_SIMULATE":YEARS_TO_SIMULATE,
           "dt":dt,
           "t_spinup":t_spinup,
           "tend":tend,
           "observation_filepath":observation_path,
           "groundtruth_filepath":groundtruth_path,
    }

    with open(path+"cfg.json","w") as f:
        json.dump(param,f)

    np.save(groundtruth_path,np.array(xgtlist))
    for i in xgtlist:
        xobzlist.append(i+np.random.randn(len(i)))
    np.save(observation_path,np.array(xobzlist))

def showmat(m):
   plt.imshow(m,interpolation='nearest',cmap='jet')
   plt.colorbar()
   plt.show()

def localizescale(sigma,i):
    r=abs(i)
    r=min(r,abs(J-r))
    return np.exp(-r*r/(2*sigma*sigma))
    

def localizematrix(sigma,H):
    m=[]
    for i in H:
        idx=(np.where(i!=0)[0][0])
        l=[]
        for j in range(J):
            l.append(localizescale(sigma,j-idx))
        m.append(np.array(l))
    return np.array(m)

def calcinvR(sigma,idx):
    nobs=H.shape[0]
    mat=np.identity(nobs)
    for idx2,j in enumerate(mat):
        idx2=(np.where(H[idx2]!=0)[0][0])
        j*=(localizescale(sigma,idx-idx2))
    return mat
    
def printmat(B):
    for i in B:
        s="["
        for j in i:
            s+=("%e, "%j)
        print(s+"],")


def observation_densitymatrix_turn(n):
    delnum = J-n
    dellist=np.linspace(0,J,delnum+2) #delete grid point list. delete except first and last point.
    matt = list(np.identity(J))
    for idx,i in enumerate(dellist[1:-1]):
        matt.pop(int(i-idx))
    return np.array(matt)

def observation_densitymatrix_dense(n):
    return np.identity(J)[:n]

plotlist=[[],[],[]]
def threedvar(scale,odm):
    B=np.array(Bmat.B)*scale
    xa=xobzlist[0]
    xf=xgtlist
    IMAT=invnumbajit(np.dot(np.dot(H,B),H.T)+R)
    K=np.dot(np.dot(B,H.T),IMAT)
    for idx,i in enumerate(ft):
        xf=runge4th(xa,dt)
        
        xa = xf + np.dot(K,np.dot(odm,xobzlist[idx])-np.dot(H,xf))
        plotlist[0].append(rmse(xa-xgtlist[idx]))

    return np.mean(plotlist[0][500:])


def KalmanFilter(delta,odm):
    Pflist=[]
    Pa = np.identity(J)*5
#    print(xgtlist[np.random.randint(int(ADAY*400/dt),int(tend/dt))])
    xa=xobzlist[0]
    xf=xgtlist

    JM=np.empty((J,J))
    for idx,i in enumerate(ft):
        xf=runge4th(xa,dt)

        for idx2,j in enumerate(np.identity(J)):
            JM[idx2]=differential(xa,j*0.001,0.001,dt)
        Pf=(1+delta)*np.dot(np.dot(JM.T,Pa),JM)
        
        K=np.dot(np.dot(Pf,H.T),invnumbajit(np.dot(np.dot(H,Pf),H.T)+R))
        xa = xf + np.dot(K,np.dot(odm,xobzlist[idx])-np.dot(H,xf))
        Pa = np.dot((np.identity(J)-np.dot(K,H)),Pf)
        Pflist.append(np.copy(Pf))

        plotlist[0].append(rmse(xa-xgtlist[idx]))
        plotlist[1].append(rmse(xgtlist[idx]-xobzlist[idx]))
        plotlist[2].append(np.sqrt(np.trace(Pa)/J))

    average=Pflist[0]
    for i in Pflist[-200:]:
        average+=i
    average /= 200

    if printB==True:
        printmat(average)

    return np.mean(plotlist[0][500:])


def LETKF(delta,odm,nmem=12):
    mask=localizematrix(3,H)
    Pa = np.identity(J)*5
#    print(xgtlist[np.random.randint(int(ADAY*400/dt),int(tend/dt))])
    xa=np.array([xobzlist[0] for i in range(nmem)])+np.random.normal(0,1,(nmem,J))
    xanext=np.empty(nmem)
    plotlist=[[],[],[]]
    Palst=[]
    JM=np.empty((J,J))
    sdelta=np.sqrt(1+delta)
    for idx,i in enumerate(ft):
        xf=np.array([runge4th(i,dt) for i in xa])
        xfa=np.mean(xf,axis=0)
        dxf=np.array([i-xfa for i in xf])*sdelta
        yfa=np.mean(np.dot(H,xf.T),axis=1)
        dyf=np.array([np.dot(H,i)-yfa for i in xf])*sdelta
        xanext=[]
        for j in range(J):
            invR=calcinvR(3,j).T
            C=np.dot(dyf,invR)
            w,v=np.linalg.eig(np.identity(nmem)*(nmem-1)+np.dot(C,dyf.T))
            w=np.real(w)
            v=np.real(v)
            p_invsq=numpy.diag(1/np.sqrt(w))
            p_inv=numpy.diag(1/w)
            Wa=v @ p_invsq @ v.T
            Was=v @ p_inv @ v.T
            
            xanext.append((np.matlib.repmat(xfa,nmem,1)+np.dot(dxf.T,
                     np.linalg.multi_dot([Was,C,
                                          np.linalg.multi_dot([H,np.matlib.repmat(xobzlist[idx]-xfa,nmem,1).T])
                     ])+np.sqrt(nmem-1)*Wa).T).T[j])
        xa=np.array(xanext).T
        rmsev=rmse(np.mean(xa,axis=0)-xgtlist[idx])
        if idx&0x1F==0:
            sys.stdout.write(".")
            sys.stdout.flush()
        plotlist[0].append(rmsev)
        Pa=np.dot(dxf.T,dxf)/(nmem-1)
        Palst.append(Pa)
        plotlist[1].append(np.sqrt(np.trace(Pa)/J))
    
    average=Palst[0]
    for i in Palst[-200:]:
        average+=i
    average /= 200

#    showmat(mask*average)
    
    # plt.plot(plotlist[0],label="rmse")
    # plt.plot(plotlist[1],label="trace pa")
    # plt.legend()
    # plt.show()
    return np.mean(plotlist[0][500:])
    
def PO(delta,odm,nmem=16):
    mask=localizematrix(3,H)
    Pa = np.identity(J)*5
#    print(xgtlist[np.random.randint(int(ADAY*400/dt),int(tend/dt))])
    xa=np.array([xobzlist[0] for i in range(nmem)])+np.random.normal(0,1,(nmem,J))
    plotlist=[[],[],[]]
    Palst=[]
    JM=np.empty((J,J))
    sdelta=1#delta
    for idx,i in enumerate(ft):
        xf=np.array([runge4th(i,dt) for i in xa])
        xfa=np.mean(xf,axis=0)
        dxf=np.array([i-xfa for i in xf])*sdelta
        yfa=np.mean(np.dot(H,xf.T),axis=1)
        dyf=np.array([np.dot(H,i)-yfa for i in xf])*sdelta
        K=mask.T*np.linalg.multi_dot([dxf.T,invnumbajit(np.identity(nmem)+np.linalg.multi_dot([dyf,R,dyf.T])),dyf,R])/np.sqrt(nmem-1)
        t=np.dot(H,np.array(np.matlib.repmat(xobzlist[idx],nmem,1)+np.random.normal(0,1,(nmem,J))*(1+delta)-xf).T)
        xa=xf+np.dot(K,t).T
        plotlist[0].append(rmse(np.mean(xa,axis=0)-xgtlist[idx]))
        Pa=np.dot(dxf.T,dxf)/(nmem-1)
        Palst.append(Pa)
        plotlist[1].append(np.sqrt(np.trace(Pa)/J))
    
    average=Palst[0]
    for i in Palst[-200:]:
        average+=i
    average /= 200

    # showmat(mask*average)
    
    # plt.plot(plotlist[0],label="rmse")
    # plt.plot(plotlist[1],label="trace pa")
    # plt.legend()
    # plt.show()
    return np.mean(plotlist[0][500:])

def main():
    global x,H,R,printB
    plt.grid(which = "major", axis = "x", color = "blue", alpha = 0.8, linestyle = "--", linewidth = 1)
    plt.grid(which = "major", axis = "y", color = "green", alpha = 0.8, linestyle = "--", linewidth = 1)
    plt.xlabel('observation data count')
    plt.ylabel('RMSE')

# if you want simulation data, uncomment and run simulation.
    # print("running simulation"+str(time.time()-begintime))
    # runsimulation(int(AYEAR*ADAY/dt))
    # print("storing data"+str(time.time()-begintime))
    # storedata("data/result")
    loaddata("data/result")

    rmselst=[]
    nlist=np.arange(10,40+1,1)
    deltaset=np.arange(0.04,0.4,0.01)
    deltaset2=np.arange(0.,0.2,0.01)
    scaleset=np.arange(0.1,2,0.2)
    obzspace=(observation_densitymatrix_turn,observation_densitymatrix_dense)
    filters=(LETKF,PO,KalmanFilter,threedvar)
    for densmat in obzspace:
        print("dens:%s"%densmat.__name__)
        for fil in filters:
            print("filters:%s"%fil.__name__)
            if fil == threedvar:
                params=scaleset
            elif fil == LETKF:
                params=deltaset2
            else:
                params=deltaset
            
            rmseminlst=[]
            for n in nlist:
                rmselst=[]
                print("obzdata :%d"%n)
                matt=densmat(n)
                R=np.identity(n)
                H=matt
                #try finding minimum rmse for each deltaset
                for i in params:
                    if True: # any error to be caught and fail if you set False(debug only).
                        try:
                            rmse=fil(i,matt)
                        except:
                            # 999 shows error
                            rmse=999
                    else:
                        rmse=fil(i,matt)
                    pair=(i,rmse)
                    rmselst.append(rmse)
                    print("param:%3f,rmse:%3f"%pair)
                    plotlist[0],plotlist[1],plotlist[2]=[],[],[]

                print("finished in "+str(time.time()-begintime))
                rmseminlst.append(min(rmselst))

            plt.plot(nlist,rmseminlst,label="%s"%fil.__name__)

        plt.grid(which = "major", axis = "x", color = "blue", alpha = 0.8, linestyle = "--", linewidth = 1)
        plt.grid(which = "major", axis = "y", color = "green", alpha = 0.8, linestyle = "--", linewidth = 1)
        plt.xlabel('observation data count')
        plt.ylabel('RMSE')
        plt.title("minimum RMSE")
        plt.ylim(0,2)
        plt.legend()
        plt.show()


if __name__ == '__main__':
    main()
24
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
22

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?