0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[PyTorch] MLPを用いた時系列回帰

Last updated at Posted at 2024-08-17

記事の内容

■ MLPを用いた時系列回帰
■ モデル:MLP ( PyTorch )
■ 学習データおよび検証データ
・伝達関数モデルを用いて学習用データおよび検証用データを作成
・学習用データ作成時の入力信号はChirp信号、検証用データ作成時の入力信号はStep信号

スクリプト

d08_Pytorch_MLP_Time_Sequence.ipynb
from control.matlab import*
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
torch.manual_seed(0);

#伝達関数モデル(学習データ作成用)の定義・・・システムが複雑になると予測精度が落ちる
s = tf('s')
P = 1/(0.5*s+1) #1次遅れ系
#P = 1/(1*s**2+2*s+1) #2次遅れ系

#シーケンス長・・・予測精度に大きな影響がある
#sequence_length = 2
#sequence_length = 16
sequence_length = 32
#sequence_length = 64

#学習用データと検証用データを作成するための入力信号の種類・・・予測精度に大きな影響がある
#1:Step信号、2:Chirp信号、3:白色ノイズ
train_data_type = 2
vali_data_type = 1

# -------入力信号の生成------------
duration = 100  # 信号の継続時間 (秒)
sampling_rate = 10  # サンプリングレート (Hz)

# 時間軸の作成
time = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
print('time.shape', time.shape)

#-------------入力信号の生成---------
step_signal_type = 2
def if_elif(step_signal_type):
    if step_signal_type ==1: #---Step信号---
        step_time = duration/20 #ステップするタイミングをdurationの1/20毎とする
        u = np.zeros_like(time)
        for i in range(len(time)):
            if (i // int(step_time * sampling_rate)) % 2 == 1:
                u[i] = 1
        return u
    elif step_signal_type ==2: #---Chirp信号---
        frequency_start = 0.01  # チャープ信号の開始周波数 (Hz)
        frequency_end = 1 # チャープ信号の終了周波数 (Hz)
        u = np.sin(2 * np.pi * np.cumsum(np.linspace(frequency_start, frequency_end, len(time)) / sampling_rate)) / 2 + 0.5
        return u
    else: #---白色ノイズ---
        A = 1.0    # 振幅
        u = 2*A*(np.random.rand(round(sampling_rate*duration))-0.5)
        return u
        
u_train = if_elif(train_data_type) 
u_vali =  if_elif(vali_data_type) 

#-----------lsimで出力(応答)計算---------
def calc_y(P, u, t):
    x0 = 0 #初期値
    y, time_tmp, x_tmp = lsim(P, U=u, T=t, X0=x0)
    return y
    
y_act_train = calc_y(P,u_train,time)
y_act_vali = calc_y(P,u_vali,time)

#---------学習用データのグラフ表示---------
plt.figure(figsize=(10, 5))

#全体表示
plt.subplot(2, 1, 1)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_train, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_train, linestyle="--", label="Output", color='r', linewidth=1)
#plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
plt.title("Train data check")
plt.legend()
plt.grid(True)

#拡大表示
plt.subplot(2, 1, 2)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_train, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_train, label="Output", color='r', linewidth=1)
plt.xlim(80,100)
plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
#plt.title("Train data")
plt.legend()
plt.grid(True)

#---------検証用データのグラフ表示---------
plt.figure(figsize=(10, 5))

#全体表示
plt.subplot(2, 1, 1)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_vali, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_vali, label="Output", color='r', linewidth=1)
plt.ylabel("Signal [-]")
plt.title("Validation data check")
plt.legend()
plt.grid(True)

#拡大表示
plt.subplot(2, 1, 2)
plt.axhline(0, color="gray", linestyle="--", linewidth=1)
plt.plot(time, u_vali, label="Input", color='b', linewidth=1)
plt.plot(time, y_act_vali, label="Output", color='r', linewidth=1)
plt.xlim(0,25)
plt.xlabel("time (s)")
plt.ylabel("Signal [-]")
#plt.title("Train data")
plt.legend()
plt.grid(True)

#---------時系列モデルの学習用および検証用データを作成(入力をシーケンス長毎のベクトルとする)---------
def create_sequences(x_org, y_org, seq_length):
    xs, ys = [], []
    for i in range(len(y_org) - seq_length + 1):
        x = x_org[i:i+seq_length]
        y = y_org[i+seq_length-1]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

# 訓練データ
x_train, y_train = create_sequences(u_train, y_act_train,sequence_length)
y_train = y_train.reshape(-1, 1)
time_train = time[sequence_length-1:]

#検証データ
x_vali, y_vali = create_sequences(u_vali, y_act_vali,sequence_length)
y_vali = y_vali.reshape(-1, 1)
time_vali = time[sequence_length-1:]

# ---------Tensor型へ変換する---------
x_train = torch.tensor(x_train, dtype=torch.float32)
y_train =  torch.tensor(y_train, dtype=torch.float32)
time_train = torch.tensor(time_train, dtype=torch.float32)
x_vali = torch.tensor(x_vali, dtype=torch.float32)
y_vali =  torch.tensor(y_vali, dtype=torch.float32)
time_vali = torch.tensor(time_vali, dtype=torch.float32)

# ---------学習用シーケンスデータの確認---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_train.numpy(), x_train.numpy()[:,0], label='x_train[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_train.numpy(), x_train.numpy()[:,sequence_length-1], label='x_train[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train', linewidth=1, color='r')
plt.title('Train data check')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_train.numpy(), x_train.numpy()[:,0], label='x_train[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_train.numpy(), x_train.numpy()[:,sequence_length-1], label='x_train[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train', linewidth=1, color='r')
plt.xlabel('Time[s]')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(0,20)

plt.show()

# ---------検証用シーケンスデータの確認---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_vali.numpy(), x_vali.numpy()[:,0], label='x_vali[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_vali.numpy(), x_vali.numpy()[:,sequence_length-1], label='x_vali[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=1, color='r')
plt.title('Validation data check')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_vali.numpy(), x_vali.numpy()[:,0], label='x_vali[:,0]', linewidth=1, color='cyan', linestyle="--")
plt.plot(time_vali.numpy(), x_vali.numpy()[:,sequence_length-1], label='x_vali[:,sequence_length-1]', linewidth=1, color='b')
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=1, color='r')
plt.xlabel('Time[s]')
plt.ylabel('X, Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(0,20)

plt.show()

# ---------MLPモデルの定義---------
hidden_units = 32
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(sequence_length, hidden_units)
        self.fc2 = nn.Linear(hidden_units, hidden_units)
        self.fc3 = nn.Linear(hidden_units, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
# ---------モデルのインスタンス作成---------
model = MLP()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

#---------モデルのサマリー表示---------
from torchinfo import summary
#summary(model=model)
summary(model=model,input_size=(1,sequence_length))

# ---------モデルの学習---------
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % (num_epochs/10) == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.10f}')


# ---------学習用データと検証用データの予測---------
model.eval()
with torch.no_grad():
    y_train_pred = model(x_train)
    y_vali_pred = model(x_vali)

# ---------学習データにおける予測結果のプロット(時系列)---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train_Act', linewidth=2, color='b')
plt.plot(time_train.numpy(), y_train_pred.numpy(), label='y_train_Pred', linewidth=2, linestyle="--", color='r')
plt.title('Prediction result with train data')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_train.numpy(), y_train.numpy(), label='y_train_Act', linewidth=2, color='b')
plt.plot(time_train.numpy(), y_train_pred.numpy(), label='y_train_Pred', linewidth=2, linestyle="--", color='r')
plt.xlabel('Time[s]')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(70,90)

plt.show()

# ---------検証データにおける予測結果のプロット(時系列)---------
plt.figure(figsize=(10, 5))

plt.subplot(2, 1, 1)
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=2, color='b')
plt.plot(time_vali.numpy(), y_vali_pred.numpy(), label='y_vali_Pred', linewidth=2, color='r', linestyle="--")

plt.title('Prediction result with validation data')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(time_vali.numpy(), y_vali.numpy(), label='y_vali', linewidth=2, color='b')
plt.plot(time_vali.numpy(), y_vali_pred.numpy(), label='y_vali_Pred', linewidth=2, color='r', linestyle="--")
plt.xlabel('Time[s]')
plt.ylabel('Y')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.xlim(0,25)
plt.show()

# ---------予測結果のプロット(散布図)---------
plt.figure(figsize=(5, 5))
plt.scatter( y_train.numpy(), y_train_pred.numpy(), label='Train Data',s=10)
plt.scatter( y_vali.numpy(), y_vali_pred.numpy(), label='Vali Data',s=10)
plt.plot([0, 1], [0, 1], color='gray', label='1:1 line')
plt.title('Prediction result')
plt.xlabel('Y act')
plt.ylabel('Y pred')
plt.grid(True, linestyle='dotted')
plt.legend()
plt.show()

スクリプト実行結果

■ 学習用データに対する予測結果
学習用データに対する予測性能は問題なし。
学習用データ.png

■ 検証用データに対する予測結果
検証用データに対する予測性能も概ね問題なしだが、低次数のシステムでは発生しないノイズが見られる
検証用データ.png

考察

学習自体は問題なくできていそう。
Step信号で作成した検証用データでは低次数システムではありえないノイズが発生している。RNNやLSTMでも同様のノイズが発生すると推定されるが確認する必要がある。CNNやNeural ODEでも同様のノイズが発生するかが今後の焦点。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?