0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pytorch Lightning基礎実装ガイド

Last updated at Posted at 2024-09-23

PyTorch Lightning の基礎的な実装ガイド

1. はじめに

PyTorch Lightningは、PyTorchのコードをよりシンプルかつ整理された形で書くためのフレームワークです。特に深層学習モデルの訓練において、訓練ループやロギング、最適化などを自動化し、コードの可読性やメンテナンス性を向上させます。本記事では、Irisデータセットを使い、PyTorch Lightningを用いた基本的なネットワークの実装と学習の流れについて解説します。

2. ライブラリのインストール

まずは、必要なライブラリをインストールします。

!pip install pytorch_lightning torchmetrics

次に、必要なモジュールをインポートします。

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_iris
import torchmetrics.functional as FM

3. データセットの準備

Irisデータセットを使用し、PyTorchのTensorDataset形式に変換します。データは訓練、検証、テストに分割されます。

# Irisデータセットの読み込み
iris = load_iris()
x = iris['data']
t = iris['target']

# Tensorに変換
x = torch.tensor(x, dtype=torch.float32)
t = torch.tensor(t, dtype=torch.int64)

# TensorDatasetの作成
dataset = torch.utils.data.TensorDataset(x, t)

# データセットの分割
n_train = int(len(dataset) * 0.6)
n_val = int(len(dataset) * 0.2)
n_test = len(dataset) - n_train - n_val
torch.manual_seed(0)  # シード固定
train, val, test = torch.utils.data.random_split(dataset, [n_train, n_val, n_test])

# DataLoaderの定義
batch_size = 10
train_loader = torch.utils.data.DataLoader(train, batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size)
test_loader = torch.utils.data.DataLoader(test, batch_size)

4. モデルの定義

PyTorch LightningのLightningModuleクラスを継承し、ニューラルネットワークとその学習過程を定義します。

class Net(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 4)
        self.fc2 = nn.Linear(4, 3)

    def forward(self, x):
        h = self.fc1(x)
        h = F.relu(h)
        h = self.fc2(h)
        return h

    def training_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        # ログの追加
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', FM.accuracy(y, t, task='multiclass', num_classes=3), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        # ログの追加
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', FM.accuracy(y, t, task='multiclass', num_classes=3), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        # ログの追加
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test_acc', FM.accuracy(y, t, task='multiclass', num_classes=3), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        return optimizer

5. モデルの訓練

Trainerクラスを使って、学習の進行を管理します。

pl.seed_everything(0)  # シードの固定
net = Net()

trainer = pl.Trainer(max_epochs=30, accelerator='gpu', devices=1, deterministic=True)
trainer.fit(net, train_loader, val_loader)

6. テストの実行

訓練が完了したら、テストデータでモデルの性能を評価します。

results = trainer.test(dataloaders=test_loader)
print(results)

7. 結果の確認

テスト結果は、resultsに含まれています。損失や正解率を確認できます。

# テスト結果の表示
trainer.callback_metrics

8. まとめ

PyTorch Lightningを使うことで、モデルの訓練プロセスをシンプルにし、再現性のある結果を得ることができます。本記事では、Irisデータセットを使った基礎的なネットワークの実装例を紹介しました。PyTorch Lightningの強力な機能により、学習ループや評価指標の計算が自動化され、研究開発がより効率的になります。

9. イメージ図

image.png
image.png

10. 参考資料

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?