サマリ
- 日付情報の整数値をDL(
torch.Tensor
)に通したあと、復元すると微妙にずれる現象の概説 - 微妙にズレた日付を復元するサンプルを提供
環境
- jupyter notebook
- たぶん、Colab でも動くと思う
- pytorch: 1.7.0
現象
準備
import numpy
import pandas
import torch
index_date = pandas.date_range("2016-01-01", "2018-12-31")
index_date
DatetimeIndex(['2016-01-01', '2016-01-02', '2016-01-03', '2016-01-04',
'2016-01-05', '2016-01-06', '2016-01-07', '2016-01-08',
'2016-01-09', '2016-01-10',
...
'2018-12-22', '2018-12-23', '2018-12-24', '2018-12-25',
'2018-12-26', '2018-12-27', '2018-12-28', '2018-12-29',
'2018-12-30', '2018-12-31'],
dtype='datetime64[ns]', length=1096, freq='D')
df = pandas.DataFrame([])
df["time_index"] = index_date.astype(numpy.int64)
df.time_index[:5]
0 1451606400000000000
1 1451692800000000000
2 1451779200000000000
3 1451865600000000000
4 1451952000000000000
Name: time_index, dtype: int64
現象の確認
series/numpy のデータを to_datetime
すると正しく復元できる
pandas.to_datetime(df.time_index)[:5]
0 2016-01-01
1 2016-01-02
2 2016-01-03
3 2016-01-04
4 2016-01-05
Name: time_index, dtype: datetime64[ns]
一方で、torch.Tensor
で、テンソルに変換してから復元しようと、to_datetime
をすると微妙にずれる(型変換時の精度誤差っぽい)
pandas.to_datetime(torch.Tensor(df.time_index))[:5]
DatetimeIndex(['2016-01-01 00:00:49.632313344',
'2016-01-01 23:59:21.295093760',
'2016-01-03 00:00:10.396827648',
'2016-01-04 00:00:59.498561536',
'2016-01-04 23:59:31.161341952'],
dtype='datetime64[ns]', freq=None)
対策例
対策として、以下のような関数を作成する
import datetime
def to_date(ti: numpy.array):
_ti = pandas.to_datetime(ti)
ti_ser = pandas.Series(_ti, name="time_index")
def _adjust_date(ts):
dte = ts.to_pydatetime()
dte += datetime.timedelta(hours=1)
return datetime.datetime(year=dte.year, month=dte.month, day=dte.day, hour=0, minute=0, second=0)
return ti_ser.apply(_adjust_date)
series に対しては、to_datetime
と同じ結果になる
ti = to_date(df.time_index)
ti.head()
0 2016-01-01
1 2016-01-02
2 2016-01-03
3 2016-01-04
4 2016-01-05
Name: time_index, dtype: datetime64[ns]
(ti == pandas.to_datetime(df.time_index)).all()
True
torch.Tensor
を通した後も、同じ結果になる
ti_restored = to_date(torch.Tensor(df.time_index))
ti_restored.head()
0 2016-01-01
1 2016-01-02
2 2016-01-03
3 2016-01-04
4 2016-01-05
Name: time_index, dtype: datetime64[ns]
(ti == ti_restored).all()
True
まとめ
- 微妙にめんどくさいので、参考になれば幸いです