0
1

[PyTorch] LSTMのシーケンス長による時系列予測精度の影響調査

Posted at

目的

LSTMのシーケンス長による時系列予測精度の影響を調査する。

結論

単純なプラント(1次遅れ系)の場合はシーケンス長は短い方が過学習しないので精度がよく、シーケンス長1が最も精度が良い。

尚、LSTMでも複雑なプラントの場合はシーケンス長を長くする必要があると考えられる。
また、MLPのように状態量を持たないモデルではシーケンス長1で時系列データ予測を行うと未学習になり精度が悪いので注意。

比較条件

・プラント:1次遅れ系
・学習用データ作成用入力信号:Chirp信号
・検証用データ作成用入力信号:Step信号

実験結果

・シーケンス長=1:精度問題なし。
image.png

・シーケンス長=16:大きな問題ではないが編曲点で若干ノイズが出ている。
image.png

・シーケンス長=128:パラメータ数増加に伴う明らかな精度悪化が見られる(過学習している)。
image.png

実験に用いたPythonスクリプト

d10_Pytorch_LSTM_Time_Sequence.ipynb
from control.matlab import*
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(0);

#学習データ作成用伝達関数モデルの定義・・・システムが複雑になると予測精度が落ちる
s = tf('s')
P = 1/(0.5*s+1) #1次遅れ系
#P = 1/(1*s**2+0.5*s+1) #2次遅れ系

#学習するモデルタイプの選択
#1:MLP, 2:RNN, 3:LSTM, 4:CNN : CNNは学習用データの形が合わないので動かない。
model_type = 3

#シーケンス長・・・予測精度に大きな影響があるので注意
sequence_length = 128

#学習用データと検証用データを作成するための入力信号の種類・・・予測精度に大きな影響がある
#1:Step信号、2:Chirp信号、3:白色ノイズ
train_data_type = 2
vali_data_type = 1

#入力信号
duration = 100  # 信号の継続時間 (秒)
sampling_rate = 10  # サンプリングレート (Hz)

# -------入力信号の生成------------
# 時間軸の作成
time_org = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
print('time_org.shape', time_org.shape)

#-------------入力信号の生成---------
step_signal_type = 2
def if_elif(step_signal_type):
    if step_signal_type ==1: #---Step信号---
        step_time = duration/20 #ステップするタイミングをdurationの1/20毎とする
        u = np.zeros_like(time_org)
        for i in range(len(time_org)):
            if (i // int(step_time * sampling_rate)) % 2 == 1:
                u[i] = 1
        return u
    elif step_signal_type ==2: #---Chirp信号---
        frequency_start = 0.01  # チャープ信号の開始周波数 (Hz)
        frequency_end = 1 # チャープ信号の終了周波数 (Hz)
        u = np.sin(2 * np.pi * np.cumsum(np.linspace(frequency_start, frequency_end, len(time_org)) / sampling_rate)) / 2 + 0.5
        return u
    else: #---白色ノイズ---
        A = 1.0    # 振幅
        u = 2*A*(np.random.rand(round(sampling_rate*duration))-0.5)
        return u
        
u_train = if_elif(train_data_type) 
u_vali =  if_elif(vali_data_type) 

#-----------lsimで出力(応答)計算---------
def calc_y(P, u, t):
    x0 = 0 #初期値
    y, time_tmp, x_tmp = lsim(P, U=u, T=t, X0=x0)
    return y
    
y_act_train = calc_y(P,u_train,time_org)
y_act_vali = calc_y(P,u_vali,time_org)

#---------学習用データのグラフ表示---------
plt.figure(figsize=(10, 5))

#全体表示
plt.subplot(2, 1, 1)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time_org, u_train, label="Input", color='b', linewidth=1)
plt.plot(time_org, y_act_train, label="Output", color='r', linewidth=1)
plt.ylabel("Signal [-]")
plt.title("Train data check")
plt.legend()
plt.grid(True)

#拡大表示
plt.subplot(2, 1, 2)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time_org, u_train, label="Input", color='b', linewidth=1)
plt.plot(time_org, y_act_train, label="Output", color='r', linewidth=1)
plt.xlim(80,100)
plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
#plt.title("Train data")
plt.legend()
plt.grid(True)

#---------検証用データのグラフ表示---------
plt.figure(figsize=(10, 5))

#全体表示
plt.subplot(2, 1, 1)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time_org, u_vali, label="Input", color='b', linewidth=1)
plt.plot(time_org, y_act_vali, label="Output", color='r', linewidth=1)
plt.ylabel("Signal [-]")
plt.title("Validation data check")
plt.legend()
plt.grid(True)

#拡大表示
plt.subplot(2, 1, 2)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time_org, u_vali, label="Input", color='b', linewidth=1)
plt.plot(time_org, y_act_vali, label="Output", color='r', linewidth=1)
plt.xlim(0,25)
plt.xlabel("time_org (s)")
plt.ylabel("Signal [-]")
#plt.title("Train data")
plt.legend()
plt.grid(True)

#---------時系列モデルの学習用および検証用データを作成(入力をシーケンス長毎のベクトルとする)---------
def create_sequences(x_org, y_org, seq_length):
    xs, ys = [], []
    for i in range(len(y_org) - seq_length + 1):
        x = x_org[i:i+seq_length]
        y = y_org[i+seq_length-1]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

# 訓練データ
x_train, y_train = create_sequences(u_train, y_act_train,sequence_length)
y_train = y_train.reshape(-1, 1)
time_train = time_org[sequence_length-1:]

#検証用データ
x_vali, y_vali = create_sequences(u_vali, y_act_vali,sequence_length)
y_vali = y_vali.reshape(-1, 1)
time_vali = time_train

#---------Tensor型へ変換する---------
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train =  torch.tensor(y_train, dtype=torch.float32)
time_train = torch.tensor(time_train, dtype=torch.float32)
x_vali = torch.tensor(x_vali, dtype=torch.float32)
y_vali =  torch.tensor(y_vali, dtype=torch.float32)
time_vali = torch.tensor(time_vali, dtype=torch.float32)

#---------学習用シーケンスデータの確認---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_train.numpy(), x_train.numpy()[:,0], label='x_train[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_train.numpy(), x_train.numpy()[:,sequence_length-1], label='x_train[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train', linewidth=1, color='r')
plt.title('Train data check')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_train.numpy(), x_train.numpy()[:,0], label='x_train[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_train.numpy(), x_train.numpy()[:,sequence_length-1], label='x_train[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train', linewidth=1, color='r')
plt.xlabel('Time[s]')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(200,300)

plt.show()

# ---------検証用シーケンスデータの確認---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_vali.numpy(), x_vali.numpy()[:,0], label='x_vali[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_vali.numpy(), x_vali.numpy()[:,sequence_length-1], label='x_vali[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=1, color='r')
plt.title('Validation data check')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_vali.numpy(), x_vali.numpy()[:,0], label='x_vali[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_vali.numpy(), x_vali.numpy()[:,sequence_length-1], label='x_vali[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=1, color='r')
plt.xlabel('Time[s]')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(40,140)

plt.show()

# ---------MLPの定義---------
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(sequence_length, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

#------------RNNの定義-------------
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size) 

    def forward(self, x):
        x, h = self.rnn(x, None)
        x = self.fc(x.squeeze(1)) 
        return x

#------------LSTMの定義-------------
class LSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

#------------CNNの定義(入力データの形状が異なるので実行できない)-------------
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_size, kernel_size=8, stride=1, padding=1)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.conv1(x)
        x = self.fc(x)
        return x

#----------モデルのインスタンス作成------------
input_size = sequence_length
hidden_size = 32
num_layers = 1 
output_size = 1

if model_type == 1:
    model = MLP()
elif model_type == 2:
    model = RNN()
elif model_type == 3:
    model = LSTM()
else:
    model = CNN()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

#----------モデルのサマリー表示----------
summary(model=model)
#summary(model=model,input_size=(1,sequence_length))

# ---------モデルの学習---------
num_epochs = 1000

start = time.time()
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    if model_type == 4:#CNNの時
        tmp = x_train.view(1, -1, input_size)
        print(tmp.size())
        outputs = model(tmp)
    else: #CNN以外(MLP,RNN,LSTMの時)
        outputs = model(x_train)
        
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % (num_epochs/10) == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.10f}')
end = time.time()
print('\n')
print(f'処理時間: {(end-start)/60:.1f}[min]')

# ---------学習用データと検証用データの予測---------
model.eval()
with torch.no_grad():
    y_train_pred = model(x_train)
    y_vali_pred = model(x_vali)

# ---------学習データにおける予測結果のプロット(時系列)---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train_Act', linewidth=2, color='b')
plt.plot(time_train.numpy(), y_train_pred.numpy(), label='y_train_Pred', linewidth=2, linestyle="--", color='r')
plt.title('Prediction result with train data')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train_Act', linewidth=2, color='b')
plt.plot(time_train.numpy(), y_train_pred.numpy(), label='y_train_Pred', linewidth=2, linestyle="--", color='r')
plt.xlabel('Time[s]')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(40,60)

plt.show()

# ---------検証データにおける予測結果のプロット(時系列)---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=2, color='b')
plt.plot(time_vali.numpy(), y_vali_pred.numpy(), label='y_vali_Pred', linewidth=2, color='r', linestyle="--")

plt.title('Prediction result with validation data')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=2, color='b')
plt.plot(time_vali.numpy(), y_vali_pred.numpy(), label='y_vali_Pred', linewidth=2, color='r', linestyle="--")
plt.xlabel('Time[s]')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(40,60)

plt.show()



0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1