LoginSignup
1
3

More than 1 year has passed since last update.

R機械学習/カルマンフィルタによる状態空間モデル1

Last updated at Posted at 2022-05-09

Introduction

カルマンフィルタを使用した状態空間モデルをRで実装する方法についてのメモ.
今回は比較的単純なローカルレベルモデルを実装する.
参考図書:時系列分析と状態空間モデルの基礎: RとStanで学ぶ理論と実装

ライブラリインポート

library(KFAS)  # カルマンフィルタの実行
library(ggplot2)  # グラフ描画

データ準備

今回はRにデフォルトで用意されているNile(ナイル川の流量)データを使用.

  • データ概要:1871年~1970年のナイル川の年間流量
dataset = Nile
dataset

Out:
1.PNG

plot(dataset)

Out:
2.PNG

訓練データとテストデータに分割.

train = window(dataset, 1871, 1950)
test = window(dataset, 1951, 1970)

モデリング

ローカルレベルモデルは以下のように定式化される.

$$
\begin{eqnarray}
\mu_t &=& \mu_{t-1} + w_t, \ \ &w_t&~Normal(0,\sigma^2_w) \\
y_t &=& \mu_t + v_t, &v_t&~Normal(0,\sigma^2_v)
\end{eqnarray}
$$

  • $\mu_t$:状態
  • $y_t$:観測値
  • $w_t$:過程誤差
  • $v_t$:観測誤差

KFASパッケージでパラメタの推定からフィルタリングまで実行

# step1:モデルの構造を決める
build_kfas = SSModel(H=NA, train ~ SSMtrend(degree=1, Q=NA))
# step2:パラメタ推定
fit_kfas = fitSSM(build_kfas, inits=c(1,1))
# step3:フィルタリング
result_kfas = KFS(fit_kfas$model, filtering=c('state', 'mean'), smoothing=c('state', 'mean'))

モデル評価

訓練データにおけるフィルタ化推定量を図示.

# グラフ用データフレーム作成
df_kfas = data.frame(time=time(train), y=train, y_pred=result_kfas$a[-1])

ggplot(data=df_kfas, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red')

Out:黒点が観測値,赤線がフィルタ化推定量(フィルタリング後の状態)
3.PNG

フィルタ化推定量は現時点の観測値が得られた後に補正した値である.
一般的な機械学習のような過去の値のみから算出した予測値は以下の通り.

# グラフ用データフレーム作成
df_kfas2 = data.frame(time=time(train), y=train, y_pred=result_kfas$m[1:length(train)])
df_kfas2$y_pred[1] = NA

ggplot(data=df_kfas2, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red')

Out:黒点が観測値,赤線が予測値
6.png

未来予測

未来の値を20期先まで予測.
ローカルレベルモデルの場合状態は一定.

# 予測結果の算出
forecast_pred = predict(fit_kfas$model, interval='prediction', n.ahead=20)
# グラフ用データフレーム作成
df_kfas_pred = data.frame(time=time(dataset), y=dataset, y_pred=c(result_kfas$a[-1],forecast_pred[,1]))

ggplot(data=df_kfas_pred, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red') +
  geom_vline(xintercept= 1950.5, col = 'purple', linetype='dashed')

Out:紫垂直線の右がテストデータ
4.PNG

逐次更新

訓練データで推定したパラメータを使用し,テストデータにおいても状態の更新をしたい.
kfasパッケージにはそのような機能は無さそうなので,フィルタリングする関数を自作する.

カルマンフィルタ関数

kfLocallevel = function(y, mu_pre, p_pre, sigma_w, sigma_v) {
  " ローカルレベルモデルの予測とフィルタリングを実行
  parameters
  ----------
  y: numeric
    観測値
  mu_pre: numeric
    前期の状態
  p_pre: numeric
    前期の状態の予測誤差の分散(mu_preの分散)
  sigma_w: numeric 
    過程誤差の分散
  sigma_v: numeric
    観測誤差の分散
  Returns
  -------
  result: data.frame
  "
  mu_forecast = mu_pre  # 状態の予測値
  p_forecast = p_pre + sigma_w  # 状態の予測誤差の分散
  k_gain = p_forecast / (p_forecast+sigma_v)  # カルマンゲイン
  mu_filtered <- mu_forecast + k_gain * (y-mu_forecast)  # カルマンゲインで補正された状態
  p_filtered = (1-k_gain) * p_forecast  # フィルタ化後の状態の予測誤差の分散
  
  result = data.frame(mu_forecast, mu_filtered, p_filtered)
  
  return(result)
}

推定結果を格納する変数を準備

N = length(c(dataset))  # サンプルサイズ

mu_forecast = numeric(N+1) # 予測値
mu_filter = numeric(N+1)  # 状態の推定値を格納する変数
p_filter = numeric(N+1)  # 状態の予測誤差の分散を格納する変数

sigma_w = fit_kfas$model$Q  # 最尤法で推定した過程誤差の分散
sigma_v = fit_kfas$model$H  # 最尤法で推定した観測誤差の分散

mu_forecast[2] = dataset[1] # 1時点目のみ観測値を使用
mu_filter[2] = dataset[1] # 散漫初期化を用いており,1時点目の推定量は最初の観測値になる
p_filter[2] = sigma_v

フィルタリング

for(i in 2:N) {
  result = kfLocallevel(dataset[i], mu_filter[i], p_filter[i], sigma_w, sigma_v)
  mu_forecast[i+1] = result$mu_forecast
  mu_filter[i+1] = result$mu_filtered
  p_filter[i+1] = result$p_filtered
}
df_kfas_pred2 = data.frame(time=time(dataset), y=dataset, y_pred=mu_forecast[-1])

ggplot(data=df_kfas_pred2, aes(x=time, y=y))+
  geom_point(alpha=0.6)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red') +
  geom_vline(xintercept= 1950.5, col = 'purple', linetype='dashed')

Out:紫垂直線の右がテストデータ
5.PNG

逐次更新2

kfasパッケージを使って逐次更新する方法について記す.
少し強引な方法ではあるがこちらの方が便利かも.

データ準備
テストデータのサイズ分訓練データにNAを代入

train2 = train
length(train2) = length(train) + length(test)
test2 = as.numeric(test)
train2

Out:
8.PNG

KFASパッケージでパラメタの推定まではこれまでと同様に行う.

# step1:モデルの構造を決める
build_kfas = SSModel(H=NA, train2 ~ SSMtrend(degree=1, Q=NA))
# step2:パラメタ推定
fit_kfas = fitSSM(build_kfas, inits=c(1,1))

パラメータ推定結果を格納したオブジェクトに訓練データとテストデータを結合したデータを格納

fit_kfas[["model"]][["y"]][1:length(train2)] = c(train,test2)
# step3,4:フィルタリング
result_kfas2 = KFS(fit_kfas$model, filtering=c('state', 'mean'), smoothing=c('state', 'mean'))

推定結果の図示

df_kfas3 = data.frame(time=time(dataset), y=dataset, y_pred=result_kfas2$m[1:length(dataset)])
df_kfas3$y_pred[1] = NA

ggplot(data=df_kfas3, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red') +
  geom_vline(xintercept= 1950.5, col = 'purple', linetype='dashed')

Out:紫垂直線の右がテストデータ
7.png

Conclusion

今回はローカルレベルモデルの構築,予測,逐次更新を実装した.
次回は説明変数を加えるなどの複雑なモデルを実装する.



code

# ライブラリインポート
library(KFAS)  # カルマンフィルタの実行
library(ggplot2)  # グラフ描画

# データ準備
dataset = Nile
dataset

plot(dataset)

train = window(dataset, 1871, 1950)
test = window(dataset, 1951, 1970)

# モデリング
build_kfas = SSModel(H=NA, train ~ SSMtrend(degree=1, Q=NA))

fit_kfas = fitSSM(build_kfas, inits=c(1,1))

result_kfas = KFS(fit_kfas$model, filtering=c('state', 'mean'), smoothing=c('state', 'mean'))


# モデル評価
df_kfas = data.frame(time=time(train), y=train, y_pred=result_kfas$a[-1])

ggplot(data=df_kfas, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red')


df_kfas2 = data.frame(time=time(train), y=train, y_pred=result_kfas$m[1:length(train)])
df_kfas2$y_pred[1] = NA

ggplot(data=df_kfas2, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red')

# 未来予測

forecast_pred = predict(fit_kfas$model, interval='prediction', n.ahead=20)

df_kfas_pred = data.frame(time=time(dataset), y=dataset, y_pred=c(result_kfas$a[-1],forecast_pred[,1]))

ggplot(data=df_kfas_pred, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red') +
  geom_vline(xintercept= 1950.5, col = 'purple', linetype='dashed')

# 逐次更新
kfLocallevel = function(y, mu_pre, p_pre, sigma_w, sigma_v) {
  " ローカルレベルモデルの予測とフィルタリングを実行
  parameters
  ----------
  y: numeric
    観測値
  mu_pre: numeric
    前期の状態
  p_pre: numeric
    前期の状態の予測誤差の分散(mu_preの分散)
  sigma_w: numeric 
    過程誤差の分散
  sigma_v: numeric
    観測誤差の分散
  Returns
  -------
  result: data.frame
  "
  mu_forecast = mu_pre  # 状態の予測値
  p_forecast = p_pre + sigma_w  # 状態の予測誤差の分散
  k_gain = p_forecast / (p_forecast+sigma_v)  # カルマンゲイン
  mu_filtered <- mu_forecast + k_gain * (y-mu_forecast)  # カルマンゲインを使って補正された状態
  p_filtered = (1-k_gain) * p_forecast  # フィルタ化後の状態の予測誤差の分散
  
  result = data.frame(mu_forecast, mu_filtered, p_filtered)
  
  return(result)
}

N = length(c(dataset))  # サンプルサイズ

mu_forecast = numeric(N+1) # 予測値
mu_filter = numeric(N+1)  # 状態の推定値を格納する変数
p_filter = numeric(N+1)  # 状態の予測誤差の分散を格納する変数

sigma_w = fit_kfas$model$Q  # 最尤法で推定した過程誤差の分散
sigma_v = fit_kfas$model$H  # 最尤法で推定した観測誤差の分散

mu_forecast[2] = dataset[1] # 1時点目のみ観測値を使用
mu_filter[2] = dataset[1] # 散漫初期化を用いており,1時点目の推定量は最初の観測値になる
p_filter[2] = sigma_v

for(i in 2:N) {
  result = kfLocallevel(dataset[i], mu_filter[i], p_filter[i], sigma_w, sigma_v)
  mu_forecast[i+1] = result$mu_forecast
  mu_filter[i+1] = result$mu_filtered
  p_filter[i+1] = result$p_filtered
}

df_kfas_pred2 = data.frame(time=time(dataset), y=dataset, y_pred=mu_forecast[-1])

ggplot(data=df_kfas_pred2, aes(x=time, y=y))+
  geom_point(alpha=0.6)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red') +
  geom_vline(xintercept= 1950.5, col = 'purple', linetype='dashed')

# 逐次更新2
train2 = train
length(train2) = length(train) + length(test)
test2 = as.numeric(test)
train2

build_kfas = SSModel(H=NA, train2 ~ SSMtrend(degree=1, Q=NA))
fit_kfas = fitSSM(build_kfas, inits=c(1,1))

fit_kfas[["model"]][["y"]][1:length(train2)] = c(train,test2)
result_kfas2 = KFS(fit_kfas$model, filtering=c('state', 'mean'), smoothing=c('state', 'mean'))

df_kfas3 = data.frame(time=time(dataset), y=dataset, y_pred=result_kfas2$m[1:length(dataset)])
df_kfas3$y_pred[1] = NA

ggplot(data=df_kfas3, aes(x=time, y=y))+
  geom_point(alpha=0.5)+
  geom_line(aes(y=y_pred),size=1.2, col = 'red') +
  geom_vline(xintercept= 1950.5, col = 'purple', linetype='dashed')
1
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3