Python
カルマンフィルター
PyTorch

pytorchでカルマンフィルタ

pytorchでカルマンフィルタを実装してみた.

前に
tensorflowでカルマンフィルタ
を書いたら,

という反応があって面白い.

でも上記pytorch実装は,計算グラフではなくてtensor計算しかしていないような気がしたので,
計算グラフ版pytorchを書いてみた.でもpytorchを初めて書いたので,計算グラフになっているのかどうか不明...(autogradしてみれば分かるけど)

  • nnクラスで1ループを実装,forward()で呼び出し
  • 内部で変数を覚えておくのはselfで簡単

numpyバージョン v.s. pytorchバージョン.

コード

pytorch
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
from torch.autograd import Variable


class Kalman(nn.Module):

    def __init__(self, mu0, V0, A, b, Q, R):
        super(Kalman, self).__init__()
        self.A = Variable(torch.from_numpy(A).double(), requires_grad=False)
        self.b = Variable(torch.from_numpy(b).double(), requires_grad=False)
        self.Q = Variable(torch.from_numpy(Q).double(), requires_grad=False)
        self.R = Variable(torch.from_numpy(R).double(), requires_grad=False)

        self.mu = Variable(torch.from_numpy(mu0).double(), requires_grad=False)
        self.V = Variable(torch.from_numpy(V0).double(), requires_grad=False)

    def forward(self, x):

        x = Variable(torch.from_numpy(x).double(), requires_grad=False)

        self.mu = self.A @ self.mu + self.b
        self.V = self.A @ self.V @ self.A.t() + self.Q

        x_predict = self.mu.data.numpy()
        V_predict = self.V.data.numpy()

        S = self.V + self.R
        K = self.V @ torch.inverse(S)

        self.mu = self.mu + K @ (x - self.mu)
        self.V = self.V - K @ self.V

        x_filter = self.mu.data.numpy()
        V_filter = self.V.data.numpy()

        return x_predict, V_predict, x_filter, V_filter


def KalmanFilter_pt(T, x_obs, mu0, V0, A, b, Q, R):

    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu0
    x_filter[0] = mu0
    V_predict[0] = V0
    V_filter[0] = V0

    start = time.time()

    kalman = Kalman(mu0, V0, A, b, Q, R)

    for t in range(1, T):

        x_predict[t], V_predict[t], x_filter[t], V_filter[t] = kalman(x_obs[t])

    elapsed_time = time.time() - start
    print("torch: ", elapsed_time)

    return x_predict, V_predict, x_filter, V_filter


def KalmanFilter(T, x_obs, mu0, V0, A, b, Q, R):
    """
    こちらは普通のカルマンフィルタ.np実装.
    """

    mu = mu0
    V = V0

    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu.transpose()
    x_filter[0] = mu
    V_predict[0] = V
    V_filter[0] = V

    start = time.time()

    for t in range(1, T):

        mu_ = A @ mu + b
        V_ = A @ V @ A.transpose() + Q

        x_predict[t] = mu_
        V_predict[t] = V_

        S = V_ + R
        K = V_ @ np.linalg.inv(S)

        mu = mu_ + K @ (x_obs[t] - mu_)
        V = V_ - K @ V_

        x_filter[t] = mu
        V_filter[t] = V

    elapsed_time = time.time() - start
    print("numpy:      ", elapsed_time)

    return x_predict, V_predict, x_filter, V_filter


def main():

    # データ作成
    T = 20
    mu0 = np.array([100, 100])
    V0 = np.array([[10, 0],
                   [0, 10]])

    A = np.array([[1.001, 0.001],
                  [0, 0.99]])
    b = np.array([5, 10])
    Q = np.array([[20, 0],
                  [0, 20]])
    R = np.array([[20, 0],
                  [0, 20]])

    rvq = np.random.multivariate_normal(np.zeros(2), Q, T)
    rvr = np.random.multivariate_normal(np.zeros(2), R, T)
    obs = np.zeros((T, 2))
    obs[0] = mu0
    for i in range(1, T):
        obs[i] = A @ obs[i-1] + b + rvq[i] + rvr[i]
    # 作成終わり

    x_predict, V_predict, x_filter, V_filter = \
        KalmanFilter(T, obs, mu0, V0, A, b, Q, R)
    print(x_filter)

    x_predict_tf, V_predict_tf, x_filter_tf, V_filter_tf = \
        KalmanFilter_pt(T, obs, mu0, V0, A, b, Q, R)
    print(x_filter_tf)

    fig = plt.figure(figsize=(16, 9))
    ax = fig.gca()

    ax.scatter(obs[:, 0], obs[:, 1],
               s=10, alpha=1, marker="o", color='w', edgecolor='k')
    ax.plot(obs[:, 0], obs[:, 1],
            alpha=0.5, lw=1, color='k')

    ax.scatter(x_filter[:, 0], x_filter[:, 1],
               s=10, alpha=1, marker="o", color='r')
    ax.plot(x_filter[:, 0], x_filter[:, 1],
            alpha=0.5, lw=1, color='r')

    ax.scatter(x_filter_tf[:, 0], x_filter_tf[:, 1],
               s=10, alpha=1, marker="o", color='m')
    ax.plot(x_filter_tf[:, 0], x_filter_tf[:, 1],
            alpha=0.5, lw=1, color='m')

    plt.show()


if __name__ == "__main__":
    main()

結果

当然,それほど速くない.(on macos, CPU pytorch 0.3.1.post2 anaconda python 3.6)

numpy:       0.000392913818359375
torch:  0.002187013626098633

でも書きやすさはtensorflowよりずいぶんまし.