0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[PyTorch] MLP,RNN,LSTMの時系列予測精度比較

Last updated at Posted at 2024-08-17

目的

MLP,RNN,LSTMの時系列予測精度を比較する。

結論

LSTMがノイズも少なく最も予測精度が高い。また勾配消失対策されているのでシーケンス長を短くできる可能性が高い、結果としてパラメータ数が削減できるので過学習を抑制できる可能性が高い。

比較条件

シーケンス長:16 (MLPでそれなりに精度がでるシーケンス長。LSTMの場合は1でも精度がでる。)
学習および検証用データ作成用モデル:伝達関数(1次遅れ)
学習データ作成用入力信号:Chirp信号
検証データ策施用入力信号:Step信号

結果

■ MLP
ノイズが発生する。
MLP.png

■ RNN
若干ノイズが発生するがMLPよりは良い。
RNN.png

■ LSTM
MLPやRNNと比較してノイズが小さく、全体的な挙動も妥当。
LSTM.png

実験に用いたスクリプト

d09_Pytorch_LSTM_Time_Sequence.ipynb
from control.matlab import*
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(0);

#学習データ作成用伝達関数モデルの定義・・・システムが複雑になると予測精度が落ちる
s = tf('s')
P = 1/(0.5*s+1) #1次遅れ系
#P = 1/(1*s**2+2*s+1) #2次遅れ系

#学習するモデルタイプの選択
#1:MLP, 2:RNN, 3:LSTM
model_type = 3

#シーケンス長・・・予測精度に大きな影響がある
sequence_length = 16

#学習用データと検証用データを作成するための入力信号の種類・・・予測精度に大きな影響がある
#1:Step信号、2:Chirp信号、3:白色ノイズ
train_data_type = 2
vali_data_type = 1

# -------入力信号の生成------------
duration = 100  # 信号の継続時間 (秒)
sampling_rate = 10  # サンプリングレート (Hz)

# 時間軸の作成
time = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
print('time.shape', time.shape)

#-------------入力信号の生成---------
step_signal_type = 2
def if_elif(step_signal_type):
    if step_signal_type ==1: #---Step信号---
        step_time = duration/20 #ステップするタイミングをdurationの1/20毎とする
        u = np.zeros_like(time)
        for i in range(len(time)):
            if (i // int(step_time * sampling_rate)) % 2 == 1:
                u[i] = 1
        return u
    elif step_signal_type ==2: #---Chirp信号---
        frequency_start = 0.01  # チャープ信号の開始周波数 (Hz)
        frequency_end = 1 # チャープ信号の終了周波数 (Hz)
        u = np.sin(2 * np.pi * np.cumsum(np.linspace(frequency_start, frequency_end, len(time)) / sampling_rate)) / 2 + 0.5
        return u
    else: #---白色ノイズ---
        A = 1.0    # 振幅
        u = 2*A*(np.random.rand(round(sampling_rate*duration))-0.5)
        return u
        
u_train = if_elif(train_data_type) 
u_vali =  if_elif(vali_data_type) 

#-----------lsimで出力(応答)計算---------
def calc_y(P, u, t):
    x0 = 0 #初期値
    y, time_tmp, x_tmp = lsim(P, U=u, T=t, X0=x0)
    return y
    
y_act_train = calc_y(P,u_train,time)
y_act_vali = calc_y(P,u_vali,time)

#---------学習用データのグラフ表示---------
plt.figure(figsize=(10, 5))

#全体表示
plt.subplot(2, 1, 1)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_train, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_train, label="Output", color='r', linewidth=1)
#plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
plt.title("Train data check")
plt.legend()
plt.grid(True)

#拡大表示
plt.subplot(2, 1, 2)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_train, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_train, label="Output", color='r', linewidth=1)
plt.xlim(80,100)
plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
#plt.title("Train data")
plt.legend()
plt.grid(True)

#---------検証用データのグラフ表示---------
plt.figure(figsize=(10, 5))

#全体表示
plt.subplot(2, 1, 1)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_vali, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_vali, label="Output", color='r', linewidth=1)
plt.ylabel("Signal [-]")
plt.title("Validation data check")
plt.legend()
plt.grid(True)

#拡大表示
plt.subplot(2, 1, 2)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_vali, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_vali, label="Output", color='r', linewidth=1)
plt.xlim(0,25)
plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
#plt.title("Train data")
plt.legend()
plt.grid(True)

#---------時系列モデルの学習用および検証用データを作成(入力をシーケンス長毎のベクトルとする)---------
def create_sequences(x_org, y_org, seq_length):
    xs, ys = [], []
    for i in range(len(y_org) - seq_length + 1):
        x = x_org[i:i+seq_length]
        y = y_org[i+seq_length-1]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

# 訓練データ
x_train, y_train = create_sequences(u_train, y_act_train,sequence_length)
y_train = y_train.reshape(-1, 1)
time_train = time[sequence_length-1:]

#検証用データ
x_vali, y_vali = create_sequences(u_vali, y_act_vali,sequence_length)
y_vali = y_vali.reshape(-1, 1)
time_vali = time[sequence_length-1:]

#---------Tensor型へ変換する---------
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train =  torch.tensor(y_train, dtype=torch.float32)
time_train = torch.tensor(time_train, dtype=torch.float32)
x_vali = torch.tensor(x_vali, dtype=torch.float32)
y_vali =  torch.tensor(y_vali, dtype=torch.float32)
time_vali = torch.tensor(time_vali, dtype=torch.float32)

#---------学習用シーケンスデータの確認---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_train.numpy(), x_train.numpy()[:,0], label='x_train[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_train.numpy(), x_train.numpy()[:,sequence_length-1], label='x_train[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train', linewidth=1, color='r')
plt.title('Train data check')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_train.numpy(), x_train.numpy()[:,0], label='x_train[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_train.numpy(), x_train.numpy()[:,sequence_length-1], label='x_train[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train', linewidth=1, color='r')
plt.xlabel('Time[s]')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(0,20)

plt.show()

# ---------検証用シーケンスデータの確認---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_vali.numpy(), x_vali.numpy()[:,0], label='x_vali[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_vali.numpy(), x_vali.numpy()[:,sequence_length-1], label='x_vali[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=1, color='r')
plt.title('Validation data check')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_vali.numpy(), x_vali.numpy()[:,0], label='x_vali[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_vali.numpy(), x_vali.numpy()[:,sequence_length-1], label='x_vali[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=1, color='r')
plt.xlabel('Time[s]')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(0,20)

plt.show()

# ---------MLPの定義---------
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(sequence_length, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

#------------RNNの定義-------------
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size) 

    def forward(self, x):
        x, h = self.rnn(x, None)
        x = self.fc(x.squeeze(1)) 
        return x

#------------LSTMの定義-------------
class LSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x

#モデルのインスタンス作成
input_size = sequence_length
hidden_size = 32
num_layers = 1 
output_size = 1

if model_type == 1:
    model = MLP()
elif model_type == 2:
    model = RNN()
else:
    model = LSTM()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

#モデルのサマリー表示
summary(model=model)
#summary(model=model,input_size=(1,sequence_length))

# ---------モデルの学習---------
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % (num_epochs/10) == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.10f}')

# ---------学習用データと検証用データの予測---------
model.eval()
with torch.no_grad():
    y_train_pred = model(x_train)
    y_vali_pred = model(x_vali)

# ---------学習データにおける予測結果のプロット(時系列)---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train_Act', linewidth=2, color='b')
plt.plot(time_train.numpy(), y_train_pred.numpy(), label='y_train_Pred', linewidth=2, linestyle="--", color='r')
plt.title('Prediction result with train data')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train_Act', linewidth=2, color='b')
plt.plot(time_train.numpy(), y_train_pred.numpy(), label='y_train_Pred', linewidth=2, linestyle="--", color='r')
plt.xlabel('Time[s]')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(70,90)

plt.show()

# ---------検証データにおける予測結果のプロット(時系列)---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=2, color='b')
plt.plot(time_vali.numpy(), y_vali_pred.numpy(), label='y_vali_Pred', linewidth=2, color='r', linestyle="--")

plt.title('Prediction result with validation data')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=2, color='b')
plt.plot(time_vali.numpy(), y_vali_pred.numpy(), label='y_vali_Pred', linewidth=2, color='r', linestyle="--")
plt.xlabel('Time[s]')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(0,25)

plt.show()

# ---------予測結果のプロット(散布図)---------
plt.figure(figsize=(5, 5))
plt.scatter( y_train.numpy(), y_train_pred.numpy(), label='Train Data',s=10)
plt.scatter( y_vali.numpy(), y_vali_pred.numpy(), label='Vali Data',s=10)
plt.plot([0, 1], [0, 1], color='gray', label='1:1 line')
plt.title('Prediction result')
plt.xlabel('Y act')
plt.ylabel('Y pred')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.show()

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?