0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

初心者向けPyTorch篇:2.線形回帰モデルの基本

Last updated at Posted at 2024-10-19

PyTorchを使った線形回帰モデルの基本

線形回帰は、データ間の関係をモデル化する最も基本的な手法の一つです。この記事では、PyTorchを使用して線形回帰モデルを構築する方法を初心者向けに解説します。

データの準備

まずは、以下のリンクから収入と年齢のデータセットをダウンロードします。

Income Dataset

このデータセットには、収入と年齢のほかに教育年数や性別の情報も含まれていますが、今回は収入と年齢の関係に焦点を当てます。

import pandas as pd

data = pd.read_csv('income.csv')
print(data.head())

出力例:

   ID  Income  Age  Education  Gender
0   1     113   69         12       1
1   2      91   52         18       0
2   3     121   65         14       0
3   4      81   58         12       0
4   5      68   31         16       1

データの可視化

データを可視化することで、収入と年齢の関係を目で見て確認します。

import matplotlib.pyplot as plt

plt.scatter(data['Age'], data['Income'], alpha=0.3)
plt.xlabel("Age")
plt.ylabel("Income")
plt.show()

1.png

モデルの構築

PyTorchを使い、線形回帰モデルを構築します。

import torch
from torch import nn

# データをTensorに変換
X = torch.tensor(data['Age'].values, dtype=torch.float32).reshape(-1, 1)
Y = torch.tensor(data['Income'].values, dtype=torch.float32).reshape(-1, 1)

# モデル定義
class EIModel(nn.Module):  # nn.Moduleを継承する
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)  # 入力xを受け取り、線形変換を適用して出力を返す

model = EIModel()  # モデルインスタンスの作成

このモデルは、入力特徴量として1つの数値を受け取り、それに基づいて1つの数値(予測値)を出力します。これにより、例えば年齢から収入を予測するようなタスクに使用することができます。
※ nn.Linear(1, 1)の引数1, 1は、入力特徴量(年齢)と出力特徴量(収入)がそれぞれ1つであることを示しています。

学習プロセス

損失関数とオプティマイザーを定義し、モデルを訓練します。

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# 訓練ループ
for epoch in range(100):  # ここで100回訓練する
    for x, y in zip(X, Y):
        y_pred = model(x)  # モデルによる予測
        loss = loss_fn(y_pred, y)  # 損失の計算
        opt.zero_grad()  # 前回ループの勾配の初期化する
        loss.backward()  # 逆伝播
        opt.step()  # パラメータの更新

モデルの評価

最後に、学習したモデルを使って予測結果をプロットします。

plt.scatter(data['Age'], data['Income'], alpha=0.3)
plt.plot(X.numpy(), model(X).detach().numpy(), color='red')
plt.xlabel("Age")
plt.ylabel("Income")
plt.show()

2.png

このようにPyTorchを使って線形回帰モデルを構築することで、データ間の関係を簡単に理解し、予測することができます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?