LoginSignup
14
10

More than 5 years have passed since last update.

pytorchでカルマンフィルタ

Posted at

pytorchでカルマンフィルタを実装してみた.

前に
tensorflowでカルマンフィルタ
を書いたら,

という反応があって面白い.

でも上記pytorch実装は,計算グラフではなくてtensor計算しかしていないような気がしたので,
計算グラフ版pytorchを書いてみた.でもpytorchを初めて書いたので,計算グラフになっているのかどうか不明...(autogradしてみれば分かるけど)

  • nnクラスで1ループを実装,forward()で呼び出し
  • 内部で変数を覚えておくのはselfで簡単

numpyバージョン v.s. pytorchバージョン.

コード

pytorch
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
from torch.autograd import Variable


class Kalman(nn.Module):

    def __init__(self, mu0, V0, A, b, Q, R):
        super(Kalman, self).__init__()
        self.A = Variable(torch.from_numpy(A).double(), requires_grad=False)
        self.b = Variable(torch.from_numpy(b).double(), requires_grad=False)
        self.Q = Variable(torch.from_numpy(Q).double(), requires_grad=False)
        self.R = Variable(torch.from_numpy(R).double(), requires_grad=False)

        self.mu = Variable(torch.from_numpy(mu0).double(), requires_grad=False)
        self.V = Variable(torch.from_numpy(V0).double(), requires_grad=False)

    def forward(self, x):

        x = Variable(torch.from_numpy(x).double(), requires_grad=False)

        self.mu = self.A @ self.mu + self.b
        self.V = self.A @ self.V @ self.A.t() + self.Q

        x_predict = self.mu.data.numpy()
        V_predict = self.V.data.numpy()

        S = self.V + self.R
        K = self.V @ torch.inverse(S)

        self.mu = self.mu + K @ (x - self.mu)
        self.V = self.V - K @ self.V

        x_filter = self.mu.data.numpy()
        V_filter = self.V.data.numpy()

        return x_predict, V_predict, x_filter, V_filter


def KalmanFilter_pt(T, x_obs, mu0, V0, A, b, Q, R):

    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu0
    x_filter[0] = mu0
    V_predict[0] = V0
    V_filter[0] = V0

    start = time.time()

    kalman = Kalman(mu0, V0, A, b, Q, R)

    for t in range(1, T):

        x_predict[t], V_predict[t], x_filter[t], V_filter[t] = kalman(x_obs[t])

    elapsed_time = time.time() - start
    print("torch: ", elapsed_time)

    return x_predict, V_predict, x_filter, V_filter


def KalmanFilter(T, x_obs, mu0, V0, A, b, Q, R):
    """
    こちらは普通のカルマンフィルタ.np実装.
    """

    mu = mu0
    V = V0

    x_predict = np.zeros((T, 2))
    x_filter = np.zeros((T, 2))
    V_predict = np.zeros((T, 2, 2))
    V_filter = np.zeros((T, 2, 2))

    x_predict[0] = mu.transpose()
    x_filter[0] = mu
    V_predict[0] = V
    V_filter[0] = V

    start = time.time()

    for t in range(1, T):

        mu_ = A @ mu + b
        V_ = A @ V @ A.transpose() + Q

        x_predict[t] = mu_
        V_predict[t] = V_

        S = V_ + R
        K = V_ @ np.linalg.inv(S)

        mu = mu_ + K @ (x_obs[t] - mu_)
        V = V_ - K @ V_

        x_filter[t] = mu
        V_filter[t] = V

    elapsed_time = time.time() - start
    print("numpy:      ", elapsed_time)

    return x_predict, V_predict, x_filter, V_filter


def main():

    # データ作成
    T = 20
    mu0 = np.array([100, 100])
    V0 = np.array([[10, 0],
                   [0, 10]])

    A = np.array([[1.001, 0.001],
                  [0, 0.99]])
    b = np.array([5, 10])
    Q = np.array([[20, 0],
                  [0, 20]])
    R = np.array([[20, 0],
                  [0, 20]])

    rvq = np.random.multivariate_normal(np.zeros(2), Q, T)
    rvr = np.random.multivariate_normal(np.zeros(2), R, T)
    obs = np.zeros((T, 2))
    obs[0] = mu0
    for i in range(1, T):
        obs[i] = A @ obs[i-1] + b + rvq[i] + rvr[i]
    # 作成終わり

    x_predict, V_predict, x_filter, V_filter = \
        KalmanFilter(T, obs, mu0, V0, A, b, Q, R)
    print(x_filter)

    x_predict_tf, V_predict_tf, x_filter_tf, V_filter_tf = \
        KalmanFilter_pt(T, obs, mu0, V0, A, b, Q, R)
    print(x_filter_tf)

    fig = plt.figure(figsize=(16, 9))
    ax = fig.gca()

    ax.scatter(obs[:, 0], obs[:, 1],
               s=10, alpha=1, marker="o", color='w', edgecolor='k')
    ax.plot(obs[:, 0], obs[:, 1],
            alpha=0.5, lw=1, color='k')

    ax.scatter(x_filter[:, 0], x_filter[:, 1],
               s=10, alpha=1, marker="o", color='r')
    ax.plot(x_filter[:, 0], x_filter[:, 1],
            alpha=0.5, lw=1, color='r')

    ax.scatter(x_filter_tf[:, 0], x_filter_tf[:, 1],
               s=10, alpha=1, marker="o", color='m')
    ax.plot(x_filter_tf[:, 0], x_filter_tf[:, 1],
            alpha=0.5, lw=1, color='m')

    plt.show()


if __name__ == "__main__":
    main()

結果

当然,それほど速くない.(on macos, CPU pytorch 0.3.1.post2 anaconda python 3.6)

numpy:       0.000392913818359375
torch:  0.002187013626098633

でも書きやすさはtensorflowよりずいぶんまし.

14
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
10