1. はじめに
この記事では、共変量を取り込んだマルチホライズン予測(複数ステップ予測)が可能であり、また、解釈可能性を持つ Temporal Fusion Transformers(TFT)を紹介します。
まず、TFT を理解するための基本要素として、2章では時系列予測の概要、3章では時系列予測で用いられている既存の DNN モデルについて簡単に解説します(この分野に明るい方は読み飛ばしていただいて構いません)。その後、4章にて TFT を紹介したのち、5章では Darts というライブラリを利用して TFT の実装を行います。
2. 時系列予測
2-1. 時系列予測とは
時系列予測とは、時間経過に沿って変化するデータを解析し、将来を予測する分析手法です。身近な例では、天気予報などが挙げられます。時系列予測は、株価や為替の変動、市場規模の変化、感染症の感染者の拡大傾向把握など、様々な分野のデータ分析にも用いられています。課題に対して将来どのような対策を実施するのか、その意思決定に寄与する重要な分析手法の一つです。
2-2. ディープラーニングによる時系列予測
時系列予測の歴史は古く、従来から、自己回帰(AR; Auto Regression)モデルや状態空間モデルなどを活用して時系列予測が行われてきています。2010年代のディープラーニングの登場により、時系列予測においてもディープラーニングを用いたアルゴリズムが用いられるようになり、性能向上が急速に進んでいます。
3. 時系列予測に使われる既存のディープラーニング手法
機械学習プロジェクトでよく使われる問題設定では、入力されるデータの各サンプルは独立していることを前提としているため、各サンプルが過去の値に依存している時系列データでは、適切な予測をすることが困難となる場合が多いです。そのため、時系列予測(より一般的には系列データ)に特化したディープラーニングのアルゴリズムが提案されてきました。以降では、ディープラーニング登場以降に時系列予測で用いられてきた代表的なアルゴリズムを紹介します。
3-1. RNN(Recurrent Neural Network)
RNN は、時間方向につながりを持たせられるニューラルネットワークです。図1はRNNの構造を示しています。$X_{t}$は時刻$t$の入力を表しており、$X_{t}$がRNNに入力され、$h_{t}$が出力されるという構造をしています。そして、$h_{t}$が再度 RNN レイヤに入力されるループ構造となっています。
図1: RNNのループ構造(Understanding LSTM Networks より引用)
このループ構造を展開すると図2のようになります。
図2: RNNのループ構造の展開(Understanding LSTM Networks より引用)
$X_{0}, X_{1}, \cdots, X_{t}$ と $h_{0}, h_{1}, \cdots, h_{t}$ は、それぞれ各時刻での RNN レイヤへの入力と出力を表します。各時刻の RNN レイヤは、同じ時刻の入力とひとつ前の RNN レイヤの出力を入力としています。そして、各時刻の RNN レイヤからの出力を、さらに次の時刻の RNN レイヤへ渡す構造をしています。これにより、過去の情報を保持することを可能としています。
3-2. LSTM(Long Short-Term Memory)
RNNは、前の時刻の情報を、次の時刻に渡すアーキテクチャでした。しかし、この仕組みだと、古い時刻における入力情報を適切に考慮できない問題が発生してしまいます。この問題を解決するために考えられたのが、LSTMというアルゴリズムです。
LSTM のアーキテクチャを図3に示します。LSTMは、新たに記憶セルとゲートとよばれる情報を取捨選択する機構を持っています。記憶セルには、過去から時刻 $t$ までに必要な情報がすべて格納されています。この記憶セルを $\tanh$ 関数によって変換したものを出力し、次の時刻に渡しています。ゲートは、シグモイド関数 $\sigma$ を用いており、どの情報を記憶セルから出力するか(Outputゲート)、どの情報を記憶セルから削除するか(Forgetゲート)、どの情報を入力から記憶セルに追加するか(Inputゲート)を選択しています。これにより、古い時刻における情報も保持することができます。
図3: LSTM の構造(Understanding LSTM Networks より引用) 左右のAは前後時刻のLSTMレイヤを表し、$\sigma$はシグモイド関数、$\tanh$は$\tanh$関数を表す。
3-3. Transformer
3-3-1. Transformerの概要
Transformer は、自然言語処理の分野で発展した有名なアルゴリズムです。系列データを扱うのが得意であり、時系列予測でも利用されています。図4のように、Multi Head Attention という Attention 機構を含んだエンコーダーとデコーダーから成り立っています。エンコーダーおよびデコーダーは、図4のそれぞれ左半分、右半分の構造であり、Encoder input、Decoder input を入力として、Decoder output を出力します。
図4: Transformer のアーキテクチャ(Wu et al., 2020 Fig.1 より引用) 時刻 T1, T2, T3, T4 での入力を Encoder input、T4, T5での入力を Decoder input に与え、Decoder output から T5, T6 の結果を学習する。
3-3-2. Multi Head Attention
Q(Query)K(Key)V(Value)計算ベースのスケール化ドット積アテンション(ベクトル間類似性によるAttention)によって構成された Attention 機構です。Attention 機構は、注意表現(どの情報が重要であるか)を学習する仕組みで、膨大な情報の中から必要な情報に焦点を当て、注意を向けることが可能です。Multi Head Attention は、この Attenion 機構を並列に多数配置することで、さまざまな注意表現の学習を可能にしています。
4. Temporal Fusion Transformers (TFT)
4-1. TFTの概要
TFT は、Transformer をベースとしたマルチホライズン予測を行うアルゴリズムであり、優れた予測精度と解釈可能性を実現しています。従来のマルチホライズン予測では、静的共変量(例えば、月、曜日、祝日)を考慮した時系列予測を行うことが困難でした。TFT は、静的共変量を考慮し、時系列予測をすることが可能です。
4-2. TFTのアーキテクチャ
TFT のアーキテクチャを図5に示します。
図5: TFTのアーキテクチャ(Lim et al., 2020 Fig.2 より引用)
4-2-1. ゲーティング機構
ゲーティング機構は、データから学習を行い、柔軟な非線形処理を施すアーキテクチャです。これにより、様々なデータセットに対してネットワークの深さや複雑さを提供します。
4-2-2. 変数選択ネットワーク
各時間ステップで、ゲーティング機構を用いて関連する入力変数の選択を行います。これにより、予測において重要な変数を選択し、予測精度を低下させるような入力を削除することができます。
4-2-3. 静的説明変数エンコーダ
ゲーティング機構によって静的説明変数を統合し、4つの異なるコンテキストベクトルを生成します。このコンテキストベクトルによって表現された静的特徴量をネットワークに渡します。
4-2-4. Temporal Processing
観測された入力と既知の時間変化する入力の両方から、長期および短期の両方の時間的関係を学習するため、sequence-to-sequence レイヤーと Attention レイヤーを用いています。局所的な処理に関しては、sequence-to-sequence レイヤーを用いており、長期的依存関係には、解釈可能なMulti-Head Attention(後述)を使用しています。
4-2-5. 解釈可能なMulti-Head Attention
通常の Multi Head Attention では、各ヘッドで異なる注意表現の学習を行うため、Attention の重みだけでは各特徴量の重要度を理解することができませんでした。そこで、各ヘッドで値を共有するために、TFTではすべてのヘッドの加算集約を行っています。これによって、TFT の Multi Head Attention は単一のアテンションレイヤーと見做すことができ、その重みを分析することによって、特徴量の重要性を示すことが可能となっています。
4-3. TFTの利点
TFTは、静的共変量を取り込むことができ、解釈可能性のある Transformer ベースのアルゴリズムです。解釈性を保ちながら高精度な予測ができるアルゴリズムであるため、以下のようなケースで有用です。
- 複数の共変量からの情報を統合する必要がある時系列予測タスク
- 時間的依存関係とイベント(休日祝日など)を考慮する必要がある時系列予測タスク
- 解釈のしやすさや説明のしやすさが必要な時系列予測タスク
5. TFTを用いた時系列予測の実装例
今回はDartsというライブラリを使用し、TFTによる時系列予測を実装します。以降の実装は、Darts TFTチュートリアルを参考にしています。
5-1. 予測タスクおよび使用するデータセット
今回対象とする時系列データは Darts にサンプルとして組み込まれているWineDataset
を用います。1980年1月から1994年8月までのオーストラリアのワインメーカーによるワインの総売上高のデータセットです(Darts に組み込まれているサンプルデータの詳細については Darts.Datasets をご参照ください)。
from darts.datasets import WineDataset
data = WineDataset().load()
data.plot()
#dataのワイン総売上高をグラフで可視します。
こちらの時系列データを用いて、ワインの総売上高の時系列予測を行います。予測対象の期間は後半の2年間(1993年1月~1994年8月)とします。
5-2. ライブラリのインポート
必要なライブラリをインポートします。
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
from darts.dataprocessing.transformers import Scaler #時系列データの正規化を行います。
from darts.models import TFTModel #TFTモデル本体
from darts.metrics import mape #評価指標
from darts import TimeSeries
from darts.utils.timeseries_generation import datetime_attribute_timeseries #TimeSeriesの日付インデックスから年、月などの情報を取得します。
from darts.utils.likelihood_models import QuantileRegression #分位点予測に使用します。
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
5-3. 特徴量の作成
予測対象の時系列データと、共変量の時系列データを作成・加工します。予測対象はワインの総売上高、共変量として年と月を用います。
#1993、1994年の2年分を予測します。
#1992年12月までのデータを学習データ、それ以降のデータをテスト用データとして分割します。TimeSeries型のデータセットはsplit_afterで分割ができます。
cutoff = pd.Timestamp("19921201")
target_train, target_val = data.split_after(cutoff)
#予測ターゲットのデータを正規化します。
target_scaler = Scaler()
target_train_transformed = target_scaler.fit_transform(target_train)
target_val_transformed = target_scaler.transform(target_val)
cov_series = datetime_attribute_timeseries(data, attribute = "year", one_hot = False)
cov_series = cov_series.stack(
datetime_attribute_timeseries(data, attribute = "month", one_hot = False)
)
#共変量のデータセットも正規化します。
cov_train, cov_val = covariates_series.split_after(cutoff)
cov_scaler = Scaler()
cov_scaler.fit(cov_train)
cov_series_transformed = cov_scaler.transformed(cov_series)
5-4. モデル構築
過去24か月の情報をもとに将来12か月の予測をするモデルを構築します。TFTは分位点出力のため、分位点を指定します。
quantiles = [
0.01,
0.05,
0.1,
0.15,
0.2,
0.25,
0.3,
0.4,
0.5,
0.6,
0.7,
0.75,
0.8,
0.85,
0.9,
0.95,
0.99,
]
input_chunk_length = 24
forecast_horizon = 12
my_model = TFTModel(
input_chunk_length=input_chunk_length,
output_chunk_length=forecast_horizon,
hidden_size=64,
lstm_layers=1,
num_attention_heads=4,
dropout=0.1,
batch_size=16,
n_epochs=300,
likelihood=QuantileRegression(
quantiles=quantiles
),
random_state=42,
)
"""
TFTMoldelの引数の解説
input_chunk_length:予測パターンを検索する期間
output_chunk_length:input_chunk_lengthに対して予測を返す期間
hidden_size:隠れ層の数
lstm_layers:LSTMレイヤーの数
num_attention_heads:マルチヘッドアテンションのヘッド数
dropout:ドロップアウト数、過学習を抑えます。
batch_size:ミニバッチサイズ
n_epochs:エポック数
likelihood:尤度関数 今回は分位点予測を指定。
"""
my_model.fit(train_transformed, future_covariates=cov_series_transformed, verbose=True)
"""
model.fitの引数の解説
model.fit(data, future_covariates, past_covariates, verbose = True)
data: 予測ターゲットのデータセット
future_covariates:将来まで既知である共変量(例えば月、曜日、祝日など)
past_covariates:過去まで既知・測定可能である共変量(例えば天気など)
"""
5-5. 精度評価
構築したモデルを用いて、1993年から1994年8月までのワインの総売上高を予測します。
Darts での予測はmodel.predict()
で出力できます。評価指標としては、MAPE(平均絶対パーセント誤差)を用います。MAPE は、時系列予測や回帰でよく用いられる指標であり、誤差率を表します。この値が0に近いほど、精度が良いことを示しています。
pred = my_model.predict(n = 20, #予測期間分
num_samples = 200
)
予測結果と実際の値をプロットすると、図7のようになります。
pred_inverse = target_scaler.inverse_transform(pred)
# 予測結果のスケールを元に戻します。
target_series.plot(label = "acutual")
pred_inverse.plot(label = "predict")
pred_inverse.plot(low_quantile = 0.1, high_quantile = 0.9, label = f"{int(0.1 * 100)}-{int(0.9 * 100)}th percentiles")
グラフを概観すると、予測値と実際の売上高がおおむね一致しているように見えます。MAPEで見たTFTの精度は8.75%であり、ワインの売上の推移をおおむね捉えられていることがわかりました。
print("MAPE:{:.2f}%".format(mape(target_series, pred_inverse)))
MAPE:8.75%
5-6. 特徴量の重要度
TFT は、解釈可能な時系列予測のアルゴリズムです。具体的には、予測に用いた特徴量の重要度を出力することができ、Darts では TFTExplainer というモジュールで特徴量の重要度を確認することができます。
from darts.explainability import TFTExplainer
explainer = TFTExplainer(my_model)
explainability_result = explainer.explain()
explainer.plot_variable_selection(explainability_result)
出力された図8を見ると、年の情報が重要であるということがわかりました。今回の予測では、共変量が2つのみでしたが、共変量や説明変数の数が多い時系列予測では、特徴量重要度はモデルを解釈するだけでなく、モデルの精度を改善していくための示唆を得ることにも役立ちます。
6. まとめ
共変量を用いた時系列予測のアルゴリズム、Temporal Fusion Transformers(TFT)について、解説と実装例を示しました。共変量を用いた時系列予測を実装したい方のご参考になれば幸いです。
7. 参考文献
[1] Lim et al., 2020, Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting
[2] Pascanu et al., 2013, On the difficulty of training Recurrent Neural Networks
[3] Understanding LSTM Networks
[4] Wu et al., 2020, Deep Transformer Models for Time Series Forecasting: The Influenza Prevalence Case
[5] ゼロから作るDeep Learning2 -自然言語処理編
[6] Classification with Gated Residual and Variable Selection Networks
[7] Darts公式チュートリアル
[8] Darts公式チュートリアル データセット一覧