Edited at

カルマンフィルタ・予測・平滑化で状態を逐次推定する


はじめに

状態空間モデルにおいて状態を逐次的に推定する有名な手法の1つにカルマンフィルタというものがあります。

カルマンフィルタなどを用いて出来る事・概要・手順・numpyを用いた行列実装をまとめてみました。


この記事の対象としている人


  • 確率分布の計算はなんとなくわかる

  • 状態空間の概念はなんとなくわかる

  • カルマンフィルタはよく知らない

  • numpyの行列実装の仕方を知りたい


カルマンフィルタで出来る事

カルマンフィルタを使うと、ノイズ混じりの観測データから、観測ノイズを取り除いた状態を逐次推定できます。

こんな感じです。

経済やマーケティングの文脈においては、時系列データの平滑化やトレンドの把握、少し先の予測などに使われます。

ある事象に対して、時点1から時点tまで時系列に沿って観測したデータ$y_{1:t}$があるとします。

例えば、DAU(Daily Active User, 1日あたりにログインしたユーザー数)を想定します。

DAUが以下のように分解出来るとします。

DAU = 状態(サービス内要因, サービス魅力度など) + 観測誤差(サービス外要因, 曜日変動やノイズなど) \\

観測誤差(サービス外要因) \sim N(0, \Sigma) \\

DAUの監視を通じて知りたいのは、時点tにおけるサービスの魅力度を表す状態$x_t$です。しかし、状態$x_t$そのものを直接観測する事は出来ず、観測出来るDAUはサービス外の要因が混じっています。なので、DAU観測データ$y_{1:t}$からサービス外の要因を取り除き、状態$x_t$を推定したいです。

そこで、フィルタリングや予測、平滑化を行います。


フィルタリング、予測、平滑化の違い

時点1から時点tまでの観測データ$y_{1:t}$を元に、ある時点の状態$x_{t'}$を推定します。つまり、$p(x_{t'} | y_{1:t})$を求めます。

より正確には、観測データが手元にあるときの時点tの状態$x_t$確率分布$p(x_{t'} | y_{1:t})$を求め、その平均や分散を求めます。

ここで、


  • t' < t, つまり時点tまでの観測値を元に過去の時点の状態$x_{t'}$を推定する場合、平滑化

  • t' = t, つまり時点tまでの観測値を元に現在の時点の状態$x_{t'}$を推定する場合、フィルタリング

  • t' > t, つまり時点tまでの観測値を元に未来の時点の状態$x_{t'}$を推定する場合、予測


という違いがあります。

例えば、過去のトレンドを知りたいときは平滑化を行ってもよいですし、現在のトレンドをリアルタイムに知りたければ、最新時点のデータが入ってくる度にフィルタリングを行えばよいでしょう。


カルマンフィルタについて

カルマンフィルタ、カルマン予測、カルマン平滑化の順に説明します。


前提

まず、以下のような状態空間モデルを想定します

状態方程式: x_t = G_t x_{t-1} + w_t \\

観測方程式: y_t = F_t x_t + v_t \\
状態ノイズ: w_t \sim N(0, W_t) \\
観測ノイズ: v_t \sim N(0, V_t)

$x_t, w_t$: p×1ベクトル

$G_t, W_t$: p×p行列

$F_t$: 1×pベクトル

$y_t, v_t, V_t$: スカラー

状態$x$は、時点tが増えるにつれ係数$G$がかかり、状態が遷移していきます。

状態が遷移する度に、正規分布に従うノイズ$w$が混ざります。

時点tの観測値$y_t$は、観測方程式に従って求められます。

観測値$y_t$には、正規分布に従う観測ノイズ$v$が混ざります。

$G_t, F_t, W_t, V_t$は既知であるとします。


カルマンフィルタの概要

カルマンフィルタは、時点tまでの観測値$y_{1:t}$から時点tの状態$x_t$を逐次的に効率よく求める手法です。

観測値から観測ノイズを取り除いた状態を推定するので、フィルタリングと呼ばれます。

時点tまでの観測値$y_{1:t}$の情報を元に時点tの状態$x_t$を推測するフィルタリング分布$p(x_t | y_{1:t})$を逐次的に求めたいです。

具体的には、時点毎のフィルタリング分布を仮定し、そのパラメータを逐次的に求めたいです。

そのために、以下の3つの分布を設定します。

フィルタリング分布: p(x_t | y_{1:t}) = N(m_t, C_t) \\

一期先予測分布: p(x_t | y_{1:1-t}) = N(a_t, R_t) \\
一期先予測尤度: p(y_t | y_{1:t-1}) = N(f_t, Q_t) \\

フィルタリング分布は時点tまでの観測値$y_{1:t}$を元に時点tの状態$x_t$を発生させる分布

一期先予測分布は時点t-1までの観測値$y_{1:t-1}$を元に時点tの状態$x_t$を予測する分布

一期先予測尤度は時点t-1までの観測値$y_{1:t-1}$を元に時点tの観測値$y_t$を予測する関数

一期先予測分布$p(x_t | y_{1:t-1})$は、観測値$y_t$を観測する前に予測したフィルタリング分布$p(x_t | y_{1:t})$のようなものです。

フィルタリング分布と一期先予測分布はどちらも$x_t$を発生させる確率分布になっています。

最後の関数が尤度と名付いているのは、恐らくこの関数が観測された$y_{1:t}$から尤もらしいパラメータ$m_t, C_t$を推測するために使われるからだと思われます。

確率"分布"(確率密度関数)は固定したパラメータからある値が発生する確率を求めるのに対し、"尤度"(尤度関数)は観測されたデータ(固定された値)を発生させたパラメータがある値だとどれくらい尤もらしいかを返す関数です。

時点tにおけるパラメータ$\Theta_t = { m_t, C_t, a_t, R_t, f_t, Q_t }$を、時点t-1のフィルタリング分布パラメータ$m_{t-1}, C_{t-1}$と観測値$y_t$を用いて求めます。


カルマンフィルタの手順

過去の観測値$y_{1:t-1}$から1期先予測分布(予測値に基づくフィルタリング分布)を計算して、その後に観測された$y_t$を用いて1期先予測分布のパラメータを修正していくイメージです。

1期先予測分布$p(x_t|y_{1:t-1})$を計算・予測



データ$Y_t=y_t$を観測



$y_t$の情報を元にフィルタリング分布$p(x_t|y_{1:t})$へ修正

この計算を観測したい時点まで繰り返します。

初期状態のフィルタリング分布のパラメータ$m_0, C_0$を設定します。

以下の計算をT回繰り返します。

1 時点tの一期先予測分布$p(x_t | y_{1:t-1})$のパラメータ

a_t = G_t m_{t-1} \\

R_t = G_t C_{t-1} G_t^T + W_t

2 時点tの一期先予測尤度$p(y_t | y_{1:t-1})$のパラメータ

f_t = F_t a_t \\

Q_t = F_t R_t F_t^T + V_t

3 時点tのカルマンゲイン

K_t = R_t F_t^T Q_t^{-1}

4 時点tの状態($x_t$を発生させるフィルタリング分布のパラメータ)の更新

m_t = a_t + K_t(y_t - f_t) \\

C_t = (I - K_t F_t) R_t

# コードにするとこんな感じ

def kalman_filter(m, C, y, G=G, F=F, W=W, V=V):
"""
Kalman Filter
m: 時点t-1のフィルタリング分布の平均
C: 時点t-1のフィルタリング分布の分散共分散行列
y: 時点tの観測値
"""

a = G @ m
R = G @ C @ G.T + W
f = F @ a
Q = F @ R @ F.T + V
# 逆行列と何かの積を取る場合は、invよりsolveを使った方がいいらしい
K = (np.linalg.solve(Q.T, F @ R.T)).T
# K = R @ F.T @ np.linalg.inv(Q)
m = a + K @ (y - f)
C = R - K @ F @ R
return m, C


カルマン予測の概要・手順

カルマン予測は、t+k時点予測分布$p(x_{t+k} | y_{1:t})$を求める手法です。

基本的な概念はカルマンフィルタと変わりません。

カルマンフィルタは

1. 一期先予測分布$p(x_t | y_{t-1})$を計算

2. データ$y_t$を元にフィルタリング分布$p(x_t|y_{1:t})$へ修正

という流れを辿っていましたが、未来のデータは観測出来ません。そこで


  1. 最新時点tのフィルタリング分布$p(x_t | y_{1:t})=p(x_{t+0} | y_{1:t})$を準備、時点tにおける0期先予測分布とする

  2. $p(x_t | y_{1:t})$を用いてt+1時点の一期先予測分布$p(x_{t+1} | y_{1:t})$を計算

  3. $p(x_{t+1} | y_{1:t})$を用いてt+2時点の一期先予測分布$p(x_{t+2} | y_{1:t})$を計算

    ...


というように、観測データによる修正をせずひたすら予測分布の計算を繰り返します。カルマンフィルタと同じ計算式です。

時点tでの0期先予測分布$a_t(0)=m_t, R_t(0)=C_t$を元に、時点t+k, k期先予測分布の平均$a_t(k)$, 分散$R_t(k)$をk=1~kまで逐次的に計算します

a_t(k) = G_{t+k} a_t(k-1) \\

R_t(k) = G_{t+k} R_t(k-1) G_{t+k}^T + W_{t+k}

# 一期先予測分布を連続して求めるだけ

def kalman_prediction(a, R, G=G, W=W):
"""
Kalman prediction
"""

a = G @ a
R = G @ R @ G.T + W
return a, R


カルマン平滑化の概要・手順

カルマン平滑化は、時点tの平滑化分布

p(x_t | y_{1:T}) = N(s_t, S_t)

を求める手法です。(t < T)

カルマンフィルタを時点Tまで計算したとします。

カルマン平滑化は、時点tのフィルタリング分布$p(x_t | y_{1:t})$を、時点t+1の平滑化分布$p(x_{t+1} | y_{1:T})$で補正します。

これにより、トレンドはより滑らかになります。

手順(RTSアルゴリズム)

0 時点t+1での平滑化分布: $s_{t+1}, S_{t+1}$を準備(t=Tのときは時点tのフィルタリング分布を使用)

1 時点tの平均化利得を計算

A_t = C_t G_{t+1}^T R_{t+1}^{-1}

2 時点tの平滑化分布のパラメータを計算

s_t = m_t + A_t(s_{t+1} - a_{t+1}) \\

S_t = C_t + A_t(S_{t+1} - R_{t+1}A_t^T)

# カルマン平滑化  

# 固定区間平滑化を行う
# 時点Tまでのカルマンフィルタリングが一旦完了しているものとする
# aとRの計算は、カルマンフィルタリングの計算時に格納したものを使った方が、計算効率は良さそう
def kalman_smoothing(s, S, m, C, G=G, W=W):
"""
Kalman smoothing
"""

# 1時点先予測分布のパラメータ計算
a = G @ m
R = G @ C @ G.T + W
# 平滑化利得の計算
A = np.linalg.solve(R, C @ G.T)
# A = C @ G.T @ np.linalg.inv(R)
# 状態の更新
s = m + A @ (s - a)
S = C + A @ (S - R) @ A.T
return s, S


実装

自作関数の定義だけ上に書いたものを使います。


ライブラリのインポート

import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm

np.random.seed(1234)
pd.set_option('display.max_columns', None)
sns.set_style('darkgrid')

%matplotlib inline


仮想データ

一番簡単な例として、ランダムウォークを想定します。


これは、ローカルモデルの一種です。

x_t = x_{t-1} + w_t \\

y_t = x_t + v_t \\
w_t \sim N(0, 1) \\
v_t \sim N(0, 10) \\

観測時点Tは100, 予測時点は5, 初期状態は100とします

G = np.array([[1]])

F = np.array([[1]])
W = np.array([[1]]) # 恣意的に与える必要がある
V = np.array([[10]]) # 上に同じ
T = 100
K = 5
x0 = 100

w = np.random.multivariate_normal(np.zeros(1), W, T+K)
v = np.random.multivariate_normal(np.zeros(1), V, T+K)
x = np.zeros(T+K)
y = np.zeros(T+K)

x[0] = x0 + w[0]
y[0] = x[0] + v[0]
for t in range(1, T+K):
x[t] = x[t-1] + w[t]
y[t] = x[t] + v[t]

fig, ax = plt.subplots(figsize=(16, 4))
sns.lineplot(np.arange(T+K), x, ax=ax, label=" true state")
sns.lineplot(np.arange(T+K), y, color="gray", ax=ax, label="observation")
ax.set_title("simulation data")
ax.legend()
plt.show()


パラメータ設定

# 初期状態のフィルタリング分布のパラメータ

m0 = np.array([[0]])
C0 = np.array([[1e7]])

# 結果を格納するarray
m = np.zeros((T, 1))
C = np.zeros((T, 1, 1))
a_pred = np.zeros((K, 1))
R_pred = np.zeros((K, 1, 1))
s = np.zeros((T, 1))
S = np.zeros((T, 1, 1))


実行

# カルマンフィルター

for t in range(T):
if t == 0:
m[t], C[t] = kalman_filter(m0, C0, y[t:t+1])
else:
m[t], C[t] = kalman_filter(m[t-1:t], C[t-1:t], y[t:t+1])

# カルマン予測
for t in range(K):
if t == 0:
a = G @ m[T-1:T]
R = G @ C[T-1:T] @ G.T + W
a_pred[t] = a
R_pred[t] = R
else:
a_pred[t], R_pred[t] = kalman_prediction(a_pred[t-1], R_pred[t-1])

# カルマン平滑化
for t in range(T):
t = T - t - 1
if t == T - 1:
s[t] = m[t]
S[t] = C[t]
else:
s[t], S[t] = kalman_smoothing(s[t+1], S[t+1], m[t], C[t])


結果

$Z_{0.05/2} \fallingdotseq 1.96$より、標準偏差に1.96を掛けたものを95%区間とします。


カルマンフィルタリング(t<=50)とカルマン予測(t>50)の結果をプロットします。

upper = 115

lower = 85
legend_loc = "lower left"

fig, axes = plt.subplots(nrows=3, figsize=(16, 12))
sns.lineplot(np.arange(T+K), x, ax=axes[0], label="true state")
sns.lineplot(np.arange(T+K), y, color="gray", ax=axes[0], label="observation")
sns.lineplot(np.arange(T), m.flatten(), color="red", ax=axes[0], label="kalman filter + prediction")
axes[0].plot(np.arange(T), (m - 1.96 * C[:,:,0]**(1/2)).flatten(), alpha=0.3, color='gray', label=".95 interval")
axes[0].plot(np.arange(T), (m + 1.96 * C[:,:,0]**(1/2)).flatten(), alpha=0.3, color='gray')
axes[0].plot(np.arange(T, T+K), a_pred.flatten(), color='red')
axes[0].plot(np.arange(T, T+K), (a_pred - 1.96 * R_pred[:,:,0]**(1/2)).flatten(), alpha=0.3, color='gray')
axes[0].plot(np.arange(T, T+K), (a_pred + 1.96 * R_pred[:,:,0]**(1/2)).flatten(), alpha=0.3, color='gray')
axes[0].axvline(100, color="black", linestyle="--", alpha=0.5, label="left: filtering, right: prediction")
axes[0].set_ylim(lower, upper)
axes[0].legend(loc=legend_loc)
axes[0].set_title("Kalman Filter + Prediction")

sns.lineplot(np.arange(T+K), x, ax=axes[1], label="true state")
sns.lineplot(np.arange(T+K), y, color="gray", ax=axes[1], label="observation")
sns.lineplot(np.arange(T), s.flatten(), color="green", ax=axes[1], label="kalman smoothing")
axes[1].plot(np.arange(T), (s - 1.96 * S[:,:,0]**(1/2)).flatten(), alpha=0.3, color='gray', label=".95 interval")
axes[1].plot(np.arange(T), (s + 1.96 * S[:,:,0]**(1/2)).flatten(), alpha=0.3, color='gray')
axes[1].axvline(100, color="black", linestyle="--", alpha=0.5, )
axes[1].set_ylim(lower, upper)
axes[1].legend(loc=legend_loc)
axes[1].set_title("Kalman Smoothing")

# sns.lineplot(np.arange(T+K), x, ax=axes[2], label="true state")
sns.lineplot(np.arange(T+K), y, color="gray", ax=axes[2], label="observation")
sns.lineplot(np.arange(T), m.flatten(), color="red", ax=axes[2], label="kalman filter")
sns.lineplot(np.arange(T), s.flatten(), color="green", ax=axes[2], label="kalman smoothing")
axes[2].axvline(100, color="black", linestyle="--", alpha=0.5)
axes[2].set_ylim(lower, upper)
axes[2].legend(loc=legend_loc)
axes[2].set_title("Kalman filter vs kalman smoothing")
plt.show()

print("カルマンフィルタリングの分散の平均: {:.3f}".format(C.mean()))

print("カルマン平滑化の分散の平均: {:.3f}".format(S.mean()))

青が真の状態、グレーが観測値、赤がカルマンフィルタ、緑がカルマン平滑化です。 カルマンフィルタ(赤)やカルマン平滑化(緑)と真の状態(青)がそれなりに近しい動き方をしていることがわかります。

観測誤差をある程度取り除いて状態を推定できました。

カルマンフィルタと比べ、カルマン平滑化の方が分散が小さくなっています。

カルマン予測は時点を追うごとに分散が大きくなっています。

過去と現在の情報から状態を推測するカルマンフィルタに比べ、カルマンフィルタの結果を1期先の平滑化分布で補正するカルマン平滑化の方が、よりなだらかです。

いかがでしたでしょうか


まとめ


  • カルマンフィルタは観測値$y_{1:t}$から状態$x_t$を求める手法

  • 観測誤差のノイズを取り除いて、状態を推定することができる

  • pythonで実装してみた

TO DO

もっと複雑なモデルに対する適用


参考文献

(萩原, 瓜生, 牧山, 2018) 基礎からわかる時系列分析