はじめに
この記事ではpytorch-lightningの使い方を、いい感じのプログラムで紹介します
まず結論:ソースコードの書き方
pytorch_lightning_sample.py
import os
import pickle
import numpy as np
from PIL import Image # 画像を取り扱うために使用
import matplotlib.pyplot as plt # 画像のサンプル表示のために使用
import torch # pytorch本体
import torch.nn as nn # ニューラルネットを構成する際の基本的なモジュールが入っている
from torchvision import transforms as transforms # 画像前処理のために使用
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import loggers as pl_loggers
# ==== データローダの作成など ============================================================================================
# データセットに対して、idxで指定された際に読み込み方法を指定するためのラッパークラス
class Dataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
# 画像を変換して整形して保持
self.data = np.array(data) # numpy形式に変換
self.data = self.data.reshape(len(data), 3, 32, 32) # dataを整形
self.data = self.data.transpose(0, 2, 3, 1) # data[ミニバッチのindex][チャンネル][画像縦位置][画像横位置]と指定できるように順序交換
# ラベルを保持
self.labels = labels
# 画像を前処理するための関数たちを登録
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
]
)
# 指定されたindexのデータを辞書形式で返却するように設定する
def __getitem__(self, index):
img, label = self.data[index], self.labels[index] # 指定のデータを取得
img = Image.fromarray(img) # 画像に変換
img = self.transform(img) # transformをかける(tensor型に変換してから、正規化)
return {'inputs':img, 'targets':label} # 辞書形式で返却(辞書のkeyはニューラルネットのforwardの引数と同じ名前にする)
# データセットの個数を返すように設定する
def __len__(self):
return len(self.data)
# サンプル表示用の関数。一般には作らなくてOK。
def plot(self, index):
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
img, label = self.data[index], self.labels[index]
plt.imshow(img) # 指定されたindexの画像を描画
plt.title(f'label={classes[label]}') # titleを"label=クラス名"という形式で設定
plt.show() # 表示
class DataModule(pl.LightningDataModule):
def __init__(self, dataset_path, split_ratio=(0.7, 0.1, 0.2), batch_size=32, thread_num=4):
super().__init__()
self.dataset_path = dataset_path
self.split_ratio = split_ratio
self.batch_size = batch_size
self.thread_num = thread_num
def setup(self, stage=None):
# 成形されたデータセットの読み込み
print(f'loading {self.dataset_path}')
with open(self.dataset_path, 'rb') as f:
dataset_raw = pickle.load(f, encoding='bytes')
print('loading completed')
# データセット形式に変換
ds = Dataset(dataset_raw[b'data'], dataset_raw[b'labels'])
ds.plot(index=np.random.randint(0, len(ds))) # サンプルとして、ランダムに1件選んで描画
# データセットを指定された比率に合わせて分割
total_size = len(ds)
train_size = int(total_size * self.split_ratio[0]) # 学習で使用するデータ個数
valid_size = int(total_size * self.split_ratio[1]) # 検証で使用するデータ個数
test_size = int(total_size * self.split_ratio[2]) # テストで使用するデータ個数
self.train_dataset = torch.utils.data.dataset.Subset(ds, range(0, train_size)) # 指定された部分のデータのみを取り出す
self.valid_dataset = torch.utils.data.dataset.Subset(ds, range(train_size, train_size+valid_size)) # 指定された部分のデータのみを取り出す
self.test_dataset = torch.utils.data.dataset.Subset(ds, range(train_size+valid_size, total_size)) # 指定された部分のデータのみを取り出す
print(f'dataset size: total {total_size}, train {train_size}, validation {valid_size}, test {test_size}')
def train_dataloader(self):
if len(self.train_dataset) != 0:
return torch.utils.data.DataLoader(
self.train_dataset, # データセット
batch_size=self.batch_size, # イテレート時のバッチサイズ
shuffle=True, # イテレート前にデータをシャッフルするか
num_workers=self.thread_num, # イテレート時に使用するスレッド数
pin_memory=True, # メモリを固定して高速化をするか
drop_last=True, # 最後の端数部を落とすか
)
else:
raise Exception('length of dataset is zero.')
def val_dataloader(self):
if len(self.valid_dataset) != 0:
return torch.utils.data.DataLoader(
self.valid_dataset, # データセット
batch_size=self.batch_size, # イテレート時のバッチサイズ
shuffle=False, # イテレート前にデータをシャッフルするか
num_workers=self.thread_num, # イテレート時に使用するスレッド数
pin_memory=True, # メモリを固定して高速化をするか
drop_last=False, # 最後の端数部を落とすか
)
else:
raise Exception('length of dataset is zero.')
def test_dataloader(self):
if len(self.test_dataset) != 0:
return torch.utils.data.DataLoader(
self.test_dataset, # データセット
batch_size=self.batch_size, # イテレート時のバッチサイズ
shuffle=False, # イテレート前にデータをシャッフルするか
num_workers=self.thread_num, # イテレート時に使用するスレッド数
pin_memory=True, # メモリを固定して高速化をするか
drop_last=False, # 最後の端数部を落とすか
)
else:
raise Exception('length of dataset is zero.')
# ====================================================================================================================
class LeNet(nn.Module):
def __init__(self):
# ここにはニューラルネットの構成で使用するモジュールを一通り書き出す。
# Conv2dは畳み込み層、AvgPool2dは平均プーリング層、Flattenは数値を一列に並べてベクトル化する層、Linearは線形層(全結合層)を表している。使い方はググろう。
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=5, padding=0, stride=1) # 32*32 3チャンネル入力 → 28*28 6チャンネル出力
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) # 28*28 6チャンネル入力 → 14*14 6チャンネル出力
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0, stride=1) # 14*14 6チャンネル入力 → 10*10 16チャンネル出力
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) # 10*10 16チャンネル入力 → 5*5 16チャンネル出力
self.flatten = nn.Flatten() # 5*5 25チャンネル入力 → 400(=5*5*25)出力
self.fc1 = nn.Linear(400, 120) # 400入力, 120出力
self.fc2 = nn.Linear(120, 84) # 120入力, 84出力
self.fc3 = nn.Linear(84, 10) # 84入力, 10出力
self.softmax = nn.Softmax(dim=1) # ソフトマックス関数
self.loss = nn.CrossEntropyLoss() # クロスエントロピー損失
def forward(self, inputs, targets):
# ここには順伝搬のやり方を書く。
h = inputs
# 作用
h = torch.sigmoid(self.conv1(h))
h = self.pool1(h)
h = torch.sigmoid(self.conv2(h))
h = self.pool2(h)
h = self.flatten(h)
h = torch.sigmoid(self.fc1(h))
h = torch.sigmoid(self.fc2(h))
h = self.softmax(self.fc3(h))
# 損失を計算
loss = self.loss(h, targets)
# データ形式をdictに包んで出力
return {
"loss":loss, # 損失の値を記録
"hidden_states":h, # 最終層の値を記録
}
# ====================================================================================================================
class LitModule(pl.LightningModule):
# ネットワークモジュールなどの定義
def __init__(self, learning_rate):
super().__init__()
self.model = LeNet()
self.learning_rate = learning_rate
self.save_hyperparameters()
# 順伝搬の処理
def forward(self, **x):
return self.model(**x)
# オプティマイザの定義
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.1, total_iters=self.trainer.max_epochs)
return [optimizer], [scheduler]
# ==================================================================
# 学習のバッチ実行処理
def training_step(self, batch, batch_index):
outputs = self.model(**batch)
loss = outputs['loss']
return {'loss': loss, 'correct': (torch.argmax(outputs["hidden_states"], dim=-1)==batch['targets']).to(torch.float32)}
# 学習の全バッチ終了時の処理
def training_epoch_end(self, outputs):
train_loss = torch.hstack([dict_['loss'] for dict_ in outputs]).mean()
train_accuracy = torch.hstack([dict_['correct'] for dict_ in outputs]).mean()
self.log_dict({"train_loss": train_loss, "train_accuracy": train_accuracy})
self.print({"train_loss": train_loss, "train_accuracy": train_accuracy})
# 検証のバッチ実行処理
def validation_step(self, batch, batch_index):
with torch.no_grad():
outputs = self.model(**batch)
loss = outputs['loss']
return {'loss': loss, 'correct': (torch.argmax(outputs["hidden_states"], dim=-1)==batch['targets']).to(torch.float32)}
# 検証の全バッチ終了時の処理
def validation_epoch_end(self, outputs):
val_loss = torch.hstack([dict_['loss'] for dict_ in outputs]).mean()
val_accuracy = torch.hstack([dict_['correct'] for dict_ in outputs]).mean()
self.log_dict({"val_loss": val_loss, "val_accuracy": val_accuracy})
self.print({"val_loss": val_loss, "val_accuracy": val_accuracy})
# テストのバッチ実行処理
def test_step(self, batch, batch_index):
with torch.no_grad():
outputs = self.model(**batch)
loss = outputs['loss']
return {'loss': loss, 'correct': (torch.argmax(outputs["hidden_states"], dim=-1)==batch['targets']).to(torch.float32)}
# テストの全バッチ終了時の処理
def test_epoch_end(self, outputs):
test_loss = torch.hstack([dict_['loss'] for dict_ in outputs]).mean()
test_accuracy = torch.hstack([dict_['correct'] for dict_ in outputs]).mean()
self.print({"test_loss": test_loss, "test_accuracy": test_accuracy})
# ==================================================================
# ====================================================================================================================
def main():
save_dir = './result'
data_module = DataModule(
dataset_path='./cifar-10-batches-py/data_batch_1',
split_ratio=(0.7, 0.1, 0.2),
batch_size=32,
)
model = LitModule(learning_rate=0.0001)
callbacks = [
ModelCheckpoint(
dirpath=save_dir,
filename='epoch{epoch:02d}-val_loss{val_loss:.2f}', # チェックポイントのファイル名の形式
monitor='val_loss', # 基準とする量
mode="min", # 最小となるところを探す
), # modelのチェックポイントを作成
# EarlyStopping(
# monitor="val_loss", # 基準とする量
# mode="min", # 最小となるところを探す
# ), # early-stoppingを利用
]
trainer = pl.Trainer(
max_epochs=10,
logger=[pl_loggers.TensorBoardLogger(save_dir=save_dir)],
callbacks=callbacks,
accelerator='gpu',
devices=[0], # 使用するGPUのIDのリスト
# auto_lr_find=True, # learning rateを自動で設定するか
# accumulate_grad_batches=1, # 勾配を累積して一度に更新することでバッチサイズを仮想的にN倍にする際のN
# gradient_clip_val=1, # 勾配クリッピングの値
# fast_dev_run=True, # デバッグ時にonにすると、1回だけtrain,validを実行する
# overfit_batches=1.0, # デバッグ時にonにすると、train = validで学習が進み、過学習できているかを確認できる
# deterministic=True, # 再現性のために乱数シードを固定するか
# resume_from_checkpoint='bbb/aaa.ckpt', # チェックポイントから再開する場合に利用
# precision=16, # 小数を何ビットで表現するか
# amp_backend="apex", # 少数の混合方式を使用するかどうか。nvidiaのapexがインストールされている必要あり。
# benchmark=True, # cudnn.benchmarkを使用して高速化するか(determisticがTrueの場合はFalseに上書きされる)
)
# trainer.tune(model, datamodule=data_module) # 「auto_lr_find=True」を指定した場合に実行する
trainer.fit(model, datamodule=data_module)
trainer.test(model, datamodule=data_module)
if __name__ == '__main__':
main()
終わりに
- とりあえずメモ書き程度で残しておきます
- 今後もう少し説明を書くかも