3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

変化点抽出問題

Posted at

1.はじめに

何回か変化点抽出の問題を扱いましたが、(「Pythonで体験するベイズ推論」 キャメロン・デビッドソン=ピロン(著)/玉木徹(訳) 森北出版 )の中でも、PyMC3を用いた変化点抽出の例題がでてきます。

2.変化点抽出の例題

output.png
メールの着信数に変化があったかどうかを検出する問題です
メール受信数にポアソン分布を仮定します。
ポアソン分布の確率パラメータ:ラムダを仮定します。

$$ C_i \sim \text{Poisson}(\lambda) $$

時点タウで確率の大きさがラムダ1からラムダ2に切り替わると仮定します。

$$
\lambda =
\begin{cases}
\lambda_1 & \text{if } t \lt \tau \cr
\lambda_2 & \text{if } t \ge \tau
\end{cases}
$$

ラムダは正であるので、それぞれ指数分布に従うと仮定します。
指数分布にはパラメータアルファを指定します。

$$
\begin{align}
&\lambda_1 \sim \text{Exp}( \alpha ) \
&\lambda_2 \sim \text{Exp}( \alpha )
\end{align}
$$

PyMC3を使ってベイズ推論します。

output.png

45日目でタウの確率があがり、メール受信数は平均約18から約23に変化したことが推測されました。

3.STANによる実装

ここではPyMC3で実装していましたが、Rを使ってstanでも同様に実装してみます。
STANの場合、上記のような離散値をパラメータに使えないので、工夫が必要です。
ここでは、stanのユーザーガイドを参考にしました。

count_mail = data.frame(mail = c(24,  8, 24,  7, 35, 14, 11, 15, 11, 22, 22, 11, 57, 11, 19, 29,  6, 19,12, 22, 12, 18, 72, 32,  9,  7, 13, 19, 23, 27, 20,  6, 17, 13, 10, 14,6, 16, 15,  7,  2, 15, 15, 19, 70, 49,  7, 53, 22, 21, 31, 19, 11, 18,20, 12, 35, 17, 23, 17,  4,  2, 31, 30, 13, 27,  0, 39, 37,  5, 14, 13,22),day=1:73) 
ggplot(count_mail,aes(x=day,y=mail)) + geom_bar(stat="identity") 
    + ylab("count of mail received") + theme_bw()

image.png

変化点の日をcpとし、離散値になるのでパラメータの周辺化消去します。

swich_model1.stan として保存します。

 data{ 
    int N; 
    int Y[N]; 
  } 
parameters { 
    real <lower=0> lamda_1; 
    real <lower=0> lamda_2; 
    real <lower=0> alpha; 
  } 
   
transformed parameters{ 
  vector[N] lp; 
  lp = rep_vector(0,N);   // 0をN個複製してベクトル化
   
  for(cp in 1:N) 
    for(t in 1:N) 
      lp[cp] =  lp[cp] + poisson_lpmf(Y[t] | (t < cp ? lamda_1:lamda_2));  
                                        //  if_else(t < cp,lamda_1,lamda_2)と同じ
}      

model{ 
    lamda_1 ~ exponential(alpha); 
    lamda_2 ~ exponential(alpha); 
    target += log_sum_exp(lp); 
 }

ポイントは、時点のlpを2重ループしているところで、
次のように、
時点1 LRRRR
時点2 LLRRR
時点3 LLLRR
時点4 LLLLR
・・・

時点が一つ増えるごとに、lamdaが変わっていくベクトルができあがります。
変化点について、すべての時点におけるパターンを作ることで、変化点cpを周辺化消去しています。

ただ、上記のコードは2重ループで計算の負荷がかかっているため、より計算を効率的する隠れマルコフモデルをもとにした前進後退アルゴリズムを使ってみます。

swich_model2.stan として保存します。

data{
    int N;
    int Y[N];
  }

parameters {
    real <lower=0> lamda_1;
    real <lower=0> lamda_2;
    real <lower=0> alpha;
  }
  
transformed parameters{
  vector[N] lp;
  vector[N+1] lp_1;
  vector[N+1] lp_2;
  lp_1[1] = 0;
  lp_2[1] = 0;
  
  for(cp in 1:N){
    lp_1[cp+1] =  lp_1[cp] + poisson_lpmf(Y[cp] | lamda_1);
    lp_2[cp+1] =  lp_2[cp] + poisson_lpmf(Y[cp] | lamda_2);
  }   
    lp = rep_vector(lp_1[N+1],N) - head(lp_1,N) + head(lp_2,N);
}        
   
model{
  
    lamda_1 ~ exponential(alpha);
    lamda_2 ~ exponential(alpha);
    target += log_sum_exp(lp);
  }

先にlp_1で作成したベクトルに、lp_2で時点毎に置き換えていくイメージです。
swich_model1と同じことをしていますが、直接、ベクトルで計算しています。

計算は20倍以上高速化します。

3.Rで実行

次のコードでstanをキックします。

N = nrow(count_mail) 
Y = count_mail$mail 
data = list(N=N,Y=Y) 
fit1 = stan(file="swich_model1.stan" ,data=data,seed=12) 
fit2 = stan(file="swich_model1.stan" ,data=data,seed=12)

計算時間ですが、
fit1で 31.947 seconds (Total)
fit2 で 1.058 seconds (Total)
となり,fit2が30倍ぐらい早くなりました。
ちなみにpyMC3では約1分30秒かかりました。stanは高速なのがわかります。

それでは結果です。

library(bayesplot) 
mcmc_dens(fit1 ,pars = c("lamda_1","lamda_2")) 
mcmc_fit = rstan::extract(fit1) 
q = data.frame(p=exp(apply(mcmc_fit$lp ,2, median))) 
q = q/sum(q) 
q$day = 1:nrow(q) 
ggplot(q,aes(x=day,y=p)) + geom_bar(stat = "identity") + theme_bw() 
which.max(q$p)

image.png

image.png

[1] 45


PyMC3との結果とほぼ同じになりました。

4.シミュレーションモデルによる試行

シミュレーションデータで同じことを試してみます。
100日間で変化点は65時点
mailの受信数は20から15に減少します。

cp = 65 
N = 100 
data_l = rpois(cp,20) 
data_r = rpois(100-cp,16) 
data_all = data.frame(mail=c(data_l,data_r),day=1:100) 
ggplot(data_all,aes(x=day,y=mail)) + geom_bar(stat = "identity") + theme_bw()

image.png
image.png
image.png

変化点64、ラムダの推定値もやや悪いです。これぐらいの変化の差だとあまり推定がよくないのかも。

今度は、変化点がない場合、どうなるかを見てみます。
mailの受信数は20から変化していないです。

cp = 65 
N = 100 
data_l = rpois(cp,20) 
data_r=rpois(100-cp,20) 
data_all = data.frame(mail=c(data_l,data_r),day=1:100) 
ggplot(data_all,aes(x=day,y=mail)) + geom_bar(stat = "identity") + theme_bw()

image.png

その結果です。
image.png
image.png

変化点1が最も高くなり、lamda_2がだいたい20ぐらいで推定できました。

最後に、変化点が3つの場合を試してみます。

cp1 = 30 
cp2 = 65 
N = 100 
data_l = rpois(cp1,30) 
data_m = rpois(cp2-cp1,25) 
data_r = rpois(100-cp2,20) 
data_all = data.frame(mail=c(data_l,data_m,data_r),day=1:100) 
ggplot(data_all,aes(x=day,y=mail)) + geom_bar(stat = "identity") + theme_bw()

image.png
シミュレーションデータの変化点は30,65の2か所です。
mailの受信数は40から30、そして20に減少します。

2地点の変化点抽出のためのstanのコードは次のとおりです。

swich_model3.stan

data{ 
    int N; 
    int Y[N]; 
  } 

parameters { 
    real <lower=0> lamda_1; 
    real <lower=0> lamda_2; 
    real <lower=0> lamda_3; 
    real <lower=0> alpha; 
  } 
   
transformed parameters{ 
  matrix[N,N] lp; 
  lp = rep_matrix(0,N,N); 
   
  for(cp1 in 1:N) 
   for(cp2 in 1:N) 
    for(t in 1:N) 
      lp[cp1,cp2] =  lp[cp1,cp2] +  
      poisson_lpmf(Y[t] | t < cp1 ? lamda_1:(t < cp2 ? lamda_2:lamda_3)); 
}      

model{    
    lamda_1 ~ exponential(alpha); 
    lamda_2 ~ exponential(alpha); 
    lamda_3 ~ exponential(alpha); 
    target += log_sum_exp(lp); 
  }

しかし、このコードの実行は恐ろしく時間がかかりました(約2時間)
計算量が100の3乗になっています。
λの推定です。
image.png
lamda_1がおおむね40,lamda_2がおおむね30、lamda_3がおおむね20で推定できています。
次に変化点の推定です。
image.png
image.png
最初の変化点の確率が高すぎて、2番目の変化点は小さな確率の変化になりました。
おおむね30と65で推定ができているとみてよいかな?

5.参考

Stan User’s Guide
累積和を使って計算の無駄を省く(変化点検出の例)

Enjoy❕

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?