More than 5 years have passed since last update.


Last updated at Posted at 2018-03-12


import keras
from keras import backend as K
from keras.layers import Input
from keras.models import Model

def KalmanFilter_keras(T, x_obs, mu0, V0, A, b, Q, R):
    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu0
    x_filter[0] = mu0
    V_predict[0] = V0
    V_filter[0] = V0

    class KalmanLayer(keras.layers.Layer):
        def __init__(self, dim, A, b, Q, R, V0, **kwargs):
            super(KalmanLayer, self).__init__(**kwargs)
            self.dim = dim
            self.A = K.constant(A)
            self.b = K.constant(b)
            self.Q = K.constant(Q)
            self.R = K.constant(R)
            self.V0 = K.constant(V0.reshape((1,dim,dim))) # バッチ次元を加えておく

        def call(self, x):
            last_output, outputs, states = K.rnn(self.step, x[:,1:], initial_states=[x[:,0],self.V0]) # はじめの観測値を初期値、それ以降の観測値を入力値とする
            return outputs

        def compute_output_shape(self, input_shape):
            return (input_shape[0],self.dim)

        def step(self, inputs, states):
            x_obs = inputs
            mu, V = states

            mu_ = K.dot(mu,K.transpose(self.A)) + self.b # 転置のまま計算 shape: (batch,d)
            VAt = K.permute_dimensions(K.dot(V,K.transpose(self.A)),(1,2,0)) # shape: (d,d,batch)
            V_ = K.permute_dimensions(K.dot(self.A,VAt),(2,0,1)) + self.Q # shape: (batch,d,d)

            S = V_ + self.R
            K_ = K.batch_dot( V_, tf.matrix_inverse(S) )

            mu = mu_ + K.batch_dot(K_, x_obs-mu_)
            V = V_ - K.batch_dot(K_,V_)
            ret = K.concatenate( [ K.batch_flatten(f) for f in [mu_,V_,mu,V] ] ) # K.rnnはリスト出力ができないため、flatten -> concat して出力する
            new_states = [ mu, V ]

            return ret, new_states

    dim = 2
    x = Input(shape=(None,dim))
    y = KalmanLayer(dim,A,b,Q,R,V0)(x)
    model = Model(inputs=[x],outputs=[y])

    start = time.time()

    ret = model.predict( x_obs.reshape((1,-1,dim)) ) # バッチサイズ1として計算
    x_predict[1:] = ret[0,:,0:2]
    V_predict[1:] = ret[0,:,2:6].reshape((-1,2,2))
    x_filter[1:]  = ret[0,:,6:8]
    V_filter[1:]  = ret[0,:,8:12].reshape((-1,2,2))

    elapsed_time = time.time() - start
    print("keras:      ", elapsed_time)
    return x_predict, V_predict, x_filter, V_filter

実行時間です(Ubuntu CPU実行時)

numpy:       0.001127481460571289
tensorflow:  0.008805990219116211
keras:       0.018843889236450195

