TL;DR
- ICLR2020 に Accept されていた時系列予測アルゴリズム N-BEATS を試した
- N-BEATS は時系列予測ライブラリ GluonTS から簡単に利用可能
- Beijing PM2.5 データセットの予測タスクに使ってみたところ、ある程度の短期予測ができることを確認
- SageMaker から N-BEATS を実行するスクリプトはこちら
N-BEATS
N-BEATSは、Yoshua Bengio 先生と Element AI (カナダの AI スタートアップ) で開発された新しい時系列予測アルゴリズムで、ICLR2020 に Accept されています。
B.N. Oreshkin, D. Carpov, N. Chapados, Y. Bengio
N-BEATS: Neural basis expansion analysis for interpretable time series forecasting, ICLR 2020.
https://arxiv.org/abs/1905.10437
Abstract に書かれていますが、M4 という時系列データのコンペティションで昨年の Winner の結果よりも 3% 精度が良いらしいです。詳細は 3章に書かれていますが、ざっと説明していきましょう。以下は、論文中に示されている全体像です。
Forecast と Backcast
予測モデルを構築する際、入力 $\boldsymbol{x}$ にもとづく予測 $\hat{\boldsymbol{y}}$ と真値 $\boldsymbol{y}$ を比較して、予測がなるべく真値に近いように学習をしていくのが一般的です。それに加えて、N-BEATS では、入力 $\boldsymbol{x}$ から入力そのものに対する予測 $\hat{\boldsymbol{x}}$ を行う Backcast を採用しています。Backcast は Auto-Encoder と同じような構成ですね。$\hat{\boldsymbol{x}}$ を正しく予測できている (Backcast が上手く行っている) ということは、その Block は入力 $\boldsymbol{x}$ をよく近似しているので、その後の Block では、近似できなかった分だけを予測できるようにします。ここで導入されるのが、次に説明する Residual です。
Residual
Residual を取り入れるアプローチは ResNet 以降、DenseNet など画像認識系のネットワークでは採用されているケースが多いですね。大きなネットワークで学習していくと、精度が上がってきたときに、Residual が小さくなってくるため、勾配消失などの問題が発生しやすいです。そこで画像認識系のネットワークでは、この残差自体を予測しようとするアプローチがとられています。
N-BEATS ではちょっと違った方法をとっています。Backcast のところで説明しましたが、 Residual を予測しているわけではなく、以下の式に示すように予測値から Residual を計算しており、Block の近似の良さを表す指標として使っています。Backcast がうまくいって Residual が 0 になれば、以降の Block の予測を簡単にすることができます。予測自体は、以下の式のように、各 Block からの予測値の総和として求めます。各 Block は、それより上の Block で予測できなかった部分を補完しており、それらを総和すれば良いという考え方です。
$$
{\boldsymbol{x}}_\ell = {\boldsymbol{x}} _{\ell-1} - \hat{{\boldsymbol{x}}} _{\ell-1}, \ \ {\boldsymbol{y}} = \sum _{\ell} \hat{{\boldsymbol{y}}} _{\ell}
$$
説明可能モデル
ここは説明を省略しますが、各 Block は線形モデル (全結合) なので、その係数を見れば、どこが予測に反応しているのかを確認することができます。例えば、ちょうど1週間前の過去の数値に対する係数が大きければ、その過去の数値が予測に貢献していると考えられます。より周期的な解析をしたい場合は、フーリエ級数展開を行って、そのフーリエ係数を確認するという方法が提案されています。
Ensembling
論文の中でも書かれていますが、ぶっちゃけコンペで勝とうと思うと、コンペで多用されている Ensembling は必要不可欠とのことです。ここでは2種類の Ensembling を行っています。
- 複数種類のメトリクスを利用 (sMAPE, MASE, MAPE)
- 過去のデータを利用する期間を変える (2H, ..., 7H; H は予測の期間)
最終的には Bagging によってこれらの結果を集約して予測値を出力します。
GluonTS
GluonTS は確率的な予測もサポートした時系列データ向けのライブラリで、Deep Learning を利用した予測アルゴリズムが実装されています。
GluonTS - Probabilistic Time Series Modeling
https://gluon-ts.mxnet.io/
学習・予測までの流れは以下の3ステップです。
- Estimator を使って学習アルゴリズムやハイパーパラメータ (予測期間、学習率など)を指定する
- Estimator に対して学習データを渡して、train() で学習する
- Estimator.predict で予測する
学習アルゴリズムを変えたい場合は、1 で呼び出す Estimator を変えれば良いです。例えば以下のような Estimator が用意されています。
- NBEATSEstimator: 1つの N-BEATS で学習
- NBEATSEnsembleEstimator: N-BEATS のアンサンブル学習
- DeepAREstimator: DeepAR を利用して学習
- NPTSEstimator: NPTS を利用して学習
それぞれハイパーパラメータが異なるので、アルゴリズムを切り替える際は、Estimator の名前だけでなくて、ハイパーパラメータも修正してください。例えば、Ensembling の説明に関係した違いでいうと、NBEATSEstimator は過去のデータの期間 (H) を1つしか指定できませんが、NBEATSEnsembleEstimator は 2H から 7H などのようにいくつ指定してアンサンブルできるようになっています。
コード
今回は北京のPM2.5データセット向けに、NBEATSEstimator を使った実装を Amazon SageMaker を使って行いました。SageMakerで実行可能なノートブックはこちらです。
https://github.com/harusametime/amazon-sagemaker-examples-jp/blob/master/gluonts/pm25_gluonts.ipynb
SageMaker でノートブックを開くとあとはポチポチとShift-Enter で実行できると思います。もし SageMaker を利用せず、コードだけ知りたいということであれば、学習・推論コードを src フォルダに置いていますので見てみてください。
コードの詳細やデータ前処理などは Github の上のノートブック (上のリンク) を見てみてください。
実行結果
PM2.5 のみを予測するモデルを 500 epoch 回して学習しました。データセットの最終月である2014年12月をテストデータに、それ以前を学習データにしました。この時点で学習時のロスは 0.8 程度まで下がっていましたが、まだまだ下がりそうでした。
まずは 2014年12月の第1週 (12/1-12/7の168時間) を入力して、その後の12時間 (12/8の0時から12時) を予測した結果です。青い線が正解、オレンジの線が予測です。11時頃からの急上昇を思いっきり外してますね。12/8の0時頃に11-12時間先を予測するのは難しいでしょうか。
ちょっと問題を簡単にして、毎回12時間先を予測はしますが、その予測は1時間毎に最新の168時間のデータで見直しができると考えましょう。つまり1時間毎に予測して直近の結果だけをとってきます。以下はこれをテストデータ1ヶ月分で回した結果です。細かい差異はいろいろありそうですが、概ねスパイクにも追従できているかな?
じゃあ間をとって6時間毎に見直しができるものとしましょう。うーむ、さっきよりスパイクへの反応が遅いですね。2014/12/9 付近の遅れがやや気になります。
結論・補足
- 他と十分な比較はできていないですが、N-BEATS は短期予測なら結構使えそうです。学習に時間がかかるので試していないですが、Ensemble 版も試してみたいところです。
- 今回はデータセットの中の PM2.5 のみのデータを使いました。それ以外のデータ (DEWPなど) も同時に予測しようとすると、PM2.5 自体の予測精度が出ない傾向がありました。あまりたくさん学習を回したわけではないですが、それ以外のデータの予測でロスが下がり、学習が収束しているように見受けられたので、今回は途中で中断しています。
- 同時に予測するのではなく、予測を支援する関連情報としてDEWPなどを入れることはできるので (Dynamic Feature)、それを試す価値はあるかも。
- 手元では N-BEATS 以外に LSTNet なども試しました。GluonTS だとアルゴリズムの切り替えもそこまで苦労なかったです。