LoginSignup
4
5

More than 3 years have passed since last update.

ludwigで時系列データの予測をしてみる

Posted at

はじめに

ludwigはUberがオープンソースで公開している深層学習のツールです。公開当初触ってみたときに時系列データの予測の仕方がどうしても理解できなくて諦めてたら、久々にサイトを覗いたらExampleが追加されてたので通してみました。

Time series forecasting (weather data example)

ludwigのインストール方法はこちら

やってみた

まずKaggleのHistorical Hourly Weather Dataからtemperature.csvをダウンロードする。19MBくらいある。

前処理する

temperature.csv.zipを展開して、同ディレクトリで以下を実行しludwigのTimeseries data教師データのフォーマットに変換する。実行後temperature_la.csvが作成される

convert.py
import pandas as pd
from ludwig.utils.data_utils import add_sequence_feature_column

df = pd.read_csv(
    './temperature.csv',
    usecols=['Los Angeles']
).rename(
    columns={"Los Angeles": "temperature"}
).fillna(method='backfill').fillna(method='ffill')

# normalize
df.temperature = ((df.temperature-df.temperature.mean()) /
                  df.temperature.std())

train_size = int(0.6 * len(df))
vali_size = int(0.2 * len(df))

# train, validation, test split
df['split'] = 0
df.loc[
    (
        (df.index.values >= train_size) &
        (df.index.values < train_size + vali_size)
    ),
    ('split')
] = 1
df.loc[
    df.index.values >= train_size + vali_size,
    ('split')
] = 2

# prepare timeseries input feature colum
# (here we are using 20 preceeding values to predict the target)
add_sequence_feature_column(df, 'temperature', 20)
df.to_csv('./temperature_la.csv')

ざっと内容を解説すると、データの中からLos Angelesの気温データを抽出して正規化し、訓練用データと検証用データ、テスト用データに分けている。temperature_la.csvを抜粋すると以下のようなデータになっている

temperature_la.csv
,temperature,split,temperature_feature
0,0.15852827495691552,0,0.15852827495691552 0.15852827495691552 0.1582474327317841 0.15742075088857002 0.1565940688905703 0.155767387047365 0.15494070504936527 0.1541140232061512 0.15328734136294592 0.15246065936494618 0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797
1,0.15852827495691552,0,0.15852827495691552 0.15852827495691552 0.1582474327317841 0.15742075088857002 0.1565940688905703 0.155767387047365 0.15494070504936527 0.1541140232061512 0.15328734136294592 0.15246065936494618 0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797
2,0.1582474327317841,0,0.15852827495691552 0.15852827495691552 0.1582474327317841 0.15742075088857002 0.1565940688905703 0.155767387047365 0.15494070504936527 0.1541140232061512 0.15328734136294592 0.15246065936494618 0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797
...
29998,-0.7448274402494238,1,-1.0734021380921992 -0.8134808504515311 -0.7001221002050634 -0.2624693080225413 -0.13037097901473493 0.34581028797350705 0.38301615692731394 0.6306027705772593 0.6020077334323776 0.49130209973846456 0.4546906470231156 0.18005122243440005 0.12292921463144427 -0.013290250274656678 -0.5407643632787822 -0.35997977760969324 -0.4616719149634197 -0.6556067568342304 -1.0316981952210769 -0.7407349445690479
29999,-0.7608561525790946,1,-0.8134808504515311 -0.7001221002050634 -0.2624693080225413 -0.13037097901473493 0.34581028797350705 0.38301615692731394 0.6306027705772593 0.6020077334323776 0.49130209973846456 0.4546906470231156 0.18005122243440005 0.12292921463144427 -0.013290250274656678 -0.5407643632787822 -0.35997977760969324 -0.4616719149634197 -0.6556067568342304 -1.0316981952210769 -0.7407349445690479 -0.7448274402494238
30000,-0.773413311100598,1,-0.7001221002050634 -0.2624693080225413 -0.13037097901473493 0.34581028797350705 0.38301615692731394 0.6306027705772593 0.6020077334323776 0.49130209973846456 0.4546906470231156 0.18005122243440005 0.12292921463144427 -0.013290250274656678 -0.5407643632787822 -0.35997977760969324 -0.4616719149634197 -0.6556067568342304 -1.0316981952210769 -0.7407349445690479 -0.7448274402494238 -0.7608561525790946
...
45250,0.7915724346576326,2,0.25139538884944534 0.05637444967513269 -0.06590058361668798 -0.23306138862324166 -0.2748515898748757 -0.375457629925109 -0.5611918577101599 -0.6648934682234834 -0.7159703808643704 -0.8707489040185808 -0.9744505145319043 -1.0255274271727914 -1.1168467558337805 -1.1818537355585514 -1.225191722041726 -0.8057419242938189 -0.282590516032588 0.18329283866158427 0.41855619385599024 0.748234448174458
45251,0.743591092479827,2,0.05637444967513269 -0.06590058361668798 -0.23306138862324166 -0.2748515898748757 -0.375457629925109 -0.5611918577101599 -0.6648934682234834 -0.7159703808643704 -0.8707489040185808 -0.9744505145319043 -1.0255274271727914 -1.1168467558337805 -1.1818537355585514 -1.225191722041726 -0.8057419242938189 -0.282590516032588 0.18329283866158427 0.41855619385599024 0.748234448174458 0.7915724346576326
45252,0.6321505558088,2,-0.06590058361668798 -0.23306138862324166 -0.2748515898748757 -0.375457629925109 -0.5611918577101599 -0.6648934682234834 -0.7159703808643704 -0.8707489040185808 -0.9744505145319043 -1.0255274271727914 -1.1168467558337805 -1.1818537355585514 -1.225191722041726 -0.8057419242938189 -0.282590516032588 0.18329283866158427 0.41855619385599024 0.748234448174458 0.7915724346576326 0.743591092479827

カンマ区切り左から、番号,目標値,データ区切り,特徴データとなっている。特徴データは1行ごとにスライドしていて、データ区切りは0が訓練用データ、1が検証用データ2がテスト用データとなっているようだ。

データ整形処理は以下で行われていて特徴データを20列のスライドにしているようだ。

convert.py
add_sequence_feature_column(df, 'temperature', 20)

モデル定義ファイルを作成する

model_definition.yamlを同ディレクトリに作成する

model_definition.yaml
input_features:
    -
        name: temperature_feature
        type: timeseries
        encoder: rnn
        embedding_size: 32
        state_size: 32

output_features:
    -
        name: temperature
        type: numerical

RNNで入力データ形式がtimeseries(temperature_la.csvのデータフォーマット)、出力が数値。

学習する

以下で学習とテストが実行される。結果はresultsディレクトリが作成されてその中に保存される

ludwig experiment --data_csv ./temperature_la.csv --model_definition_file model_definition.yaml

ログはこんな感じ

███████████████████████
█ █ █ █  ▜█ █ █ █ █   █
█ █ █ █ █ █ █ █ █ █ ███
█ █   █ █ █ █ █ █ █ ▌ █
█ █████ █ █ █ █ █ █ █ █
█     █  ▟█     █ █   █
███████████████████████
ludwig v0.2 - Experiment

Experiment name: experiment
Model name: run
Output path: results/experiment_run_0


ludwig_version: '0.2'

...(中略)...

╒══════════╕
│ TRAINING │
╘══════════╛

Epoch   1
Training: 100%|██████████████████████████████| 213/213 [00:01<00:00, 112.75it/s]
Evaluation train: 100%|██████████████████████| 213/213 [00:00<00:00, 230.19it/s]
Evaluation vali : 100%|████████████████████████| 71/71 [00:00<00:00, 292.57it/s]
Evaluation test : 100%|████████████████████████| 71/71 [00:00<00:00, 297.32it/s]
Took 3.3149s
╒═══════════════╤════════╤══════════════════════╤═══════════════════════╤════════╤═════════╕
│ temperature   │   loss │   mean_squared_error │   mean_absolute_error │     r2 │   error │
╞═══════════════╪════════╪══════════════════════╪═══════════════════════╪════════╪═════════╡
│ train         │ 0.0868 │               0.0868 │                0.2130 │ 0.0072 │  0.0080 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ vali          │ 0.1023 │               0.1023 │                0.2330 │ 0.0064 │  0.0383 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ test          │ 0.0545 │               0.0545 │                0.1719 │ 0.0066 │  0.0517 │
╘═══════════════╧════════╧══════════════════════╧═══════════════════════╧════════╧═════════╛
Validation loss on combined improved, model saved

...(中略)...

Epoch  97
Training: 100%|██████████████████████████████| 213/213 [00:01<00:00, 142.55it/s]
Evaluation train: 100%|██████████████████████| 213/213 [00:00<00:00, 318.74it/s]
Evaluation vali : 100%|████████████████████████| 71/71 [00:00<00:00, 321.90it/s]
Evaluation test : 100%|████████████████████████| 71/71 [00:00<00:00, 305.79it/s]
Took 2.6192s
╒═══════════════╤════════╤══════════════════════╤═══════════════════════╤════════╤═════════╕
│ temperature   │   loss │   mean_squared_error │   mean_absolute_error │     r2 │   error │
╞═══════════════╪════════╪══════════════════════╪═══════════════════════╪════════╪═════════╡
│ train         │ 0.0211 │               0.0211 │                0.0902 │ 0.0077 │  0.0082 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ vali          │ 0.0451 │               0.0451 │                0.1303 │ 0.0072 │  0.0098 │
├───────────────┼────────┼──────────────────────┼───────────────────────┼────────┼─────────┤
│ test          │ 0.0152 │               0.0152 │                0.0801 │ 0.0075 │  0.0086 │
╘═══════════════╧════════╧══════════════════════╧═══════════════════════╧════════╧═════════╛
Last improvement of loss on combined happened 5 epochs ago

EARLY STOPPING due to lack of validation improvement, it has been 5 epochs since last validation accuracy improvement

Best validation model epoch: 92
Best validation model loss on validation set combined: 0.045013064149334946
Best validation model loss on test set combined: 0.015312512894541722

Finished: experiment_run
Saved to: results/experiment_run_0

╒═════════╕
│ PREDICT │
╘═════════╛

Evaluation: 100%|██████████████████████████████| 71/71 [00:00<00:00, 193.66it/s]

===== temperature =====
error: 0.008600296025459501
loss: 0.015234892950008466
mean_absolute_error: 0.08010127209373875
mean_squared_error: 0.015234892950008466
r2: 0.007505640815540167

Finished: experiment_run
Saved to: results/experiment_run_0

なんかできたっぽい。学習済みモデルは./results/experiment_run_0に保存されたようだ

予測してみる

以下のようなデータで予測してみる

test.csv
,temperature,split,temperature_feature
0,0,2,0.1516339775217321 0.15080729552374117 0.1499806136805271 0.14915393168253613 0.14832724983932208 0.147500567996108 0.14667388599811704 0.145847204154903 0.14502052215691202 0.14419384031369797 0.1433671584704839 0.14254047647249293 0.14171379462927888 0.14088711263128792 0.14006043078807387 0.1392337487900829 0.13840706694686883 0.4835631735807611 0.734304381090574 1.0964861252714315
ludwig predict --data_csv=./test.csv --model_path=./results/experiment_run_0/model/
███████████████████████
█ █ █ █  ▜█ █ █ █ █   █
█ █ █ █ █ █ █ █ █ █ ███
█ █   █ █ █ █ █ █ █ ▌ █
█ █████ █ █ █ █ █ █ █ █
█     █  ▟█     █ █   █
███████████████████████
ludwig v0.2 - Predict

Dataset path: ./test.csv
Model path: ./results/experiment_run_0/model/
Output path: results_0

Loading metadata from: ./results/experiment_run_0/model/train_set_metadata.json

╒═══════════════╕
│ LOADING MODEL │
╘═══════════════╛

╒═════════╕
│ PREDICT │
╘═════════╛

Use standard file APIs to check for files with this prefix.
Evaluation: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 10.38it/s]
Saved to: results_0

results_0ディレクトリが作成されてtemperature_predictions.csvが保存される

temperature_predictions.csv
1.1143011

なんかそれっぽい予測値が入ってた

おわりに

リリースされた当初はそこそこ使われるんじゃないかなーと思っていたludwigですが、全然使われていないっぽい。tfとかtorchとかでガリガリ書かなくていい用途やそこそこの推論処理をシステムに組み込むにはちょうどいいツールじゃないかと思ってるんだけど、あんまりそういう用途がないのだろうか。あとAutoML Tablesとか便利なGUIも出てきたしそっちのほうが便利なのかな。うん、便利だよね。

でもなんかludwigは個人的に無視できないです。なんとなく。

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5