PyTorch Lightning の基礎的な実装ガイド
1. はじめに
PyTorch Lightningは、PyTorchのコードをよりシンプルかつ整理された形で書くためのフレームワークです。特に深層学習モデルの訓練において、訓練ループやロギング、最適化などを自動化し、コードの可読性やメンテナンス性を向上させます。本記事では、Irisデータセットを使い、PyTorch Lightningを用いた基本的なネットワークの実装と学習の流れについて解説します。
2. ライブラリのインストール
まずは、必要なライブラリをインストールします。
!pip install pytorch_lightning torchmetrics
次に、必要なモジュールをインポートします。
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_iris
import torchmetrics.functional as FM
3. データセットの準備
Irisデータセットを使用し、PyTorchのTensorDataset
形式に変換します。データは訓練、検証、テストに分割されます。
# Irisデータセットの読み込み
iris = load_iris()
x = iris['data']
t = iris['target']
# Tensorに変換
x = torch.tensor(x, dtype=torch.float32)
t = torch.tensor(t, dtype=torch.int64)
# TensorDatasetの作成
dataset = torch.utils.data.TensorDataset(x, t)
# データセットの分割
n_train = int(len(dataset) * 0.6)
n_val = int(len(dataset) * 0.2)
n_test = len(dataset) - n_train - n_val
torch.manual_seed(0) # シード固定
train, val, test = torch.utils.data.random_split(dataset, [n_train, n_val, n_test])
# DataLoaderの定義
batch_size = 10
train_loader = torch.utils.data.DataLoader(train, batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size)
test_loader = torch.utils.data.DataLoader(test, batch_size)
4. モデルの定義
PyTorch LightningのLightningModule
クラスを継承し、ニューラルネットワークとその学習過程を定義します。
class Net(pl.LightningModule):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 4)
self.fc2 = nn.Linear(4, 3)
def forward(self, x):
h = self.fc1(x)
h = F.relu(h)
h = self.fc2(h)
return h
def training_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.cross_entropy(y, t)
# ログの追加
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', FM.accuracy(y, t, task='multiclass', num_classes=3), on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.cross_entropy(y, t)
# ログの追加
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('val_acc', FM.accuracy(y, t, task='multiclass', num_classes=3), on_step=True, on_epoch=True, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, t = batch
y = self(x)
loss = F.cross_entropy(y, t)
# ログの追加
self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('test_acc', FM.accuracy(y, t, task='multiclass', num_classes=3), on_step=True, on_epoch=True, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
return optimizer
5. モデルの訓練
Trainer
クラスを使って、学習の進行を管理します。
pl.seed_everything(0) # シードの固定
net = Net()
trainer = pl.Trainer(max_epochs=30, accelerator='gpu', devices=1, deterministic=True)
trainer.fit(net, train_loader, val_loader)
6. テストの実行
訓練が完了したら、テストデータでモデルの性能を評価します。
results = trainer.test(dataloaders=test_loader)
print(results)
7. 結果の確認
テスト結果は、results
に含まれています。損失や正解率を確認できます。
# テスト結果の表示
trainer.callback_metrics
8. まとめ
PyTorch Lightningを使うことで、モデルの訓練プロセスをシンプルにし、再現性のある結果を得ることができます。本記事では、Irisデータセットを使った基礎的なネットワークの実装例を紹介しました。PyTorch Lightningの強力な機能により、学習ループや評価指標の計算が自動化され、研究開発がより効率的になります。