23

More than 5 years have passed since last update.

# tensorflowでカルマンフィルタ

Last updated at Posted at 2018-03-07

Tensorflowでカルマンフィルタを実装してみた．

• 計算グラフで1ループを実装
• session.runで呼び出し
• tensorflowの内部で変数を覚えておく
• Tips:
• assignで内部変数に保存する
• assignの返り値のoperatorをsess.runで呼び出す

numpyバージョン v.s. tensorflowバージョン．

# コード

kalman.py
``````import numpy as np
import matplotlib.pyplot as plt
import time
import tensorflow as tf

def KalmanFilter_tf(T, x_obs, mu0, V0, A, b, Q, R):
"""
tensorflowで実装したカルマンフィルタ．
"""

x_predict = np.zeros((T, 2))
x_filter = np.zeros((T, 2))
V_predict = np.zeros((T, 2, 2))
V_filter = np.zeros((T, 2, 2))

x_predict[0] = mu0
x_filter[0] = mu0
V_predict[0] = V0
V_filter[0] = V0

g = tf.Graph()

with g.as_default():  # 計算グラフ準備

obs_ = tf.placeholder(tf.float32, name="obs", shape=(2, 1))

A_ = tf.placeholder(tf.float32, name="A", shape=(2, 2))
b_ = tf.placeholder(tf.float32, name="b", shape=(2, 1))
Q_ = tf.placeholder(tf.float32, name="Q", shape=(2, 2))
R_ = tf.placeholder(tf.float32, name="R", shape=(2, 2))

mu0_ = tf.placeholder(tf.float32, name="mu0", shape=(2, 1))
V0_ = tf.placeholder(tf.float32, name="V0", shape=(2, 2))

mu = tf.Variable(tf.zeros((2, 1)), dtype=tf.float32, name="mu")
V = tf.Variable(tf.zeros((2, 2)), dtype=tf.float32, name="V")

mu0_init = tf.assign(mu, mu0_)
V0_init = tf.assign(V, V0_)

mu_ = A_ @ mu + b_
V_ = A_ @ V @ tf.transpose(A_) + Q_

S = V_ + R_
K = V_ @ tf.matrix_inverse(S)

# 内部変数で保存しておく
mu_op = tf.assign(mu, mu_ + K @ (obs_ - mu_))
V_op = tf.assign(V, V_ - K @ V_)

with tf.Session(graph=g) as sess:

# 初期化
m, v = sess.run([mu0_init, V0_init],
feed_dict={mu0_: mu0.reshape((2, 1)), V0_: V0})

start = time.time()

for t in range(1, T):

# 各時刻tでsess.runで呼び出す
m, v = sess.run([mu_op, V_op],
feed_dict={A_: A, b_: b.reshape((2, 1)),
Q_: Q, R_: R,
obs_: x_obs[t].reshape((2, 1))
})

x_filter[t], V_filter[t] = m.transpose(), v

elapsed_time = time.time() - start
print("tensorflow: ", elapsed_time)

return None, None, x_filter, V_filter

def KalmanFilter(T, x_obs, mu0, V0, A, b, Q, R):
"""
こちらは普通のカルマンフィルタ．np実装．
"""

mu = mu0
V = V0

x_predict = np.zeros((T, 2))
x_filter = np.zeros((T, 2))
V_predict = np.zeros((T, 2, 2))
V_filter = np.zeros((T, 2, 2))

x_predict[0] = mu.transpose()
x_filter[0] = mu
V_predict[0] = V
V_filter[0] = V

start = time.time()

for t in range(1, T):

mu_ = A @ mu + b
V_ = A @ V @ A.transpose() + Q

x_predict[t] = mu_
V_predict[t] = V_

S = V_ + R
K = V_ @ np.linalg.inv(S)

mu = mu_ + K @ (x_obs[t] - mu_)
V = V_ - K @ V_

x_filter[t] = mu
V_filter[t] = V

elapsed_time = time.time() - start
print("numpy:      ", elapsed_time)

return x_predict, V_predict, x_filter, V_filter

def main():

# データ作成
T = 20
mu0 = np.array([100, 100])
V0 = np.array([[10, 0],
[0, 10]])

A = np.array([[1.001, 0.001],
[0, 0.99]])
b = np.array([5, 10])
Q = np.array([[20, 0],
[0, 20]])
R = np.array([[20, 0],
[0, 20]])

rvq = np.random.multivariate_normal(np.zeros(2), Q, T)
rvr = np.random.multivariate_normal(np.zeros(2), R, T)
obs = np.zeros((T, 2))
obs[0] = mu0
for i in range(1, T):
obs[i] = A @ obs[i-1] + b + rvq[i] + rvr[i]
# 作成終わり

x_predict, V_predict, x_filter, V_filter = \
KalmanFilter(T, obs, mu0, V0, A, b, Q, R)
print(x_filter)

x_predict_tf, V_predict_tf, x_filter_tf, V_filter_tf = \
KalmanFilter_tf(T, obs, mu0, V0, A, b, Q, R)
print(x_filter_tf)

fig = plt.figure(figsize=(16, 9))
ax = fig.gca()

ax.scatter(obs[:, 0], obs[:, 1],
s=10, alpha=1, marker="o", color='w', edgecolor='k')
ax.plot(obs[:, 0], obs[:, 1],
alpha=0.5, lw=1, color='k')

ax.scatter(x_filter[:, 0], x_filter[:, 1],
s=10, alpha=1, marker="o", color='r')
ax.plot(x_filter[:, 0], x_filter[:, 1],
alpha=0.5, lw=1, color='r')

ax.scatter(x_filter_tf[:, 0], x_filter_tf[:, 1],
s=10, alpha=1, marker="o", color='m')
ax.plot(x_filter_tf[:, 0], x_filter_tf[:, 1],
alpha=0.5, lw=1, color='m')

plt.show()

if __name__ == "__main__":
main()
``````

# 結果

## on Mac

``````numpy:       0.0004937648773193359
tensorflow:  0.00774383544921875
``````

## on Ubuntu

GPUならもっと遅い．(Ubuntu 16.04, GPU tf 1.6.0, Geforce 1080Ti, anaconda python 3.6）

``````numpy:       0.002942800521850586
(...tensorflow messages...)
tensorflow:  0.6563735008239746
``````

(Ubuntu 16.04, CPU tf 1.6.0, anaconda python 3.6)

``````numpy:       0.002834796905517578
tensorflow:  0.027823209762573242
``````

Register as a new user and use Qiita more conveniently

1. You get articles that match your needs
2. You can efficiently read back useful information
What you can do with signing up
23