7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

kerasでカルマンフィルタ

Last updated at Posted at 2018-03-12

下記の記事が面白かったので、参考にしつつKeras版を実装してみました。

  1. https://qiita.com/tttamaki/items/3bd1fdab5ef5bbf0d4be
  2. http://s0sem0y.hatenablog.com/entry/2018/03/08/131445

K.rnnを使って書きましたが、逆行列はtensorflowの関数を使いました。
K.rnnはバッチ次元を前提としているため少しややこしいですが、個人的にはtensorflowよりは読み書きしやすいと思います。今後ニューラルネットと組み合わせることも意識できますのでこれで良しとします。
tensorflowであればtf.nn.dynamic_rnnを使えば同じ要領で書けるはずです。

表示・評価部分は参照記事1と同じですので、関数部分のみコード記載します:


import keras
from keras import backend as K
from keras.layers import Input
from keras.models import Model

def KalmanFilter_keras(T, x_obs, mu0, V0, A, b, Q, R):
    """
    kerasで実装したカルマンフィルタ.
    """
    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu0
    x_filter[0] = mu0
    V_predict[0] = V0
    V_filter[0] = V0

    class KalmanLayer(keras.layers.Layer):
        def __init__(self, dim, A, b, Q, R, V0, **kwargs):
            super(KalmanLayer, self).__init__(**kwargs)
            self.dim = dim
            self.A = K.constant(A)
            self.b = K.constant(b)
            self.Q = K.constant(Q)
            self.R = K.constant(R)
            self.V0 = K.constant(V0.reshape((1,dim,dim))) # バッチ次元を加えておく

        def call(self, x):
            last_output, outputs, states = K.rnn(self.step, x[:,1:], initial_states=[x[:,0],self.V0]) # はじめの観測値を初期値、それ以降の観測値を入力値とする
            return outputs

        def compute_output_shape(self, input_shape):
            return (input_shape[0],self.dim)

        def step(self, inputs, states):
            x_obs = inputs
            mu, V = states

            mu_ = K.dot(mu,K.transpose(self.A)) + self.b # 転置のまま計算 shape: (batch,d)
            VAt = K.permute_dimensions(K.dot(V,K.transpose(self.A)),(1,2,0)) # shape: (d,d,batch)
            V_ = K.permute_dimensions(K.dot(self.A,VAt),(2,0,1)) + self.Q # shape: (batch,d,d)

            S = V_ + self.R
            K_ = K.batch_dot( V_, tf.matrix_inverse(S) )

            mu = mu_ + K.batch_dot(K_, x_obs-mu_)
            V = V_ - K.batch_dot(K_,V_)
            
            ret = K.concatenate( [ K.batch_flatten(f) for f in [mu_,V_,mu,V] ] ) # K.rnnはリスト出力ができないため、flatten -> concat して出力する
            new_states = [ mu, V ]

            return ret, new_states

    dim = 2
    x = Input(shape=(None,dim))
    y = KalmanLayer(dim,A,b,Q,R,V0)(x)
    model = Model(inputs=[x],outputs=[y])

    start = time.time()

    ret = model.predict( x_obs.reshape((1,-1,dim)) ) # バッチサイズ1として計算
    x_predict[1:] = ret[0,:,0:2]
    V_predict[1:] = ret[0,:,2:6].reshape((-1,2,2))
    x_filter[1:]  = ret[0,:,6:8]
    V_filter[1:]  = ret[0,:,8:12].reshape((-1,2,2))

    elapsed_time = time.time() - start
    print("keras:      ", elapsed_time)
    
    return x_predict, V_predict, x_filter, V_filter

実行時間です(Ubuntu CPU実行時)

numpy:       0.001127481460571289
(省略)
tensorflow:  0.008805990219116211
keras:       0.018843889236450195
7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?