PyTorch -> PyTorchLightning
見たので、動画とdocumentationを見ながら、手元の image classification model を refactoring してみました。
PyTorchLightningを活用しきれていないです。
%pip install pytorch-lightning
import pytorch_lightning as pl
print(pl.__version__)
# '1.7.1'(2022.8)
Grasp the picture of my project
- make CONFIG class
- import libraries
- define
- make_directory, make_pathlist
- transform
- Dataset
- PyTorch_Lightning
- Trainer
1. make CONFIG class
import torch
import pprint
import datetime
dt_now = datetime.datetime.now()
class CFG:
setting = {"MODEL_NAME" : "efficientnet_b3", "model_library" : "torchvision", "pretrained" : True, "Version" : "trial"}
seed = 0
num_workers = 4
n_class = 0 # 後で上書き
resize = 224
batch_size = 32
epochs = 5
model_lr = 1e-3
weight_decay = 1e-4
beta1 = 0.9
beta2 = 0.999
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
n_fold = 5
criterion = "nn.CrossEntropyLoss()"
optimizer = "optim.Adam(self.parameters(), lr=CFG.model_lr)"
scheduler = "lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)"
print(dt_now)
print(f"MODEL_NAME : {CFG.setting['MODEL_NAME']}")
print(f"model_library : {CFG.setting['model_library']} / pretrained : {CFG.setting['pretrained']}")
print(f"Version : {CFG.setting['Version']}")
print(f"batch_size : {CFG.batch_size}")
print(f"optimizer : {CFG.optimizer}")
print(f"scheduler : {CFG.scheduler}")
2022-08-23 00:40:12.595750
MODEL_NAME : efficientnet_b3
model_library : torchvision / pretrained : True
Version : trial
batch_size : 32
optimizer : optim.Adam(self.parameters(), lr=CFG.model_lr)
scheduler : lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)
optimizer とか scheduler とか変更する可能性があるものはここに文字列として入れてます。
optimizer = eval(CFG.optimizer)
CFGのクラス内変数は文字列ですが、eval関数を使って、参照させてます。
2. import libraries
今回使ってないものも入ってます...闇鍋状態...
import os
import glob
import sys
import pandas as pd
import numpy as np
import random
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
# sns.set_style("ticks")
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import timm
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.callbacks.progress import ProgressBarBase
from sklearn.model_selection import train_test_split, StratifiedKFold
import torchmetrics
from torchmetrics import Accuracy, MetricCollection, Precision, Recall
import warnings
warnings.filterwarnings('ignore')
3. define
def torch_seed(seed=0):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True
def load_model():
if CFG.setting["model_library"] == "torchvision":
# torchvision
net = eval(f'models.{CFG.setting["MODEL_NAME"]}')(pretrained=CFG.setting["pretrained"])
# resnet
if "resne" in CFG.setting["MODEL_NAME"]:
fc_in_features = net.fc.in_features
net.fc = nn.Linear(fc_in_features, CFG.n_class)
# vgg
if "vgg" in CFG.setting["MODEL_NAME"]:
fc_in_features = net.classifier[6].in_features
net.classifier[6] = nn.Linear(fc_in_features, CFG.n_class)
# efficient
if "efficientnet_" in CFG.setting["MODEL_NAME"]:
fc_in_features = net.classifier[1].in_features
net.classifier[1] = nn.Linear(fc_in_features, CFG.n_class)
# vit
if "vit_" in CFG.setting["MODEL_NAME"]:
fc_in_features = net.heads.head.in_features
net.heads.head = nn.Linear(fc_in_features, CFG.n_class)
elif CFG.model_library == "timm":
net = timm.create_model(CFG.MODEL_NAME, pretrained = CFG.pretrained, num_classes = CFG.n_class)
return net
def visualize_process(train_loss_log, valid_loss_log, train_acc_log, valid_acc_log):
"""
train_loss_log, valid_loss_log, train_acc_log, valid_acc_log : list
"""
fig = plt.figure(figsize=(15,5))
ax = plt.subplot(1, 2, 1)
plt.plot(train_loss_log, label="train")
plt.plot(valid_loss_log, label="valid")
plt.title("Loss_log")
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.legend()
ax = plt.subplot(1, 2, 2)
plt.plot(train_acc_log, label="train")
plt.plot(valid_acc_log, label="valid")
plt.title("Acc_log")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend()
plt.show()
return fig
torch_seed : seed固定
load_model : CFGに基づいてmodelをloadします。
visualize_process : 訓練・検証ログのlistからグラフを描きます
4. make_directory , make_pathlist
rootpath = "/root/workspace/Pytorch_Lightning/data"
rootpath
├── class_1
│ ├── class_1の画像1.jpg
│ ├── class_1の画像2.jpg
│ └── class_1の画像n.jpg
├── class_2
│ ├── class_2の画像1.jpg
│ ├── class_2の画像2.jpg
│ └── class_2の画像n.jpg
├── class_3
│ ├── class_3の画像1.jpg
│ ├── class_3の画像2.jpg
│ └── class_3の画像n.jpg
└── class_...
├── class_...の画像1.jpg
├── class_...の画像2.jpg
└── class_...の画像n.jpg
ありがちなdierctoryを想定してます。
rootpath 以下のフォルダを数えて、分類クラスの数も確定させてます。
# rootpath
rootpath = "/root/workspace/Pytorch_Lightning/data"
# param_save_path
parampath = f"/root/workspace/Pytorch_Lightning/params/{CFG.setting['MODEL_NAME']}_{CFG.setting['Version']}"
os.makedirs(parampath, exist_ok=True)
print(f"rootpath : {rootpath}")
print(f"parampath : {parampath}")
# listdirの下に .DS_Store が勝手に生成されてることがあるので対策してます...
class_name_list = [class_name for class_name in os.listdir(rootpath) if class_name != ".DS_Store"]
CFG.n_class = len(class_name_list) # クラスの数
filepath = os.path.join(rootpath , "**/*.jpg")
path_list = [path for path in glob.glob(filepath)]
print(f"n_path : {len(path_list)}")
print(f"n_class : {CFG.n_class}")
rootpath : /root/workspace/Pytorch_Lightning/data
parampath : /root/workspace/Pytorch_Lightning/params/efficientnet_b3_trial
n_path : 144
n_class : 3
pathの指定は環境によって変更が必要です。
5. transform
class train_ImageTransform():
def __init__(self, resize, mean, std):
self.data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((resize, resize)),
transforms.RandomRotation(degrees=(-10, 10)), # rotate
transforms.RandomHorizontalFlip(p=0.5), # Flip
transforms.Normalize(mean, std), # Normalize
transforms.RandomErasing(p=0.8, scale=(0.06, 0.20), ratio=(0.1, 4.0)), # mask
])
def __call__(self, img):
return self.data_transform(img)
class valid_ImageTransform():
def __init__(self, resize, mean, std):
self.data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((resize, resize)),
transforms.Normalize(mean, std),
])
def __call__(self, img):
return self.data_transform(img)
6. Dataset
画像ファイルへの path と label の2列からなる DataFrame を作って、そこから Dataset を作ることにします。
# make_pathlist
filepath = os.path.join(rootpath , "**/*.jpg")
path_list = [path for path in glob.glob(filepath)]
class_name_list = [class_name for class_name in os.listdir(rootpath) if class_name != ".DS_Store"]
# make_class / class_name(str) -> label(int) の変換を行うdictを作ります
label_map = dict()
for i, class_name in enumerate(class_name_list):
label_map[class_name] = i
# 作ったdictでlabel_listにします。intでリストに格納されます
label_list = [label_map[path.split("/")[-2]] for path in path_list]
# DataFrame にする
df = pd.DataFrame({
"path" : path_list,
"label" : label_list
})
# print(df.shape) # (len(file_list),2)
この df から Dataset を作るクラスを作ります。
class ImageDataset(Dataset):
def __init__(self, df, transform=None):
super().__init__()
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, index):
row = self.df.iloc[index]
img_path = row.path
img = self.transform(Image.open(img_path).convert("RGB"))
label = row.label
return img, label
動作確認
train_transform = train_ImageTransform(CFG.resize, CFG.mean, CFG.std)
train_dataset = ImageDataset(df, transform=train_transform)
print(len(train_dataset)) # n_pathlist
print(train_dataset[0][0].shape, train_dataset[0][1]) # torch.Size([3, 224, 224]) labelを表すint
この df は StratifiedKFold して使うことになります。
7. PyTorch_Lightning
datasetができたら PyTorchLightning を使います。
- DataLoaderの作成
- 訓練・検証ループの作成
の2つを作ります。
train_datasetとvalid_datasetを渡すとこのクラスででDataLoaderが作れます。
# make DataLoader for Pytorch_Lightning
class PlDataModule(pl.LightningDataModule):
def __init__(self, train_dataset, valid_dataset):
super().__init__()
self.train_dataset = train_dataset
self.valid_dataset = valid_dataset
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
def val_dataloader(self):
return DataLoader(self.valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)
PyTorchLightningでLoop部分を作ります。
# Loop
class Net(pl.LightningModule):
# make_model
def __init__(self, cv):
super(Net, self).__init__()
# model
self.model = load_model()
# cv
self.cv = cv
# other
self.lr = CFG.model_lr
self.criterion = eval(CFG.criterion)
# metrics
self.classes = CFG.n_class
self.accuracy = torchmetrics.Accuracy()
self.F1score = torchmetrics.F1Score(num_classes=self.classes, average='macro')
# log_list
self.train_loss_log, self.train_acc_log = list(), list()
self.valid_loss_log, self.valid_acc_log = list(), list()
self.F1_log = list()
# forward
def forward(self, x):
x = self.model(x)
return x
# optimizer & scheduler
def configure_optimizers(self):
optimizer = eval(CFG.optimizer)
scheduler = eval(CFG.scheduler)
return [optimizer],[scheduler]
# 1.training_step
def training_step(self, batch, batch_idx):
# x:image y:label
imgs, labels = batch # imgs:[B, C, H, W], labels:[B]
logits = self.forward(imgs) # [B, class]
loss = self.criterion(logits, labels)
# display log
self.log("train_acc", self.accuracy(logits, labels),prog_bar=False,logger=True,
on_epoch=True, on_step=False)
print(".", end='')
return {"loss" : loss, "logits" : logits,
"labels" : labels, "batch_loss" : loss.item()*imgs.size(0)}
# returnのDict内の"loss"を基にモデルがアップデートされる
# 2.validation_step
def validation_step(self, batch, batch_idx):
imgs, labels = batch
logits = self.forward(imgs)
loss = self.criterion(logits, labels)
# display log
self.log("valid_acc", self.accuracy(logits, labels),prog_bar=False,logger=True,
on_epoch=True, on_step=False)
return {"valid_loss" : loss, "logits" : logits,
"labels" : labels, "batch_loss" : loss.item()*imgs.size(0)}
# 3.validation_epoch_end
def validation_epoch_end(self, val_step_outputs):
logits = torch.cat([o["logits"] for o in val_step_outputs], dim=0) # [n_train_dataset, class]
labels = torch.cat([o["labels"] for o in val_step_outputs], dim=0) # [n_train]
epoch_loss = sum([o["batch_loss"] for o in val_step_outputs]) / logits.size(0)
acc = self.accuracy(logits, labels)
F1 = self.F1score(logits, labels)
self.valid_loss_log.append(epoch_loss)
self.valid_acc_log.append(acc.item())
self.F1_log.append(F1.item())
# display log
self.log("valid_acc_epoch", acc, prog_bar=True,logger=True,
on_epoch=True, on_step=False)
# # save記述
if np.argmax(self.valid_acc_log) == self.current_epoch:
torch.save(self.model.state_dict(), os.path.join(parampath, f"best_acc_{self.cv}.pth"))
# 4.training_epoch_end
def training_epoch_end(self, train_step_outputs):
# リストの中に return の dict が入ってる
# print(train_step_outputs) # [ [training_step_outputs], [training_step_outputs],...,[training_step_outputs] ]
logits = torch.cat([o["logits"] for o in train_step_outputs], dim=0) # [n_train_dataset, class]
labels = torch.cat([o["labels"] for o in train_step_outputs], dim=0) # [n_train]
epoch_loss = sum([o["batch_loss"] for o in train_step_outputs]) / logits.size(0)
acc = self.accuracy(logits, labels)
self.train_loss_log.append(epoch_loss)
self.train_acc_log.append(acc.item())
# display log
self.log("train_acc_epoch", acc, prog_bar=True,logger=True,
on_epoch=True, on_step=False)
print()
print(f"epoch : {self.current_epoch} >>> ", end='')
print(f"train_loss : {self.train_loss_log[-1]:.4f} / train_acc : {self.train_acc_log[-1]:.4f} ", end='')
print(f"/ valid_loss : {self.valid_loss_log[-1]:.4f} / valid_acc : {self.valid_acc_log[-1]:.4f} / F1 : {self.F1_log[-1]:.4f}")
print()
-
__init__の部分でmodel, learning_rate, criterion, 評価指標の計算機とかを決めます。
-
forwardは [B, C, H, W]の画像のテンソルを渡した時の挙動で、モデルに突っ込むよということです。
-
configure_optimizers で optimizerとscheduler を設定
-
training_step -> validation_step -> validation_epoch_end -> training_epoch_end という順番で動いていきます。記述の順番に制約はありません。
-
訓練回数は self.current_epoch で取得できます。
-
self.logはlogが残せます。この辺りがまだあまり理解できていません。
progress_barを表示した時に載せるような設定にできます。
あとは tensor_board のような logger に送ることができたりという設定ができるようです。
あまり理解ができてない故、強引に表示してます。ここはもっとスマートにできると思われます...。 -
PyTorch Lightningでは net.train(),net.eval(),to(device),optimizer.zero_grad()とかもろもろの記述が不要、というか書かないでも中で自動でやってくれているようです。
-
CrossValidationで回したいので、cvという変数をそのまま追加しています。訓練中に検証時のacc最高のモデルを保存するようにしています。
-
validation_epoch_endの引数のval_step_outputs の部分は
printすれば分かりますが、リストの中に validation_step の return が dict で入ってます。
epoch_endの中ではこの dict を処理する記述をしてます。
import torch
d1 = {'loss': 1, 'labels': torch.tensor([0,1,2]), 'logits': torch.tensor([[1,1,1], [2,2,2], [3,3,3]])}
d2 = {'loss': 10, 'labels': torch.tensor([1,1,1]), 'logits': torch.tensor([[10,10,10], [20,20,20], [30,30,30]])}
d3 = {'loss': 100, 'labels': torch.tensor([2,2,0]), 'logits': torch.tensor([[0.1,0.1,0.1], [0.2,0.2,0.2], [0.3,0.3,0.3]])}
outputs = [d1, d2, d3]
print(outputs)
for o in outputs:
print("*****")
print(o)
print()
sum_ = sum([o["loss"] for o in outputs])
print(sum_)
cat_ = torch.cat([ o["logits"] for o in outputs ], dim=0)
print(cat_)
print(cat_.shape)
こんな感じのことをしてるはずです。
8. Trainer
PyTorchLightningのdocumentでは以下のような記述が例として記述されています。
model = MyLightningModule()
trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)
今回は MyLightningModule() は Net() と記述、dataloaderは PlDataModuleクラスで作成、cvの回数をそのまま記述したので、
data_module = PlDataModule(train_dataset, valid_dataset)
# Trainer
model = Net(cv=cv)
trainer = pl.Trainer(max_epochs=CFG.epochs, gpus=1, callbacks=progress, num_sanity_val_steps=0)
trainer.fit(model,datamodule=data_module)
のように書くことにします。
max_epochsは訓練回数。並列・分散しないのであれば、gpus=1。
callbacksの部分では progress_barを消す設定をしました。progress_barの制御がVSCodeで難しかったので...
# ProgressBar setting
class NotebookProgress(ProgressBarBase):
def __init__(self):
super().__init__()
self.enable = True
def disable(self):
self.enable = True
progress = NotebookProgress()
callbacksでは他にも様々な設定ができるようです。checkpointとかearlystoppingとか。
num_sanity_val_stepsの指定ですが、ここを指定しないと、training_step -> validation_step -> validation_epoch_end -> training_epoch_end のループに入る前に、訓練していない状態で validation が先に1回だけ回るという設定がデフォルトになっています。
こちらに設定の意図などとても分かりやすく書いてありました。
では最後に、StratifiedKFoldの枠組みで Trainer を回します。
# fix seed
pl.seed_everything(CFG.seed)
# ProgressBar setting
progress = NotebookProgress()
# transform
train_transform = train_ImageTransform(CFG.resize, CFG.mean, CFG.std)
valid_transform = valid_ImageTransform(CFG.resize, CFG.mean, CFG.std)
skf = StratifiedKFold(CFG.n_fold, shuffle=True, random_state=CFG.seed)
cv = 0
cv_log = {}
for train_index, valid_index in skf.split(df, df["label"]):
cv +=1
print(f"************** fold: {cv} ***************")
# df => train_df, valid_df
train_df, valid_df = df.iloc[train_index], df.iloc[valid_index]
# dataset
train_dataset = ImageDataset(train_df, transform=train_transform)
valid_dataset = ImageDataset(valid_df, transform=valid_transform)
# data_module
data_module = PlDataModule(train_dataset, valid_dataset)
model = Net(cv=cv)
trainer = pl.Trainer(max_epochs=CFG.epochs, gpus=1, callbacks=progress, num_sanity_val_steps=0)
trainer.fit(model,datamodule=data_module)
# plot
fig = visualize_process(model.train_loss_log, model.valid_loss_log, model.train_acc_log, model.valid_acc_log)
print(f"***** fold:{cv}:best_acc *****")
print(max(model.valid_acc_log))
print()
# cvごとのlog
# cv_logはmapになっていて、cvの数字によって log を保存
cv_log[f"cv{cv}_train_loss_log"] = model.train_loss_log
cv_log[f"cv{cv}_train_acc_log"] = model.train_acc_log
cv_log[f"cv{cv}_valid_loss_log"] = model.valid_loss_log
cv_log[f"cv{cv}_valid_acc_log"] = model.valid_acc_log
cv_log[f"cv{cv}_valid_F1_log"] = model.F1_log
# 訓練終了後に一番良い値を表示する
for i in range(CFG.n_fold):
print(f"********** cv{i+1} **********")
best_acc = max(cv_log[f"cv{i+1}_valid_acc_log"])
best_acc_index = np.argmax(cv_log[f"cv{i+1}_valid_acc_log"])
print("best_acc_index: {}, best_acc: {:.04f}".format(best_acc_index, best_acc))
best_loss = min(cv_log[f"cv{i+1}_valid_loss_log"])
best_loss_index = np.argmin(cv_log[f"cv{i+1}_valid_loss_log"])
print("best_loss_index: {}, best_loss: {:.04f}".format(best_loss_index, best_loss))
best_F1 = max(cv_log[f"cv{i+1}_valid_F1_log"])
best_F1_index = np.argmax(cv_log[f"cv{i+1}_valid_F1_log"])
print("best_F1_index: {}, best_F1: {:.04f}".format(best_F1_index, best_F1))
poorな部分
- 今回、訓練結果を強引に表示していますが、これはスマートなやり方がありそうです。
- CrossValidation 自体もこのように for 文で記述しなくても、どこかで定義できる方法があると思われます。
またdocumentを見て改善します...
参考