1
1

[PyTorch] MLPを用いた単純な回帰問題

Last updated at Posted at 2024-08-03

記事の内容

MLPを用いた単純な回帰モデル
モデル:MLP ( PyTorch )

スクリプト

d01_Pytorch_MLP_Simple_Regression.ipynb
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
torch.manual_seed(0)

# データの生成
data_size = 5000
x = torch.randn(data_size, 1)  # 正規分布に従う乱数
y = 2 * x + 3 + torch.randn(data_size, 1) * 0  # 動作確認のためノイズは0にしているがノイズを追加をしても問題なく動作した。

# データを学習用と検証用に分割
train_size = int(0.8 * len(x))
x_train, x_vali = torch.split(x, [train_size, len(x) - train_size])
y_train_act, y_vali_act = torch.split(y, [train_size, len(y) - train_size])

# MLPモデルの定義
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(1, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# モデルのインスタンス化
model = MLP()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# モデルの学習
num_epochs = 500
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train_act)
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 学習用データと検証用データの予測
model.eval()
with torch.no_grad():
    y_train_pred = model(x_train)
    y_vali_pred = model(x_vali)

# 予測結果のプロット
plt.figure(figsize=(5, 5))
plt.scatter( y_train_act.numpy(), y_train_pred.numpy(), label='Train Data')
plt.scatter( y_vali_act.numpy(), y_vali_pred.numpy(), label='Vali Data')
plt.plot([-10, 10], [-10, 10], color='gray', label='1:1 line')
plt.title('Prediction result')
plt.xlabel('Y act')
plt.ylabel('Y pred')
plt.grid(True, linestyle='dotted')
plt.legend()

plt.show()

スクリプト実行結果

出力例.png

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1