はじめに
ludwigはUberがオープンソースで公開している深層学習のツールです。公開当初触ってみたときに時系列データの予測の仕方がどうしても理解できなくて諦めてたら、久々にサイトを覗いたらExampleが追加されてたので通してみました。
Time series forecasting (weather data example)
ludwigのインストール方法はこちら
やってみた
まずKaggleのHistorical Hourly Weather Dataからtemperature.csv
をダウンロードする。19MBくらいある。
前処理する
temperature.csv.zip
を展開して、同ディレクトリで以下を実行しludwigのTimeseries data教師データのフォーマットに変換する。実行後temperature_la.csv
が作成される
import pandas as pd
from ludwig.utils.data_utils import add_sequence_feature_column
df = pd.read_csv(
'./temperature.csv',
usecols=['Los Angeles']
).rename(
columns={"Los Angeles": "temperature"}
).fillna(method='backfill').fillna(method='ffill')
# normalize
df.temperature = ((df.temperature-df.temperature.mean()) /
df.temperature.std())
train_size = int(0.6 * len(df))
vali_size = int(0.2 * len(df))
# train, validation, test split
df['split'] = 0
df.loc[
(
(df.index.values >= train_size) &
(df.index.values < train_size + vali_size)
),
('split')
] = 1
df.loc[
df.index.values >= train_size + vali_size,
('split')
] = 2
# prepare timeseries input feature colum
# (here we are using 20 preceeding values to predict the target)
add_sequence_feature_column(df, 'temperature', 20)
df.to_csv('./temperature_la.csv')
ざっと内容を解説すると、データの中からLos Angelesの気温データを抽出して正規化し、訓練用データと検証用データ、テスト用データに分けている。temperature_la.csv
を抜粋すると以下のようなデータになっている
,temperature,split,temperature_feature
0,0.15852827495691552,0,0.15852827495691552 0.15852827495691552 0.1582474327317841 0.15742075088857002 0.1565940688905703 0.155767387047365 0.15494070504936527 0.1541140232061512 0.15328734136294592 0.15246065936494618 0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797
1,0.15852827495691552,0,0.15852827495691552 0.15852827495691552 0.1582474327317841 0.15742075088857002 0.1565940688905703 0.155767387047365 0.15494070504936527 0.1541140232061512 0.15328734136294592 0.15246065936494618 0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797
2,0.1582474327317841,0,0.15852827495691552 0.15852827495691552 0.1582474327317841 0.15742075088857002 0.1565940688905703 0.155767387047365 0.15494070504936527 0.1541140232061512 0.15328734136294592 0.15246065936494618 0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797
...
29998,-0.7448274402494238,1,-1.0734021380921992 -0.8134808504515311 -0.7001221002050634 -0.2624693080225413 -0.13037097901473493 0.34581028797350705 0.38301615692731394 0.6306027705772593 0.6020077334323776 0.49130209973846456 0.4546906470231156 0.18005122243440005 0.12292921463144427 -0.013290250274656678 -0.5407643632787822 -0.35997977760969324 -0.4616719149634197 -0.6556067568342304 -1.0316981952210769 -0.7407349445690479
29999,-0.7608561525790946,1,-0.8134808504515311 -0.7001221002050634 -0.2624693080225413 -0.13037097901473493 0.34581028797350705 0.38301615692731394 0.6306027705772593 0.6020077334323776 0.49130209973846456 0.4546906470231156 0.18005122243440005 0.12292921463144427 -0.013290250274656678 -0.5407643632787822 -0.35997977760969324 -0.4616719149634197 -0.6556067568342304 -1.0316981952210769 -0.7407349445690479 -0.7448274402494238
30000,-0.773413311100598,1,-0.7001221002050634 -0.2624693080225413 -0.13037097901473493 0.34581028797350705 0.38301615692731394 0.6306027705772593 0.6020077334323776 0.49130209973846456 0.4546906470231156 0.18005122243440005 0.12292921463144427 -0.013290250274656678 -0.5407643632787822 -0.35997977760969324 -0.4616719149634197 -0.6556067568342304 -1.0316981952210769 -0.7407349445690479 -0.7448274402494238 -0.7608561525790946
...
45250,0.7915724346576326,2,0.25139538884944534 0.05637444967513269 -0.06590058361668798 -0.23306138862324166 -0.2748515898748757 -0.375457629925109 -0.5611918577101599 -0.6648934682234834 -0.7159703808643704 -0.8707489040185808 -0.9744505145319043 -1.0255274271727914 -1.1168467558337805 -1.1818537355585514 -1.225191722041726 -0.8057419242938189 -0.282590516032588 0.18329283866158427 0.41855619385599024 0.748234448174458
45251,0.743591092479827,2,0.05637444967513269 -0.06590058361668798 -0.23306138862324166 -0.2748515898748757 -0.375457629925109 -0.5611918577101599 -0.6648934682234834 -0.7159703808643704 -0.8707489040185808 -0.9744505145319043 -1.0255274271727914 -1.1168467558337805 -1.1818537355585514 -1.225191722041726 -0.8057419242938189 -0.282590516032588 0.18329283866158427 0.41855619385599024 0.748234448174458 0.7915724346576326
45252,0.6321505558088,2,-0.06590058361668798 -0.23306138862324166 -0.2748515898748757 -0.375457629925109 -0.5611918577101599 -0.6648934682234834 -0.7159703808643704 -0.8707489040185808 -0.9744505145319043 -1.0255274271727914 -1.1168467558337805 -1.1818537355585514 -1.225191722041726 -0.8057419242938189 -0.282590516032588 0.18329283866158427 0.41855619385599024 0.748234448174458 0.7915724346576326 0.743591092479827
カンマ区切り左から、番号
,目標値
,データ区切り
,特徴データ
となっている。特徴データは1行ごとにスライドしていて、データ区切りは0
が訓練用データ、1
が検証用データ2
がテスト用データとなっているようだ。
データ整形処理は以下で行われていて特徴データを20列のスライドにしているようだ。
add_sequence_feature_column(df, 'temperature', 20)
モデル定義ファイルを作成する
model_definition.yaml
を同ディレクトリに作成する
input_features:
-
name: temperature_feature
type: timeseries
encoder: rnn
embedding_size: 32
state_size: 32
output_features:
-
name: temperature
type: numerical
RNNで入力データ形式がtimeseries
(temperature_la.csv
のデータフォーマット)、出力が数値。
学習する
以下で学習とテストが実行される。結果はresults
ディレクトリが作成されてその中に保存される
ludwig experiment --data_csv ./temperature_la.csv --model_definition_file model_definition.yaml
ログはこんな感じ
███████████████████████
█ █ █ █ ▜█ █ █ █ █ █
█ █ █ █ █ █ █ █ █ █ ███
█ █ █ █ █ █ █ █ █ ▌ █
█ █████ █ █ █ █ █ █ █ █
█ █ ▟█ █ █ █
███████████████████████
ludwig v0.2 - Experiment
Experiment name: experiment
Model name: run
Output path: results/experiment_run_0
ludwig_version: '0.2'
...(中略)...
╒══════════╕
│ TRAINING │
╘══════════╛
Epoch 1
Training: 100%|██████████████████████████████| 213/213 [00:01<00:00, 112.75it/s]
Evaluation train: 100%|██████████████████████| 213/213 [00:00<00:00, 230.19it/s]
Evaluation vali : 100%|████████████████████████| 71/71 [00:00<00:00, 292.57it/s]
Evaluation test : 100%|████████████████████████| 71/71 [00:00<00:00, 297.32it/s]
Took 3.3149s
╒═══════════════╤════════╤══════════════════════╤═══════════════════════╤════════╤═════════╕
│ temperature │ loss │ mean_squared_error │ mean_absolute_error │ r2 │ error │
╞═══════════════╪════════╪══════════════════════╪═══════════════════════╪════════╪═════════╡
│ train │ 0.0868 │ 0.0868 │ 0.2130 │ 0.0072 │ 0.0080 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ vali │ 0.1023 │ 0.1023 │ 0.2330 │ 0.0064 │ 0.0383 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ test │ 0.0545 │ 0.0545 │ 0.1719 │ 0.0066 │ 0.0517 │
╘═══════════════╧════════╧══════════════════════╧═══════════════════════╧════════╧═════════╛
Validation loss on combined improved, model saved
...(中略)...
Epoch 97
Training: 100%|██████████████████████████████| 213/213 [00:01<00:00, 142.55it/s]
Evaluation train: 100%|██████████████████████| 213/213 [00:00<00:00, 318.74it/s]
Evaluation vali : 100%|████████████████████████| 71/71 [00:00<00:00, 321.90it/s]
Evaluation test : 100%|████████████████████████| 71/71 [00:00<00:00, 305.79it/s]
Took 2.6192s
╒═══════════════╤════════╤══════════════════════╤═══════════════════════╤════════╤═════════╕
│ temperature │ loss │ mean_squared_error │ mean_absolute_error │ r2 │ error │
╞═══════════════╪════════╪══════════════════════╪═══════════════════════╪════════╪═════════╡
│ train │ 0.0211 │ 0.0211 │ 0.0902 │ 0.0077 │ 0.0082 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ vali │ 0.0451 │ 0.0451 │ 0.1303 │ 0.0072 │ 0.0098 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ test │ 0.0152 │ 0.0152 │ 0.0801 │ 0.0075 │ 0.0086 │
╘═══════════════╧════════╧══════════════════════╧═══════════════════════╧════════╧═════════╛
Last improvement of loss on combined happened 5 epochs ago
EARLY STOPPING due to lack of validation improvement, it has been 5 epochs since last validation accuracy improvement
Best validation model epoch: 92
Best validation model loss on validation set combined: 0.045013064149334946
Best validation model loss on test set combined: 0.015312512894541722
Finished: experiment_run
Saved to: results/experiment_run_0
╒═════════╕
│ PREDICT │
╘═════════╛
Evaluation: 100%|██████████████████████████████| 71/71 [00:00<00:00, 193.66it/s]
===== temperature =====
error: 0.008600296025459501
loss: 0.015234892950008466
mean_absolute_error: 0.08010127209373875
mean_squared_error: 0.015234892950008466
r2: 0.007505640815540167
Finished: experiment_run
Saved to: results/experiment_run_0
なんかできたっぽい。学習済みモデルは./results/experiment_run_0
に保存されたようだ
予測してみる
以下のようなデータで予測してみる
,temperature,split,temperature_feature
0,0,2,0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797 0.1433671584704839 0.14254047647249293 0.14171379462927888 0.14088711263128792 0.14006043078807387 0.1392337487900829 0.13840706694686883 0.4835631735807611 0.734304381090574 1.0964861252714315
ludwig predict --data_csv=./test.csv --model_path=./results/experiment_run_0/model/
███████████████████████
█ █ █ █ ▜█ █ █ █ █ █
█ █ █ █ █ █ █ █ █ █ ███
█ █ █ █ █ █ █ █ █ ▌ █
█ █████ █ █ █ █ █ █ █ █
█ █ ▟█ █ █ █
███████████████████████
ludwig v0.2 - Predict
Dataset path: ./test.csv
Model path: ./results/experiment_run_0/model/
Output path: results_0
Loading metadata from: ./results/experiment_run_0/model/train_set_metadata.json
╒═══════════════╕
│ LOADING MODEL │
╘═══════════════╛
╒═════════╕
│ PREDICT │
╘═════════╛
Use standard file APIs to check for files with this prefix.
Evaluation: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 10.38it/s]
Saved to: results_0
results_0
ディレクトリが作成されてtemperature_predictions.csv
が保存される
1.1143011
なんかそれっぽい予測値が入ってた
おわりに
リリースされた当初はそこそこ使われるんじゃないかなーと思っていたludwigですが、全然使われていないっぽい。tfとかtorchとかでガリガリ書かなくていい用途やそこそこの推論処理をシステムに組み込むにはちょうどいいツールじゃないかと思ってるんだけど、あんまりそういう用途がないのだろうか。あとAutoML Tablesとか便利なGUIも出てきたしそっちのほうが便利なのかな。うん、便利だよね。
でもなんかludwigは個人的に無視できないです。なんとなく。