0
1

長期トレンド付きガウス過程で東京のコロナ感染者数を予測してみた

Posted at

はじめに

こんにちは、事業会社で働いているデータサイエンティストです。

詳細は言えませんが、現在会社で取り組んでいるコミュニケーション設計系のタスクでガウス過程を利用しています。

今回の記事では長期トレンド付きのガウス過程の概念を説明して、実際に東京都のコロナウイルス感染者数の予測で性能を確認します。

データかこちらからダウンロードできます。

本記事のデータは日本時間の2024年2月8日午後8時15分にダウンロードしました。

トレンド付きガウス過程の概念

トレンド付きのガウス過程はDew and Ansari(2018)から得た着想です。

簡潔に説明すると、普通のトレンドのないガウス過程も、確かに表現力が強く、予測性能も悪くないが、(時系列予測タスクでいうと)未来の予測になればなるほど、予測に使える情報がどんどん減り、最終的に予測値が平均、つまり多くの場合ゼロに戻ってしまう問題があります。

つまり、長期予測には向いてません。

このゼロに回帰(?)してしまう現象に対処する方法として、ガウス過程の平均をゼロと置くのではなく、自体もきちんとモデル化しましょうという発想が提案されました。

関係ない話ですが、ガウス過程と別のベイズ機械学習を組み合わせた新しい手法も以前の記事で提案したのでもし興味ありましたらぜひ読んでみてください:

モデル説明

ここでは、モデルの説明に入ります。

まず、パラメーターの事前分布は

$$ \alpha \sim Normal(0,1)$$

$$ \sigma \sim Inverse\space Gamma(5,5)$$

$$ \lambda_{1} \sim Normal(0,5)$$

$$ \lambda_{2} \sim Beta(10,10)$$

$$ \rho_{trend} \sim Gamma(0.01, 0.01)$$

$$ \alpha_{trend} \sim Normal(0,1)$$

全てのt(時間)について

$$ m_{t} = \lambda_{1} * (t - 1) ^ {\lambda_{2}} $$

$$ f_{trend} \sim Gaussian\space Process(mean=m,magnitude=\alpha_{trend},length\space scale = \rho_{trend})$$

ただし、Gaussian Processは指数2次カーネルを共分散関数とするガウス過程で、パラメーターは引数のところで示された通りです。

最後に、観測値は

$$ Covid_{東京} \sim Laplace(\alpha + f_{trend}, \sigma) $$

で生成されるとモデル化します。

ここでは、ガウス過程の平均、つまり長期トレンドを

$$ m_{t} = \lambda_{1} * (t - 1) ^ {\lambda_{2}} $$

で定式化して、$\lambda_{1}$と$\lambda_{2}$をデータから推定します。

ただ、ここで強調しておきたいのは、Dew and Ansariの論文では、$\lambda_{2}$は正の実数という制約しかかけていないですが、実際に様々な場面で利用すると、特にデータが足りないなどの際に、$\lambda_{2}$が1を超えることがあります。$\lambda_{2}$が1を超えると予測結果が指数関数的に成長してしまいます。

なので、ここでは$\lambda_{2}$の範囲を0から1までに制限し、正規分布の代わりにベータ分布を事前分布として設定しました。

もちろん、元々指数関数的に成長するデータはあると思いますが、少なくとも私がいる求人業界(日本の就職希望者数が指数関数的に成長したらちょっと怖いです、、、)とコロナウイルスの感染状況という二つのドメインにおいて、指数関数的な成長は妥当な仮定ではありません。

ご自身のデータの特性を理解して、$\lambda_{2}$の妥当な範囲、ないしはより妥当な長期トレンドの定式化をぜひ考えてみてください。

観測値が従う関数はあえて正規分布ではなく、ラプラス分布に設定しました。結果が変わるかは未検証ですが、ガウス過程はあくまでも時系列が従うトレンドを推定する手法だけで、実際のデータが従う分布(深層学習的にいうと出力ですかね?)は別に正規分布である必要は全くありません。

Stanのコードはこちらです。reduce_sum関数で並列化しています。

dp.stan
functions {
  real partial_sum_lpdf(
    array[] real covid,
    
    int start, int end,
    
    real alpha,
    real sigma,
    vector f_trend
  ){
    return double_exponential_lupdf(covid | alpha + f_trend[start:end], sigma);
  }
}
data {
  int<lower=1> N_train;
  int<lower=0> N_predict;
  
  array[N_train + N_predict] int<lower=1,upper=N_train+N_predict> N_seq;
  
  array[N_train] real covid;
}
parameters {
  real alpha;
  
  real<lower=0> sigma;
  
  real lambda_1;
  real<lower=0,upper=1> lambda_2;
  
  real<lower=0> rho_trend;
  real<lower=0> alpha_trend;
  vector[N_train + N_predict] eta_trend;
}
transformed parameters {
  vector[N_train + N_predict] m_full;
  vector[N_train + N_predict] f_trend;
  {
    matrix[N_train + N_predict, N_train + N_predict] K_trend = gp_exp_quad_cov(N_seq, alpha_trend, rho_trend);
    
    matrix[N_train + N_predict, N_train + N_predict] L_K_trend;
    
    for (n in 1:(N_train + N_predict)){
      m_full[n] = lambda_1 * (n - 1) ^ lambda_2;
      K_trend[n, n] = K_trend[n, n] + 0.001;
    }
    
    L_K_trend = cholesky_decompose(K_trend);
    
    f_trend = m_full + L_K_trend * eta_trend;
  }
}
model {
  alpha ~ std_normal();
  
  sigma ~ inv_gamma(5, 5);
  
  lambda_1 ~ normal(0, 5);
  lambda_2 ~ beta(10, 10);
  
  rho_trend ~ gamma(0.01, 0.01);
  alpha_trend ~ std_normal();
  eta_trend ~ std_normal();
  
  int grainsize = 1;
  
  target += reduce_sum(
    partial_sum_lupdf, covid,
    
    grainsize,
    
    alpha,
    sigma,
    f_trend[1:N_train]
  );
}
generated quantities {
  array[N_train + N_predict] real covid_predicted;
  
  covid_predicted = double_exponential_rng(alpha + f_trend, sigma);
}

前処理

まず、コロナウイルスの感染状況のデータを読み込みます。

covid_df <- readr::read_csv("newly_confirmed_cases_daily.csv")

次に、Stanが利用するデータに整形するが、ここでは少しデータの切り方の話をします。

  • 学習データ:2020/1/16 ~ 2023/4/8
  • 検証データ:2023/4/9 ~ 2023/5/8
  • 予測期間:2023/5/9 ~ 2023/7/7

また、実際にデータをbase Rのplotで簡単に可視化すればわかるように、

plot(covid_df$Tokyo)

row.png

コロナの感染者数データはスケールの変動が大きすぎるので、まず対数を取った上で標準化する工夫もしています。

実際に前処理されたデータはこんな形になります:

plot(((log(covid_df$Tokyo + 1) - mean(log(covid_df$Tokyo + 1)))/sd(log(covid_df$Tokyo + 1))))

transformed.png

より変動が見やすくなったと思います。

次に、データを整理して

data_list <- list(
  N_train = nrow(covid_df) - 30,
  N_predict = 90,
  
  N_seq = seq_len(nrow(covid_df) + 60),
  
  covid = ((log(covid_df$Tokyo + 1) - mean(log(covid_df$Tokyo + 1)))/sd(log(covid_df$Tokyo + 1)))[seq_len(nrow(covid_df) - 30)]
)

推定を開始します。

> m_gp_estimate <- m_gp_init$variational(
     seed = 12345,
     threads = 16,
     iter = 10000,
     data = data_list
 )
------------------------------------------------------------ 
EXPERIMENTAL ALGORITHM: 
  This procedure has not been thoroughly tested and may be unstable 
  or buggy. The interface is subject to change. 
------------------------------------------------------------ 
Gradient evaluation took 0.10868 seconds 
1000 transitions using 10 leapfrog steps per transition would take 1086.8 seconds. 
Adjust your expectations accordingly! 
Begin eta adaptation. 
Iteration:   1 / 250 [  0%]  (Adaptation) 
Iteration:  50 / 250 [ 20%]  (Adaptation) 
Iteration: 100 / 250 [ 40%]  (Adaptation) 
Iteration: 150 / 250 [ 60%]  (Adaptation) 
Iteration: 200 / 250 [ 80%]  (Adaptation) 
Success! Found best value [eta = 1] earlier than expected. 
Begin stochastic gradient ascent. 
  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes  
   100        -1211.741             1.000            1.000 
   200         -684.854             0.885            1.000 
   300         -537.415             0.681            0.769 
   400         -466.635             0.549            0.769 
   500         -643.440             0.494            0.275 
   600         -464.546             0.476            0.385 
   700         -484.010             0.414            0.275 
   800         -450.159             0.371            0.275 
   900         -478.353             0.337            0.274 
  1000         -501.325             0.308            0.274 
  1100         -420.527             0.227            0.192 
  1200         -426.806             0.151            0.152 
  1300         -393.466             0.132            0.085 
  1400         -398.765             0.118            0.075 
  1500         -410.097             0.094            0.059 
  1600         -371.420             0.066            0.059 
  1700         -496.720             0.087            0.075 
  1800         -431.418             0.095            0.085 
  1900         -358.470             0.109            0.104 
  2000         -381.587             0.110            0.104 
  2100         -421.901             0.101            0.096 
  2200         -547.550             0.122            0.104 
  2300         -506.376             0.122            0.104 
  2400         -391.353             0.150            0.151 
  2500         -361.666             0.155            0.151 
  2600         -355.207             0.147            0.151 
  2700         -391.529             0.131            0.096 
  2800         -332.371             0.134            0.096 
  2900         -465.207             0.142            0.096 
  3000         -371.930             0.161            0.178 
  3100         -337.925             0.161            0.178 
  3200         -343.423             0.140            0.101 
  3300         -329.295             0.136            0.101 
  3400         -302.571             0.116            0.093 
  3500         -298.738             0.109            0.093 
  3600         -301.200             0.108            0.093 
  3700         -483.964             0.136            0.101 
  3800         -295.876             0.182            0.101 
  3900         -301.637             0.155            0.088 
  4000         -306.301             0.132            0.043 
  4100         -287.142             0.128            0.043 
  4200         -376.748             0.150            0.067 
  4300         -293.739             0.174            0.088 
  4400         -288.763             0.167            0.067 
  4500         -286.570             0.167            0.067 
  4600         -405.854             0.195            0.238 
  4700         -377.994             0.165            0.074 
  4800         -267.883             0.143            0.074 
  4900         -285.193             0.147            0.074 
  5000         -251.815             0.158            0.133 
  5100         -275.516             0.160            0.133 
  5200         -330.354             0.153            0.133 
  5300         -532.783             0.163            0.133 
  5400         -304.159             0.236            0.166 
  5500         -258.032             0.253            0.179 
  5600         -250.857             0.227            0.166 
  5700         -317.625             0.241            0.179 
  5800         -490.972             0.235            0.179 
  5900         -255.467             0.321            0.210 
  6000         -225.316             0.321            0.210 
  6100         -240.005             0.319            0.210 
  6200         -234.323             0.304            0.210 
  6300         -228.640             0.269            0.179 
  6400         -236.288             0.197            0.134 
  6500         -255.966             0.187            0.077 
  6600         -269.181             0.189            0.077 
  6700         -315.333             0.182            0.077 
  6800         -296.702             0.153            0.063 
  6900         -204.903             0.106            0.063 
  7000         -277.256             0.119            0.063 
  7100         -180.391             0.166            0.077 
  7200         -241.599             0.189            0.146 
  7300         -399.495             0.226            0.253 
  7400         -217.944             0.306            0.261 
  7500         -241.037             0.308            0.261 
  7600         -221.285             0.312            0.261 
  7700         -286.551             0.320            0.261 
  7800         -188.636             0.366            0.395 
  7900         -186.027             0.323            0.261 
  8000         -179.955             0.300            0.253 
  8100         -278.231             0.281            0.253 
  8200         -247.340             0.269            0.228 
  8300         -180.937             0.266            0.228 
  8400         -316.503             0.225            0.228 
  8500         -402.381             0.237            0.228 
  8600         -188.846             0.341            0.353 
  8700         -215.414             0.331            0.353 
  8800         -175.064             0.302            0.230 
  8900         -238.166             0.327            0.265 
  9000         -211.313             0.336            0.265 
  9100         -265.940             0.322            0.230 
  9200         -193.818             0.346            0.265 
  9300         -194.630             0.310            0.230 
  9400         -226.776             0.281            0.213 
  9500         -250.551             0.269            0.205 
  9600         -219.809             0.170            0.142 
  9700         -168.479             0.189            0.205 
  9800         -178.103             0.171            0.142 
  9900         -447.049             0.205            0.142 
  10000         -168.561             0.357            0.205 
Informational Message: The maximum number of iterations is reached! The algorithm may not have converged. 
This variational approximation is not guaranteed to be meaningful. 
Drawing a sample of size 1000 from the approximate posterior...  
COMPLETED. 
Finished in  1334.6 seconds.

10000回のイテレーションでも収束しなかったようですが、とりあえず結果を見てみましょう。

推定結果

まず、推定されたトレンドの形から確認しましょう。

m_gp_summary |>
  dplyr::filter(stringr::str_detect(variable, "m_full")) |>
  dplyr::bind_cols(
    date = seq(as.Date(covid_df$Date[1]), length = nrow(covid_df) + 60, by = "day")
  ) |>
  ggplot2::ggplot() +
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean), color = "black") +
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), 
                       fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2::xlab("日付") + 
  ggplot2::ylab("トレンド") + 
  ggplot2::ggtitle("東京都コロナ感染者数予測") + 
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") + 
  ggplot2::scale_x_date(date_breaks = "90 day")

long_trend.png

$\lambda_{2}$の範囲を0から1までに制限したモデル定式化の影響もありますが、緩やかな上昇傾向にあることがわかります。

続いては、前処理されたデータとモデルの予測結果を比較しましょう:

m_gp_summary |>
  dplyr::filter(stringr::str_detect(variable, "covid_predicted")) |>
  dplyr::bind_cols(
    covid = c(
      ((log(covid_df$Tokyo + 1) - mean(log(covid_df$Tokyo + 1)))/sd(log(covid_df$Tokyo + 1))),
      rep(0, 60)
      ),
    date = seq(as.Date(covid_df$Date[1]), length = nrow(covid_df) + 60, by = "day")
    ) |>
  dplyr::mutate(
    id = dplyr::row_number(),
    データ種類 = dplyr::case_when(
      id <= nrow(covid_df) - 30 ~ "学習データ",
      TRUE ~ "検証データ"
    )
  ) |> 
  ggplot2::ggplot() +
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean), color = "black") +
  ggplot2::geom_point(ggplot2::aes(x = date, y = covid, color = データ種類)) +
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), 
                       fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2::xlab("日付") + 
  ggplot2::ylab("感染者数(対数標準化)") + 
  ggplot2::ggtitle("東京都コロナ感染者数予測") + 
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") + 
  ggplot2::scale_colour_manual(
    values = c(
      学習データ = "blue",
      検証データ = "red"
    )
  ) + 
  ggplot2::scale_x_date(date_breaks = "90 day")

logged_pred.png

ただし、ゼロになっている赤い点は検証データにも含まれていない、本当の意味での予測区間になります。

まず、学習データの期間のデータの動きをガウス過程がうまく表現できているだけでなく、2023/4/9 ~ 2023/5/8の検証データ期間の着地(ビジネス用語すぎてごめんなさい、、、)もおおむねベイズ信用区間に入っています。

さらに、トレンドを定式化しているため、最後の学習データから離れたところにある2023/7/7あたりの予測結果も、きちんと上昇傾向になっていて、普通のガウス過程なら起こりそうなゼロへの回帰は起きていません。ゼロの代わりに推定されたトレンドのところに回帰しています。

なので、このモデルは完璧とは言えないが、信憑性のある予測を提供できると思われます。

最後に、モデルの予測結果を前処理前に戻し、生データと比較してみます:

m_gp_summary |>
  dplyr::filter(stringr::str_detect(variable, "covid_predicted")) |>
  dplyr::mutate(
    mean = exp(mean * sd(log(covid_df$Tokyo + 1)) + mean(log(covid_df$Tokyo + 1))),
    q5 = exp(q5 * sd(log(covid_df$Tokyo + 1)) + mean(log(covid_df$Tokyo + 1))),
    q95 = exp(q95 * sd(log(covid_df$Tokyo + 1)) + mean(log(covid_df$Tokyo + 1)))
  ) |>
  dplyr::bind_cols(covid = c(
    covid_df$Tokyo,
    rep(0, 60)
    ),
    date = seq(as.Date(covid_df$Date[1]), length = nrow(covid_df) + 60, by = "day")
    ) |>
  dplyr::mutate(
    id = dplyr::row_number(),
    データ種類 = dplyr::case_when(
      id <= nrow(covid_df) - 30 ~ "学習データ",
      TRUE ~ "検証データ"
    )
  ) |> 
  ggplot2::ggplot() +
  ggplot2::geom_line(ggplot2::aes(x = date, y = mean), color = "black") +
  ggplot2::geom_point(ggplot2::aes(x = date, y = covid, color = データ種類)) +
  ggplot2::geom_ribbon(ggplot2::aes(x = date, ymin = q5, ymax = q95), 
                       fill = ggplot2::alpha("blue", 0.3)) + 
  ggplot2::xlab("日付") + 
  ggplot2::ylab("感染者数") + 
  ggplot2::ggtitle("東京都コロナ感染者数予測") + 
  ggplot2::theme_gray(base_family = "HiraKakuPro-W3") + 
  ggplot2::scale_colour_manual(
    values = c(
      学習データ = "blue",
      検証データ = "red"
    )
  ) + 
  ggplot2::scale_x_date(date_breaks = "90 day")

raw_pred.png

ここでも性能が悪くないことがわかります。

結論

ここでは長期トレンド付きのガウス過程の性能を見てみました。もちろん、厳密にいうとトレンドなしのガウス過程、さらには観測値が正規分布に従うガウス過程とも比較しないといけないですが、これは今後別の記事でまとめます。

また、Stanの変分推論は収束していなくても、推定されたパラメーターが実務上使える場合があることも確認できました。

皆さんもぜひガウス過程という柔軟なベイズ機械学習の手法を活用してください!

参考文献

Dew, Ryan and Asim Ansari (2018) Bayesian Nonparametric Customer Base Analysis with Model-Based Visualizations. Marketing Science 37(2):216-235.

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1