22
23

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tensorflowでカルマンフィルタ

Last updated at Posted at 2018-03-07

Tensorflowでカルマンフィルタを実装してみた.

  • 計算グラフで1ループを実装
  • session.runで呼び出し
  • tensorflowの内部で変数を覚えておく
  • Tips:
  • assignで内部変数に保存する
  • assignの返り値のoperatorをsess.runで呼び出す

numpyバージョン v.s. tensorflowバージョン.

コード

kalman.py
import numpy as np
import matplotlib.pyplot as plt
import time
import tensorflow as tf


def KalmanFilter_tf(T, x_obs, mu0, V0, A, b, Q, R):
    """
    tensorflowで実装したカルマンフィルタ.
    """

    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu0
    x_filter[0] = mu0
    V_predict[0] = V0
    V_filter[0] = V0

    g = tf.Graph()

    with g.as_default():  # 計算グラフ準備

        obs_ = tf.placeholder(tf.float32, name="obs", shape=(2, 1))

        A_ = tf.placeholder(tf.float32, name="A", shape=(2, 2))
        b_ = tf.placeholder(tf.float32, name="b", shape=(2, 1))
        Q_ = tf.placeholder(tf.float32, name="Q", shape=(2, 2))
        R_ = tf.placeholder(tf.float32, name="R", shape=(2, 2))

        mu0_ = tf.placeholder(tf.float32, name="mu0", shape=(2, 1))
        V0_ = tf.placeholder(tf.float32, name="V0", shape=(2, 2))

        mu = tf.Variable(tf.zeros((2, 1)), dtype=tf.float32, name="mu")
        V = tf.Variable(tf.zeros((2, 2)), dtype=tf.float32, name="V")

        mu0_init = tf.assign(mu, mu0_)
        V0_init = tf.assign(V, V0_)

        mu_ = A_ @ mu + b_
        V_ = A_ @ V @ tf.transpose(A_) + Q_

        S = V_ + R_
        K = V_ @ tf.matrix_inverse(S)

        # 内部変数で保存しておく
        mu_op = tf.assign(mu, mu_ + K @ (obs_ - mu_))
        V_op = tf.assign(V, V_ - K @ V_)

    with tf.Session(graph=g) as sess:

        # 初期化
        m, v = sess.run([mu0_init, V0_init],
                        feed_dict={mu0_: mu0.reshape((2, 1)), V0_: V0})

        start = time.time()

        for t in range(1, T):

            # 各時刻tでsess.runで呼び出す
            m, v = sess.run([mu_op, V_op],
                            feed_dict={A_: A, b_: b.reshape((2, 1)),
                                       Q_: Q, R_: R,
                                       obs_: x_obs[t].reshape((2, 1))
                                       })

            x_filter[t], V_filter[t] = m.transpose(), v

        elapsed_time = time.time() - start
        print("tensorflow: ", elapsed_time)

    return None, None, x_filter, V_filter


def KalmanFilter(T, x_obs, mu0, V0, A, b, Q, R):
    """
    こちらは普通のカルマンフィルタ.np実装.
    """

    mu = mu0
    V = V0

    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu.transpose()
    x_filter[0] = mu
    V_predict[0] = V
    V_filter[0] = V

    start = time.time()

    for t in range(1, T):

        mu_ = A @ mu + b
        V_ = A @ V @ A.transpose() + Q

        x_predict[t] = mu_
        V_predict[t] = V_

        S = V_ + R
        K = V_ @ np.linalg.inv(S)

        mu = mu_ + K @ (x_obs[t] - mu_)
        V = V_ - K @ V_

        x_filter[t] = mu
        V_filter[t] = V

    elapsed_time = time.time() - start
    print("numpy:      ", elapsed_time)

    return x_predict, V_predict, x_filter, V_filter


def main():

    # データ作成
    T = 20
    mu0 = np.array([100, 100])
    V0 = np.array([[10, 0],
                   [0, 10]])

    A = np.array([[1.001, 0.001],
                  [0, 0.99]])
    b = np.array([5, 10])
    Q = np.array([[20, 0],
                  [0, 20]])
    R = np.array([[20, 0],
                  [0, 20]])

    rvq = np.random.multivariate_normal(np.zeros(2), Q, T)
    rvr = np.random.multivariate_normal(np.zeros(2), R, T)
    obs = np.zeros((T, 2))
    obs[0] = mu0
    for i in range(1, T):
        obs[i] = A @ obs[i-1] + b + rvq[i] + rvr[i]
    # 作成終わり

    x_predict, V_predict, x_filter, V_filter = \
        KalmanFilter(T, obs, mu0, V0, A, b, Q, R)
    print(x_filter)

    x_predict_tf, V_predict_tf, x_filter_tf, V_filter_tf = \
        KalmanFilter_tf(T, obs, mu0, V0, A, b, Q, R)
    print(x_filter_tf)

    fig = plt.figure(figsize=(16, 9))
    ax = fig.gca()

    ax.scatter(obs[:, 0], obs[:, 1],
               s=10, alpha=1, marker="o", color='w', edgecolor='k')
    ax.plot(obs[:, 0], obs[:, 1],
            alpha=0.5, lw=1, color='k')

    ax.scatter(x_filter[:, 0], x_filter[:, 1],
               s=10, alpha=1, marker="o", color='r')
    ax.plot(x_filter[:, 0], x_filter[:, 1],
            alpha=0.5, lw=1, color='r')

    ax.scatter(x_filter_tf[:, 0], x_filter_tf[:, 1],
               s=10, alpha=1, marker="o", color='m')
    ax.plot(x_filter_tf[:, 0], x_filter_tf[:, 1],
            alpha=0.5, lw=1, color='m')

    plt.show()


if __name__ == "__main__":
    main()

結果

on Mac

当然,それほど速くない.(on macos, CPU tf 1.4.1, anaconda python 3.6)

numpy:       0.0004937648773193359
tensorflow:  0.00774383544921875

on Ubuntu

GPUならもっと遅い.(Ubuntu 16.04, GPU tf 1.6.0, Geforce 1080Ti, anaconda python 3.6)

numpy:       0.002942800521850586
(...tensorflow messages...)
tensorflow:  0.6563735008239746

(Ubuntu 16.04, CPU tf 1.6.0, anaconda python 3.6)

numpy:       0.002834796905517578
tensorflow:  0.027823209762573242
22
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
23

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?