Tensorflowでカルマンフィルタを実装してみた.
- 計算グラフで1ループを実装
- session.runで呼び出し
- tensorflowの内部で変数を覚えておく
- Tips:
- assignで内部変数に保存する
- assignの返り値のoperatorをsess.runで呼び出す
numpyバージョン v.s. tensorflowバージョン.
コード
kalman.py
import numpy as np
import matplotlib.pyplot as plt
import time
import tensorflow as tf
def KalmanFilter_tf(T, x_obs, mu0, V0, A, b, Q, R):
"""
tensorflowで実装したカルマンフィルタ.
"""
x_predict = np.zeros((T, 2))
x_filter = np.zeros((T, 2))
V_predict = np.zeros((T, 2, 2))
V_filter = np.zeros((T, 2, 2))
x_predict[0] = mu0
x_filter[0] = mu0
V_predict[0] = V0
V_filter[0] = V0
g = tf.Graph()
with g.as_default(): # 計算グラフ準備
obs_ = tf.placeholder(tf.float32, name="obs", shape=(2, 1))
A_ = tf.placeholder(tf.float32, name="A", shape=(2, 2))
b_ = tf.placeholder(tf.float32, name="b", shape=(2, 1))
Q_ = tf.placeholder(tf.float32, name="Q", shape=(2, 2))
R_ = tf.placeholder(tf.float32, name="R", shape=(2, 2))
mu0_ = tf.placeholder(tf.float32, name="mu0", shape=(2, 1))
V0_ = tf.placeholder(tf.float32, name="V0", shape=(2, 2))
mu = tf.Variable(tf.zeros((2, 1)), dtype=tf.float32, name="mu")
V = tf.Variable(tf.zeros((2, 2)), dtype=tf.float32, name="V")
mu0_init = tf.assign(mu, mu0_)
V0_init = tf.assign(V, V0_)
mu_ = A_ @ mu + b_
V_ = A_ @ V @ tf.transpose(A_) + Q_
S = V_ + R_
K = V_ @ tf.matrix_inverse(S)
# 内部変数で保存しておく
mu_op = tf.assign(mu, mu_ + K @ (obs_ - mu_))
V_op = tf.assign(V, V_ - K @ V_)
with tf.Session(graph=g) as sess:
# 初期化
m, v = sess.run([mu0_init, V0_init],
feed_dict={mu0_: mu0.reshape((2, 1)), V0_: V0})
start = time.time()
for t in range(1, T):
# 各時刻tでsess.runで呼び出す
m, v = sess.run([mu_op, V_op],
feed_dict={A_: A, b_: b.reshape((2, 1)),
Q_: Q, R_: R,
obs_: x_obs[t].reshape((2, 1))
})
x_filter[t], V_filter[t] = m.transpose(), v
elapsed_time = time.time() - start
print("tensorflow: ", elapsed_time)
return None, None, x_filter, V_filter
def KalmanFilter(T, x_obs, mu0, V0, A, b, Q, R):
"""
こちらは普通のカルマンフィルタ.np実装.
"""
mu = mu0
V = V0
x_predict = np.zeros((T, 2))
x_filter = np.zeros((T, 2))
V_predict = np.zeros((T, 2, 2))
V_filter = np.zeros((T, 2, 2))
x_predict[0] = mu.transpose()
x_filter[0] = mu
V_predict[0] = V
V_filter[0] = V
start = time.time()
for t in range(1, T):
mu_ = A @ mu + b
V_ = A @ V @ A.transpose() + Q
x_predict[t] = mu_
V_predict[t] = V_
S = V_ + R
K = V_ @ np.linalg.inv(S)
mu = mu_ + K @ (x_obs[t] - mu_)
V = V_ - K @ V_
x_filter[t] = mu
V_filter[t] = V
elapsed_time = time.time() - start
print("numpy: ", elapsed_time)
return x_predict, V_predict, x_filter, V_filter
def main():
# データ作成
T = 20
mu0 = np.array([100, 100])
V0 = np.array([[10, 0],
[0, 10]])
A = np.array([[1.001, 0.001],
[0, 0.99]])
b = np.array([5, 10])
Q = np.array([[20, 0],
[0, 20]])
R = np.array([[20, 0],
[0, 20]])
rvq = np.random.multivariate_normal(np.zeros(2), Q, T)
rvr = np.random.multivariate_normal(np.zeros(2), R, T)
obs = np.zeros((T, 2))
obs[0] = mu0
for i in range(1, T):
obs[i] = A @ obs[i-1] + b + rvq[i] + rvr[i]
# 作成終わり
x_predict, V_predict, x_filter, V_filter = \
KalmanFilter(T, obs, mu0, V0, A, b, Q, R)
print(x_filter)
x_predict_tf, V_predict_tf, x_filter_tf, V_filter_tf = \
KalmanFilter_tf(T, obs, mu0, V0, A, b, Q, R)
print(x_filter_tf)
fig = plt.figure(figsize=(16, 9))
ax = fig.gca()
ax.scatter(obs[:, 0], obs[:, 1],
s=10, alpha=1, marker="o", color='w', edgecolor='k')
ax.plot(obs[:, 0], obs[:, 1],
alpha=0.5, lw=1, color='k')
ax.scatter(x_filter[:, 0], x_filter[:, 1],
s=10, alpha=1, marker="o", color='r')
ax.plot(x_filter[:, 0], x_filter[:, 1],
alpha=0.5, lw=1, color='r')
ax.scatter(x_filter_tf[:, 0], x_filter_tf[:, 1],
s=10, alpha=1, marker="o", color='m')
ax.plot(x_filter_tf[:, 0], x_filter_tf[:, 1],
alpha=0.5, lw=1, color='m')
plt.show()
if __name__ == "__main__":
main()
結果
on Mac
当然,それほど速くない.(on macos, CPU tf 1.4.1, anaconda python 3.6)
numpy: 0.0004937648773193359
tensorflow: 0.00774383544921875
on Ubuntu
GPUならもっと遅い.(Ubuntu 16.04, GPU tf 1.6.0, Geforce 1080Ti, anaconda python 3.6)
numpy: 0.002942800521850586
(...tensorflow messages...)
tensorflow: 0.6563735008239746
(Ubuntu 16.04, CPU tf 1.6.0, anaconda python 3.6)
numpy: 0.002834796905517578
tensorflow: 0.027823209762573242