æŠèŠ
å人çãªåå¿é²ãå ŒããPyTorchã®åºæ¬çãªè§£èª¬ãšãŸãšãã§ããLSTMãå©çšããæ¥çµ225ãå©çšããäºæž¬ã®ïŒåç®ãšãªããŸããä»åãæ¥çµ225ã®å§å€ãäºæž¬ããã¿ã€ãã§ããã第ïŒåã第ïŒåãšã¯ç°ãªããè€æ°æå ãŸã§ã®äºæž¬ãç®æšãšããŠããŸããLSTMãå©çšããç³»å倿ã¢ãã«ã«ãã£ãŠãæ¯æ5æå ãŸã§äºæž¬ããã¢ãã«ãäœæããŠã¿ãŸããäžèšã®ãããªã°ã©ãã宿ããŸãã
å³ïŒ5æå ãŸã§äºæž¬ãããã®ãã€ãªããã°ã©ã
æ¹é
- ã§ããã ãåãã³ãŒãé²è¡
- ã§ããã ãç°¡æœïŒçްããå 容ã¯å²æïŒ
- ç¹åŸŽéãªã©ã®éšåïŒãããŠæ°å€ã§èšå ¥ïŒã©ã®ããã«å€ããããããããããïŒ
æŒç¿çšã®ãã¡ã€ã«
- ããŒã¿ïŒnikkei_225.csv
- ã³ãŒãïŒsample_11.ipynb
1. ð¹ RNNãšç³»å倿
1.1 RNNã®åŸ©ç¿
$t-1$æãŸã§ã®éå»ã®æ
å ±ã®ç¹åŸŽéã§ããå±¥æŽ$h_{t-1}$ãšãã®æç¹ã§ã®ããŒã¿$x_t$ãã
$$h_t=\tanh(W_x x_t + W_h h_{t-1} + b)$$
ã«åŸã£ãŠ$t$æãŸã§ã®æ
å ±ã®ç¹åŸŽéãå°åºããŠããã®ãååž°ãããã¯ãŒã¯ã®åºæ¬çã«ãªããŸãã
- $x_t$ïŒæå»$t$ã§ã®æ°ããå ¥åããŒã¿
- $h_{t-1}$ïŒåã®æå»ã§èšç®ããçµæïŒããããéå»ã®æ å ±ãïŒ
- $h_t$ïŒçŸåšèšç®ããŠããçµæïŒæ¬¡ã®æå»ã§ã¯éå»ã®æ å ±$h_{t-1}$ãšãªãïŒ
$t$æã®ç¹åŸŽéãèšç®ããã«ã¯ã$t-1$æã®æ
å ±ã§ãã$h_{t-1}$ãå¿
èŠã«ãªãç¹ããã€ã³ãã§ããç¹ã«ãLSTMã®å Žåã¯ã$t-1$æã®å±¥æŽãšã¡ã¢ãªãŒã»ã«ã®æ
å ±
$$(h_{t-1}, c_{t-1})$$
ãå¿
èŠã«ãªããŸãã
1.2 ç³»å倿ã¢ãã« (sequence to sequence model)
ç³»å倿ã¢ãã«ã¯ãè±èªè¡šèšã§è¡šçŸãããŠããããã«ãããç³»åãå
¥åããããšãäœãããã®ç³»åãåºåããããšããã¢ãã«ã«ãªããŸãããããã«ã¡ã¯ãã®å
¥åã«å¯ŸããŠãããïŒããšè¿äºãããã®ãšäŒŒãŠããŸãã
äŸ
æç³»åããŒã¿ã®(1, 2, 3, 4, 5)ãå
¥åããããšãæç³»åã®(6, 7, 8)ã®ããã«ç¶ããåºåããããããªã¿ã€ããç³»å倿ãšåŒãã§ããããã§ãã
å ¥åãããç³»å | åºåãããç³»å |
---|---|
1, 2, 3, 4, 5 | 6, 7, 8 |
ç³»å倿ã¢ãã«ã¯
- ãšã³ã³ãŒã㌠(encoder)ïŒå ¥åããŒã¿ããç¹åŸŽéãæœåºããéšå
- ãã³ãŒã㌠(decoder)ïŒæœåºããç¹åŸŽéããæç³»åã®ããŒã¿ãåºåããéšå
ãšããïŒçš®é¡ã®ãããã¯ãŒã¯æ§é ããæ§æãããŸããä»åã¯ãšã³ã³ãŒããŒã«LSTMããã³ãŒããŒã«ãLSTMãå©çšããã¿ã€ãã®ãããã¯ãŒã¯ãäœæããŠãããŸãã
ãããŸã§
- ïŒæãïŒæãŸã§ã®ããŒã¿ãããïŒæç®ïŒæ¬¡ã®æïŒãäºæž¬
ä»å (ç³»å倿ã¢ãã«)
- ïŒæãïŒæãŸã§ã®ããŒã¿ãããïŒæç®ã10æç®ãŸã§ã®5æéãäºæž¬
- å³ã®ãããªãããã¯ãŒã¯æ§é ãšãªããŸãã

PyTorchã«ããããã°ã©ã ã®æµãã確èªããŸããåºæ¬çã«äžèšã®ïŒã€ã®æµããšãªããŸããJuypyter Labãªã©ã§å®éã«å ¥åããªããé²ããã®ããªã¹ã¹ã¡
- ããŒã¿ã®èªã¿èŸŒã¿ãštorchãã³ãœã«ãžã®å€æã(2.1)
- ãããã¯ãŒã¯ã¢ãã«ã®å®çŸ©ãšäœæã(2.2)
- èª€å·®é¢æ°ãšèª€å·®æå°åã®ææ³ã®éžæã(2.3)
- 倿°æŽæ°ã®ã«ãŒãã(2.4)
- æ€èšŒã(2.5)
2. ð€ ã³ãŒããšè§£èª¬
2.0 ããŒã¿ã«ã€ããŠ
æ¥çµ225ã®ããŒã¿ãyfinanceãpandas_datareaderãªã©ã§ååŸããŸãã第8åãšåäžã®ããŒã¿ãå©çšããŸãã
Date | Open | High | Low | Close | Volume |
---|---|---|---|---|---|
2021-01-04 | 27575.57 | 27602.11 | 27042.32 | 27258.38 | 51500000 |
2021-01-05 | 27151.38 | 27279.78 | 27073.46 | 27158.63 | 55000000 |
2021-01-06 | 27102.85 | 27196.40 | 27002.18 | 27055.94 | 72700000 |
ãïžã | ãïžã | ãïžã | ãïžã | ãïžã | ãïžã |
2025-06-18 | 38364.16 | 38885.15 | 38364.16 | 38885.15 | 110000000 |
2025-06-19 | 38858.52 | 38870.55 | 38488.34 | 38488.34 | 89300000 |
å§å€(Open)ãäºæž¬ããåœ¢ã§æŒç¿ãé²ããŠãããŸããå§å€ã®ã°ã©ããæç»ããŠã¿ãŸããããéè²ã®ç·ãæ¥çµ225ã®å§å€ã®æãç·ã°ã©ããšãªããŸãã
åŠç¿çšããŒã¿ãšãã¹ãçšããŒã¿ã«åå²ããŸããã°ã©ãã®èµ€ãç·ã®å³åŽ100æããã¹ãçšã®ããŒã¿ãšããŠäœ¿ããŸããæ®ãã®å·ŠåŽãåŠç¿çšã®ããŒã¿ãšããŸããåŠç¿çšããŒã¿ã§åŠç¿ãããŠããå³åŽã®100æéãäºæž¬ã§ããã®ãïŒããäž»ç®æšãšãªããŸãã
2021幎以éã®æ¥çµ225ã®å€ã¯ã3äžåååŸã®æ°å€ã«ãªãããšãã»ãšãã©ã§ãã誀差èšç®æã®æå€±ã®å€ã倧ãããªããããªãããã«ã倿°ã®æŽæ°ãããŸãè¡ãããããã«ãã1äžåã§å²ãç®ããŠæ°å€ãå°ããã ããŠãããŸããããã§ãã»ãšãã©ã®å€ã2.5ã4ã«åãŸãã¯ãã§ãã
ãããŸã§ã®å 容ããŸãšããã³ãŒããæ¬¡ã«ãªããŸãã
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
# CSVãã¡ã€ã«ã®èªã¿èŸŒã¿
data = pd.read_csv("./data/nikkei_225.csv")
# æ¥çµ225ã®å€ã10,000åã§å²ãç®ããŠå€ãå°ãããã
scaling_factor = 10_000
x_open = data["Open"]/scaling_factor
x_high = data["High"]/scaling_factor
x_low = data["Low"]/scaling_factor
x_close = data["Close"]/scaling_factor
2.1 ããŒã¿ã»ããã®äœæãštorchãã³ãœã«ãžã®å€æ
ç³»å倿ã¢ãã«ã§åŠç¿ã§ãã圢ã«ããŒã¿ãååŠçããŸããïŒæãïŒæãŸã§ã®ããŒã¿ãããïŒæç®ã10æç®ãŸã§ã®5æéã®å§å€ãäºæž¬ãããããã«ããŒã¿ã»ãããäœã£ãŠãããŸããå ·äœçã«ã¯ãå ¥åããŒã¿ã¯æ ªäŸ¡ã®å§å€ã»é«å€ã»å®å€ã»çµå€ããŒã¿ãçªãµã€ãºïŒã§åºåã£ãŠããã®çªã1ã€ãã€ã¹ã©ã€ããããªããããŒã¿äœæããŠãããŸããæåž«ããŒã¿ã5æåãäºæž¬ãããã®ã§ãçªãµã€ãºïŒã®å§å€ããŒã¿ ã«ãªããŸããäºæž¬ãããæéã®æ°ïŒãšçªãµã€ãºãåãã«ãªããŸãã
å ¥åããŒã¿ | æåž«ããŒã¿ |
---|---|
ïŒæãïŒæãïŒæãïŒæãïŒæ | ïŒæãïŒæãïŒæãïŒæãïŒïŒæ |
$x_1,~x_2,~x_3,~x_4,~x_5$ | $t_1,~t_2,~t_3,~t_4,~t_5$ |
泚æç¹ã¯ïŒç®æ
- æåž«ããŒã¿ã¯ã5æç®ããçªãµã€ãºïŒã§ç§»åãããŠäœæ
- ïŒæå ãŸã§ã®äºæž¬ãªã®ã§ãå ¥åããŒã¿ã®çµç«¯äœçœ®ã«æ³šæ
å
¥åããŒã¿ã®æ³šæç¹
倿Žå
- XO = [... for start in range(len(data)-win_size)]
倿ŽåŸ
- XO = [... for start in range(len(data)-win_size-dec_win_size)]
CSVãã¡ã€ã«ã®èªã¿èŸŒã¿ããçªãµã€ãºã§ã®åå²ãŸã§ã®ã³ãŒãã§ããã¹ããŒãã«äžåºŠã«å€æã§ã¯ãªããå°å³ã«åãããšãç¹°ãè¿ãã§æžããŸãã
win_size = 5 # å
¥åããŒã¿ã®çªãµã€ãº
dec_win_size = 5 # æåž«ããŒã¿ã®çªãµã€ãº
XO = [x_open[start:start+win_size] for start in range(len(data)-win_size-dec_win_size)]
XH = [x_high[start:start+win_size] for start in range(len(data)-win_size-dec_win_size)]
XL = [x_low[start:start+win_size] for start in range(len(data)-win_size-dec_win_size)]
XC = [x_close[start:start+win_size] for start in range(len(data)-win_size-dec_win_size)]
# æåž«ããŒã¿ win_size=5æããã¹ã¿ãŒã
TO = [x_open[start:start+dec_win_size] for start in range(win_size, len(data)-dec_win_size)]
# å
¥åããŒã¿
xo = np.array(XO)
xh = np.array(XH)
xl = np.array(XL)
xc = np.array(XC)
# æåž«ããŒã¿
t = np.array(TO)
xo = xo.reshape(xo.shape[0], xo.shape[1], 1)
xh = xh.reshape(xh.shape[0], xh.shape[1], 1)
xl = xl.reshape(xl.shape[0], xl.shape[1], 1)
xc = xc.reshape(xc.shape[0], xc.shape[1], 1)
t = t.reshape(t.shape[0], t.shape[1]) # æçµçãªæåž«ããŒã¿
x = np.concatenate([xo, xh, xl, xc], axis=2) # æçµçãªå
¥åããŒã¿
çªãµã€ãºã§åºåã£ãå§å€ã»é«å€ã»å®å€ã»çµå€ã®ïŒçš®é¡ãå
¥åããŒã¿ã«äœ¿ããŸãã
ïŒååºåãããŒã¿ïŒXOãXHãXLãXCïŒã¯ãããããçªãµã€ãºïŒã§ç¹åŸŽéãïŒã€ã®ç¶æ³ã§ãããããçµåããŠãïŒããããµã€ãºãïŒïŒïŒïŒã®åœ¢ç¶ã«å€æããŸãã
å®éã«è¡šç€ºãããšãããã®ã§ãããäžèšã®ã³ãŒãã ãšXOãTOã¯ã¿ã€ããå
¥ãä¹±ããŠããŸããæçµçã«torch.FloatTensor()ã®åœ¢ã«ãªãã°ããã®ã§ãã¹ããŒãã§ã¯ãããŸãããåæã§æŒãåãã³ãŒãã«ããŸãã äžæŠãnumpyé
åã«ããŠåœ¢åŒãæŽããŠããŸããŸããã
å ¥åããŒã¿xã®åœ¢ç¶ããïŒããããµã€ãºãç³»åé·ã®ïŒãç¹åŸŽéã®ïŒïŒãæåž«ããŒã¿ã®åœ¢ç¶ããïŒããããµã€ãºãç³»åé·ã®ïŒïŒã«ãªã£ãŠããããšã確èªã§ããŸããxãLSTMã«å ¥ããããšãããããã¯ãŒã¯ãå§ãŸããŸãããã®åã«ãxãštãFloatTensorã«å€æããŠãåŠç¿çšããŒã¿ãšãã¹ãçšããŒã¿ã«åå²ããŸããååéšåãåŠç¿çšãåŸåéšåããã¹ãçšãšååŸã«åå²ããŸãã
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.FloatTensor(x).to(device)
t = torch.FloatTensor(t).to(device)
period = 100
x_train = x[:-period]
x_test = x[-period:]
t_train = t[:-period]
t_test = t[-period:]
# å
¥åããç¹åŸŽéã¯ïŒæ¬¡å
# x_train.shape : torch.size([987, 5, 4])
# x_test.shape : torch.Size([100, 5, 4])
# t_train.shape : torch.Size([987, 1])
# t_test.shape : torch.Size([100, 1])
2.2 ãããã¯ãŒã¯ã¢ãã«ã®å®çŸ©ãšäœæ
èšå·
$x_1,~x_2,~x_3,~x_4,~x_5$ïŒå
¥åããŒã¿ã§ïŒæãïŒæã®å§å€ãé«å€ãå®å€ãçµå€
$t_1,~t_2,~t_3,~t_4,~t_5$ïŒå
¥åããŒã¿ã«å¯Ÿå¿ããæåž«ããŒã¿ã§ïŒæã10æã®å§å€
æåž«ããŒã¿ã®æ·»ãåãã$t_1$ã§ïŒæç®ã®å§å€ã衚ããŸã
LSTMãå©çšããç³»å倿ã¢ãã«ã§ïŒæå ãŸã§ã®äºæž¬ãæ±ã£ãŠãããŸãããªã¬ã³ãžè²ã£ãœãå¹³è¡å蟺圢ã§å²ããã4次å ã®å€$x_1=(xo_1, xh_1, xl_1, xc_1)$ããé çªã«å ¥åãããŸããLSTMã«å ¥åãããç¹åŸŽéã¯4ã€ãªã®ã§ãinput_size=4ãšãªããŸããLSTMã®æçµçãªåºåã§ããh5ãšc5ãéå»ã®ïŒæ¥åã®æ å ±ãååž°çã«èæ ®ããç¹åŸŽéã«ãªããŸããh5ãc5ããã³ãŒããŒã®LSTMã®åæå€ãšããŠå©çšããŸãã

PyTorchã§ã®LSTMã®æžãæ¹ã®ãã€ã³ãããŸãšããŠãããŸããLSTMã®åºåã¯ïŒçš®é¡ãããå±¥æŽhãšã»ã«cããšã³ã³ãŒããŒåŽã®åºåãšããŠå©çšããŸãã
LSTMå±€ã®æžãæ¹
nn.LSTM(input_size, hidden_size, num_layers, batch_first)
- input_size : å ¥åãããç¹åŸŽéã®æ¬¡å
- hidden_size : åºåãããé ãå±€ã®ç¹åŸŽéã®æ¬¡å
- num_layers :ãååž°ããLSTMã®æ°ãããã©ã«ãã¯num_layers=1
- batch_first : Trueã§ïŒããããµã€ãºãç³»åé·ãç¹åŸŽéïŒã®åœ¢ç¶
batch_first=Trueã§ã®LSTMã®åºåå€
o, (h, c) = lstm(x)
- o : ãã¹ãŠã®æç¹ã§ã®æçµå±€ïŒäžçªæåŸlayerïŒã®é ãç¶æ ã®åºå
- h : æåŸã®æç¹ã§ã®ãã¹ãŠã®é ãå±€ã®åºå
- c : æåŸã®æç¹ã«ãããã»ã«ç¶æ
詳现ã¯PyTorchã®å ¬åŒããã¥ã¡ã³ãã«èšèŒãããŠããŸãã
ãã³ãŒããŒéšåã§ã¯ãåŸæ¥ã®ã1æå ã®ã¿ãäºæž¬ããã¢ãã«ããšã¯ç°ãªãã¢ãããŒããæ¡çšããŸããããã§ã¯5æå ãŸã§ã®äºæž¬ãè¡ããããLSTMã䜿ã£ãŠäºæž¬å€ã段éçã«çæããŠãããŸããå ·äœçã«ã¯ããŸã1æå ãäºæž¬ãããã®äºæž¬çµæã次ã®å ¥åãšããŠ2æå ãäºæž¬ãããã«ãã®çµæã§3æå ãäºæž¬âŠâŠãšããããã«ãäºæž¬å€ãé£éçã«çæããŠ5æåã®æç³»åäºæž¬ãããŠãããŸãã
ãšã³ã³ãŒããŒããåŒãç¶ãã é ãç¶æ
h5ãã»ã«ç¶æ
c5ããã³ãŒããŒã®æåã®LSTMã®åæå€ãšããå
¥åããŒã¿ã«ã¯5æã®å§å€$xo_5$ã䜿ããŸããLSTMã®åºåã¯
$$o, (h, c) = \mbox{lstm}(xo_5, (h_5, c_5))$$
ã®ããã«èšç®ãããŸããoãŸãã¯hããå
šçµåå±€ãªã©ãå©çšããŠïŒæç®ã®äºæž¬å€ã$y_1$ãšèšç®ããŸã1ãäºæž¬å€$y_1$ãšïŒæç®ã®æåž«ããŒã¿ã§ãã$t_1$ãšã®èª€å·®ãæ±ããŸãã$y_1$ãæ¬¡ã®LSTMã®å
¥åããŒã¿ãšããŠååž°ãããã¯ãŒã¯ã®èšç®ãå§ãŸããŸããïŒæç®ã®äºæž¬å€$y_2$ãšæåž«ããŒã¿ã§ãã$t_2$ãæ¯èŒããŠèª€å·®ãæ±ããŸããæçµçã«10æç®ã®äºæž¬å€$y_5$ãšæåž«ããŒã¿$t_5$ããèª€å·®ãæ±ãŸããŸãããã®èª€å·®ãå°ããããããã«ãã©ã¡ãŒã¿ãæŽæ°ããããšã«ãªããŸãã
æçµçãªäºæž¬å€ã¯($y_1$, $y_2$, $y_3$, $y_4$, $y_5$)ã®5ã€ãšãªããŸããããã6æç®ã10æç®ãŸã§ã®5æåã®äºæž¬ãšãªããŸãã

å®éã®ãããã¯ãŒã¯æ§é ã§ããããã³ãŒããŒã§ã®äºæž¬ã«ãŒãéšåãäžèŠè€éã«èŠããŸãããããã¯ãŒã¯æ§é èªäœã¯ãLSTM2ã€ãšå šçµåå±€ã ããªã®ã§æ¬è³ªçã«åçŽã§ãããšã³ã³ãŒããŒåŽã¯ïŒæ¬¡å ã®å ¥åå€ã§ïŒåååž°çã«ç¹°ãè¿ããå±¥æŽ$h_5$ãšã»ã«ç¶æ $c_5$ããã³ãŒããŒã®LSTMã«æž¡ããŸãããã³ãŒããŒã¯ïŒæ¬¡å ã®å ¥åå€ã§æ¬¡ã®æã®æ¥çµ225ãäºæž¬ããŠãããšãã圢ã«ãªããŸããå³ã«ãŸãšãããšãããããç³»å倿ã®è§£èª¬å³ã«ãªããŸãã
å³ïŒç³»å倿ã¢ãã«
ãšã³ã³ãŒããŒéšåã¯LSTMã ãã§ããenc_lstmãšããŠããŸãããã³ãŒããŒã¯ãLSTMãšå šçµåå±€ãæŽ»æ§å颿°ããæ§æãããŸãããã³ãŒããŒã®LSTMãdec_lstmãšããŠããŸãã
class DNN(nn.Module):
def __init__(self):
super().__init__()
# ãšã³ã³ãŒããŒ
self.enc_lstm = nn.LSTM(input_size=4, hidden_size=100, num_layers=1, batch_first=True)
# ãã³ãŒããŒ
self.dec_lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True)
self.fc1 = nn.Linear(in_features=100, out_features=50)
self.act = nn.LeakyReLU()
self.fc2 = nn.Linear(in_features=50, out_features=1)
def forward(self, x):
# ãšã³ã³ãŒããŒåŠç
_, (h, c) = self.enc_lstm(x)
# ãã³ãŒããŒçšã®åæå
¥åïŒxã®æåŸã®ã¿ã€ã ã¹ããã (1,2,3,4,5)ãªãïŒãå
¥åãããïŒ
decoder_input = x[:, -1:, 0].unsqueeze(2) # (batch_size, 1, 1)
# 5æ¥ç®ã®CloseäŸ¡æ Œãåæå
¥åãšããŠäœ¿çš
#decoder_input = x[:, -1:, 3].unsqueeze(2) # CloseäŸ¡æ Œã䜿çš
hidden = h[-1,:,:].unsqueeze(0)
cell = c[-1,:,:].unsqueeze(0)
# äºæž¬çµæãæ ŒçŽ
outputs = []
# ãã³ãŒããŒã§ã®äºæž¬ã«ãŒã
for t in range(dec_win_size):
o, (hidden, cell) = self.dec_lstm(decoder_input, (hidden, cell))
last_output = hidden[-1] # æåŸã®ã¹ãããã®åºå o[:,-1,:]ã§ãåã
h = self.fc1(last_output)
h = self.act(h)
y = self.fc2(h)
outputs.append(y)
decoder_input = y.unsqueeze(1) # (batch_size, 1, 1)
# äºæž¬çµæãçµå
outputs = torch.cat(outputs, dim=1)
return outputs
forwardã®foræã®éšåã詳ããèŠãŠãããŸããå ¥ãçµãã§ããŸãããå®ã¯åçŽã§ããdec_win_sizeã¯æåž«ããŒã¿ã®çªãµã€ãºãªã®ã§ïŒã§ããforæã¯lstmãšfcã«ããäºæž¬èšç®ãïŒåç¹°ãè¿ãã ãã§ãã
...
for t in range(dec_win_size):
o, (hidden, cell) = self.dec_lstm(decoder_input, (hidden, cell))
last_output = hidden[-1] # æåŸã®ã¹ãããã®åºå o[:,-1,:]ã§ãåã
h = self.fc1(last_output)
h = self.act(h)
y = self.fc2(h)
outputs.append(y)
decoder_input = y.unsqueeze(1) # (batch_size, 1, 1)
...
o, (hidden, cell) = self.dec_lstm(decoder_input, (hidden, cell))
ãã®éšåã®åãããã€ã³ãã§ãã
- ïŒåç®ã®ã«ãŒã
-
decoder_input
ã¯ïŒæç®ã®å§å€ãhidden
ãšcell
ã¯ãã³ãŒããŒã®æçµåºå - lstmã®åºåã§ãããhiddenãoã®åºåãå šçµåå±€ã«å ¥åããŠäºæž¬å€yãæ±ããŸã
-
- ïŒåç®ã®ã«ãŒã
-
decoder_input=y.unsqueeze(1)
ã€ãŸãïŒæç®ã®äºæž¬å€ã§ã -
hidden
ãšcell
ã¯dec_lstmã®åºåã«ãªããŸã - lstmã®åºåãå šçµåå±€ã«å ¥åããŠäºæž¬å€yãæ±ããŸã
-
- ïŒåç®ã®ã«ãŒã
-
decoder_input=y.unsqueeze(1)
ã€ãŸãïŒæç®ã®äºæž¬å€ã§ã -
hidden
ãšcell
ã¯dec_lstmã®åºåã«ãªããŸã - lstmã®åºåãå
šçµåå±€ã«å
¥åããŠäºæž¬å€yãæ±ããŸã
....
-
- ïŒåç®ã®ã«ãŒã
-
decoder_input=y.unsqueeze(1)
ã€ãŸãïŒæç®ã®äºæž¬å€ã§ã -
hidden
ãšcell
ã¯dec_lstmã®åºåã«ãªããŸã - lstmã®åºåãå šçµåå±€ã«å ¥åããŠäºæž¬å€yãæ±ããŸã
-
-
outputs.append(y)
ã§äºæž¬å€ãéããŠãããŸã - ãšããã©ãã
unsqueeze()
ãã€ããŠããã®ã¯ãããŒã¿ã®åœ¢ç¶ãåãããããã§ã
forwardã®éšåã¯è€éã§ããããprint(model)ã®çµæã¯ãšãŠãã·ã³ãã«ã§ãã
DNN(
(rnn): LSTM(4, 100, batch_first=True)
(fc1): Linear(in_features=100, out_features=50, bias=True)
(act1): ReLU()
(fc2): Linear(in_features=50, out_features=1, bias=True)
)
2.3 èª€å·®é¢æ°ãšèª€å·®æå°åã®ææ³ã®éžæ
ååž°åé¡ãªã®ã§äºæž¬å€y ãšå®æž¬å€ïŒæåž«ããŒã¿ïŒt ã®äºä¹èª€å·®ãå°ããããŠããæ¹æ³ã§åŠç¿ãããããŸãã
# æå€±é¢æ°ãšæé©å颿°ã®å®çŸ©
criterion = nn.MSELoss() # å¹³åäºä¹èª€å·®
optimizer = torch.optim.AdamW(model.parameters())
2.4 倿°æŽæ°ã®ã«ãŒã
LOOPã§æå®ããåæ°
- y=model(x) ã§äºæž¬å€ãæ±ãã
- criterion(y, t_train) ã§æå®ããèª€å·®é¢æ°ã䜿ãäºæž¬å€ãšæåž«ããŒã¿ã®èª€å·®ãèšç®ã
- 誀差ãå°ãããªãããã«optimizerã«åŸãå šçµåå±€ã®éã¿ãšãã€ã¢ã¹ãã¢ããããŒã
ãç¹°ãè¿ããŸãã
LOOP = 5_000
model.train()
for epoch in range(LOOP):
optimizer.zero_grad()
y = model(x_train)
loss = criterion(y, t_train)
if (epoch+1)%1000 == 0:
print(f"{epoch}\tloss: {loss.item()}")
loss.backward()
optimizer.step()
forã«ãŒãã§å€æ°ãæŽæ°ããããšã«ãªããŸããæå€±ã®æžå°ã芳å¯ããªãããåŠç¿åæ°ãåŠç¿çãé©å®å€æŽããããšã«ãªããŸãããããŸã§ã§ãåºæ¬çãªåŠç¿ã¯çµãããšãªããŸãã
2.5 ð æ€èšŒ
2.1ã®ããŒã¿åå²ã§äœæãããã¹ãããŒã¿ x_test ãš t_test ãå©çšããŠåŠç¿çµæããã¹ãããŠã¿ãŸããããx_testãmodelã«å ¥ããå€ y_test = model(x_test) ãäºæž¬å€ãšãªããŸããy_testã¯ïŒæããšã®äºæž¬å€ã®ãªã¹ãã«ãªããŸãã衚ã®ãããªåœ¢ã«ãªããŸãã
5æåã®å€ | |
---|---|
y_test[0] | 3.83, 3.83, 3.85, 3.85, 3.85 |
y_test[1] | 3.85, 3.85, 3.86, 3.87, 3.87 |
ïž | ïž |
y_test[99] | 3.85, 3.84, 3.83, 3.82, 3.82 |
ã°ã©ãã®æç»ã§ããã5æäºæž¬ããŠããŸããæ¬¡ã®5æãäºæž¬ãã圢ã§ã°ã©ããæç»ããŠã¿ãŸããã5æéãšã°ããªããäºæž¬å€ãããããããŠããã¹ã¿ã€ã«ã§ã2ã
import matplotlib.pyplot as plt
import japanize_matplotlib
# äºæž¬å€ã®ãªã¹ã
with torch.inference_mode():
output = model(x_test)
y_test = output.cpu().detach().numpy()
# 宿ž¬å€ã®ãªã¹ã
# ä»åã¯dec_win_sizeã®å
é éšåãéããã°å®æž¬å€ãšãªã
real_list = [item[0].detach().cpu().numpy() for item in t_test]
e = 100 # period 100æå衚瀺
plt.figure(figsize=(15,10))
plt.title(f"{dec_win_size}æå
ãŸã§ã®äºæž¬ã5æããšè¡šç€º")
plt.plot(real_list[:e], label="real", marker="^")
for i in range(0,e-dec_win_size+1,5):
plt.plot(range(i, i + len(y_test[i])), y_test[i], linestyle="dotted", label="prediction" if i == 0 else "", marker="*", color="red", alpha=0.5)
plt.legend()
plt.grid()
plt.show()
100æéãäžåºŠã«è¡šç€ºããããšã°ã©ããèŠã¥ããã®ã§ååã«åããŠäœå³ããŠã¿ãŸããïŒ
å³ïŒ0ã49æã»5æå ãŸã§äºæž¬ãããã®ãã€ãªããã°ã©ã
ãããããªã«ãéãæ°ããã3ããã¬ã³ãã¯ã€ãããŠãããã©ã3æå
ã®äºæž¬ã¯ã»ãŒåœ¹ç«ããªããª
ïŒæå
ã«ãªããšèª€å·®ãèç©ãããŠããã®ã§ã¯ïŒãšæããåŠç¿æã«ãæåž«åŒ·å¶ïŒTeacher ForcingïŒã®æ¹æ³4ãåãå
¥ããŠã¿ãã®ã§ãããã¡ãã£ãŽãæ¹åãšããçšåºŠã§ãããå Žåã«ãã£ãŠã¯æ¹æªãããŠããéšåããããŸãã
ç³»å倿ã¢ãã«ã«ãããããšãã£ãŠãåçŽãªLSTMãããã¢ãã«ã®æ§èœãåäžãããšããããã§ã¯ãªãã®ãããããŸãããšã³ã³ãŒããŒã®åœ¢ã第9åã®æ§é ãšäŒŒãŠãããããè€æ°æå ãŸã§äºæž¬ã§ããã®ãç³»å倿ã¢ãã«ã®é¢çœãéšåãªã®ããªã
å³ïŒ50ã99æã»5æå ãŸã§äºæž¬ãããã®ãã€ãªããã°ã©ã
5æ¥åã®äºæž¬ãšãªããšãªããªãé£ããïŒãšã¯æãã®ã§ãã倧ãŸããªåŸåãæãŸãããããšããæå³ã§ã¯æçšãªæ°ãããŸãããã ãåœãã£ãŠããã°ã®è©±ã§ããã50æããããŸã§ã¯äžäžã®åŸåã¯æããããŠããæ°ãããŸãã50æç®ä»¥éã®ã°ã©ãã¯ãæãã®å€è¯ãåŸåã«ããããªïŒæåŸã®20æéã¯ã¡ãã£ãšæ®å¿µãªçµæãšãªã£ãŠããŸããäºæž¬ã¯ã¡ãã£ãšå€ãæ°å³ã§ãã
æ³šææ©æ§ã®ãªãLSTMã®ç³»å倿ã¢ãã«ã§ã¯ããšã³ã³ãŒããŒã®LSTMã®åæå€($h_5$, $c_5$)ãšããŠãããšã³ã³ãŒããŒã®æ å ±ãå©çšãããŠããŸããããã³ãŒããŒã®æ å ±ãããšã³ã³ãŒããŒåŽã«äŒãããããŠããªãå¯èœæ§ããããŸãããã³ãŒããŒæã«ããšã³ã³ãŒããŒã®æ å ±ãç©æ¥µçã«åç §ããããšããçºæ³ãæ³šææ©æ§ãå°å ¥ããç³»å倿ã¢ãã«ãšãªããŸãã
RNNãå©çšããç³»å倿ã»ãšã³ã³ãŒããŒãšãã³ãŒããŒã¢ãã«
-
Cho et al., (2014) "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation"
-
Sutskever et al., (2014) "Sequence to Sequence Learning with Neural Networks"
ã¢ã€ãã£ã¢ã¯ããã£ãšæããããããã§ã翻蚳ã¿ã¹ã¯ãå€ãã¿ããã
次å
- ãããŸã§ãããæ³šææ©æ§ã«ãªãã®ãèªç¶ãªæµãã®ãããªæ°ãããŸããäžæŠäŒæ©ã§ã
次åã¯æç³»ååæã§ç»å ŽããïŒæ¬¡å ç³ã¿èŸŒã¿å±€ãå æç³ã¿èŸŒã¿å±€ã«ã€ããŠç°¡åã«è§Šããããªã
ç®æ¬¡ããŒãž
泚
-
ããã³ãŒããŒã®æåã®å ¥åå€ã«äœãå©çšããã®ãïŒãïŒæç®ã®å§å€ãäºæž¬ãããã®ã§ãå ¥åããããŒã¿ã«ïŒæç®ã®å§å€ã䜿ããŸããïŒæç®ã®å§å€ãé«å€ãå®å€ãçµå€ã®ïŒçš®é¡ãå ¥åããå Žåã¯ããã³ãŒããŒã§ã®äºæž¬ãïŒçš®é¡ã«ããå¿ èŠããããŸãããã®å Žåããã³ãŒããŒã®LSTMã¯
self.dec_lstm = nn.LSTM(input_size=4
, hidden_size=100, num_layers=1, batch_first=True)ã®ããã«ãïŒæ¬¡å å ¥åãšãªããŸããåãããŠãæçµåºåãæ±ºããå šçµåå±€ã
self.fc2 = nn.Linear(in_features=50,out_features=4
)ã®ããã«4次å åºåãšãªããŸããæåž«ããŒã¿ãå§å€ãé«å€ãå®å€ãçµå€ã®4çš®é¡çšæããå¿ èŠããããŸãããããããªããªãé¢åãããã â© -
æ¯æ5æãã€ã®äºæž¬ãæç»ããã¹ã¿ã€ã«ã ãšã°ã©ããæ¥µããŠèŠã«ãããã
â©
-
äžã®ã°ã©ãã¯ããŸããŸããŸãåŠç¿ã§ãããšãã®æ§åã§ããæå€±ãããŸãäžãã£ãŠããªãç¶æ ã ãšå³äžãããUååã«äºæž¬ããããšãå€ãã£ãæ°ãããŸãã â©
-
ãã³ãŒããŒã§ã®å ¥åå€ã«äºæž¬å€ã ãã§ãªããå®éã«å€ïŒæåž«ããŒã¿ïŒãäœ¿ãæ¹æ³ãç·ç§°ããŠTeacher ForcingïŒæåž«åŒ·å¶ïŒãšåŒãã§ããããã§ãã â©