0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchLightning my poor Refactoring

Last updated at Posted at 2022-08-22

PyTorch -> PyTorchLightning

見たので、動画とdocumentationを見ながら、手元の image classification model を refactoring してみました。
PyTorchLightningを活用しきれていないです。

%pip install pytorch-lightning
import pytorch_lightning as pl
print(pl.__version__)
# '1.7.1'(2022.8)

Grasp the picture of my project

  1. make CONFIG class
  2. import libraries
  3. define
  4. make_directory, make_pathlist
  5. transform
  6. Dataset
  7. PyTorch_Lightning
  8. Trainer

1. make CONFIG class

import torch
import pprint
import datetime
dt_now = datetime.datetime.now()

class CFG:
    setting = {"MODEL_NAME" : "efficientnet_b3", "model_library" : "torchvision", "pretrained" : True, "Version" : "trial"}
    seed = 0
    num_workers = 4
    n_class = 0 # 後で上書き
    resize = 224
    batch_size = 32
    epochs = 5
    model_lr = 1e-3
    weight_decay = 1e-4
    beta1 = 0.9
    beta2 = 0.999
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    n_fold = 5
    criterion = "nn.CrossEntropyLoss()"
    optimizer = "optim.Adam(self.parameters(), lr=CFG.model_lr)"
    scheduler = "lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)"

print(dt_now)
print(f"MODEL_NAME : {CFG.setting['MODEL_NAME']}")
print(f"model_library : {CFG.setting['model_library']} / pretrained : {CFG.setting['pretrained']}")
print(f"Version : {CFG.setting['Version']}")
print(f"batch_size : {CFG.batch_size}")
print(f"optimizer : {CFG.optimizer}")
print(f"scheduler : {CFG.scheduler}")

2022-08-23 00:40:12.595750
MODEL_NAME : efficientnet_b3
model_library : torchvision / pretrained : True
Version : trial
batch_size : 32
optimizer : optim.Adam(self.parameters(), lr=CFG.model_lr)
scheduler : lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)

optimizer とか scheduler とか変更する可能性があるものはここに文字列として入れてます。

optimizer = eval(CFG.optimizer)

CFGのクラス内変数は文字列ですが、eval関数を使って、参照させてます。

2. import libraries

今回使ってないものも入ってます...闇鍋状態...

import os
import glob
import sys
import pandas as pd
import numpy as np
import random

from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
# sns.set_style("ticks")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler

from torchvision import datasets, models, transforms
import timm

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.callbacks.progress import ProgressBarBase

from sklearn.model_selection import train_test_split, StratifiedKFold

import torchmetrics
from torchmetrics import Accuracy, MetricCollection, Precision, Recall

import warnings
warnings.filterwarnings('ignore')

3. define

def torch_seed(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

def load_model():
    if CFG.setting["model_library"] == "torchvision":
    # torchvision
        net = eval(f'models.{CFG.setting["MODEL_NAME"]}')(pretrained=CFG.setting["pretrained"])
        
        # resnet
        if "resne" in CFG.setting["MODEL_NAME"]:
            fc_in_features = net.fc.in_features
            net.fc = nn.Linear(fc_in_features, CFG.n_class)
        # vgg
        if "vgg" in CFG.setting["MODEL_NAME"]:
            fc_in_features = net.classifier[6].in_features
            net.classifier[6] = nn.Linear(fc_in_features, CFG.n_class)
        # efficient
        if "efficientnet_" in CFG.setting["MODEL_NAME"]:
            fc_in_features = net.classifier[1].in_features
            net.classifier[1] = nn.Linear(fc_in_features, CFG.n_class)
        # vit
        if "vit_" in CFG.setting["MODEL_NAME"]:
            fc_in_features = net.heads.head.in_features
            net.heads.head = nn.Linear(fc_in_features, CFG.n_class)


    elif CFG.model_library == "timm":
        net = timm.create_model(CFG.MODEL_NAME, pretrained = CFG.pretrained, num_classes = CFG.n_class)

    return net

def visualize_process(train_loss_log, valid_loss_log, train_acc_log, valid_acc_log):
    """
    train_loss_log, valid_loss_log, train_acc_log, valid_acc_log : list
    """
    fig = plt.figure(figsize=(15,5))
    ax = plt.subplot(1, 2, 1)
    plt.plot(train_loss_log, label="train")
    plt.plot(valid_loss_log, label="valid")
    plt.title("Loss_log")
    plt.xlabel("epoch")
    plt.ylabel("Loss")
    plt.legend()

    ax = plt.subplot(1, 2, 2)
    plt.plot(train_acc_log, label="train")
    plt.plot(valid_acc_log, label="valid")
    plt.title("Acc_log")
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    plt.legend()

    plt.show()
    return fig

torch_seed : seed固定
load_model : CFGに基づいてmodelをloadします。
visualize_process : 訓練・検証ログのlistからグラフを描きます

4. make_directory , make_pathlist

rootpath = "/root/workspace/Pytorch_Lightning/data"
rootpath
    ├── class_1
    │     ├── class_1の画像1.jpg
    │     ├── class_1の画像2.jpg
    │     └── class_1の画像n.jpg
    ├── class_2
    │     ├── class_2の画像1.jpg
    │     ├── class_2の画像2.jpg
    │     └── class_2の画像n.jpg
    ├── class_3
    │     ├── class_3の画像1.jpg
    │     ├── class_3の画像2.jpg
    │     └── class_3の画像n.jpg
    └── class_...
          ├── class_...の画像1.jpg
          ├── class_...の画像2.jpg
          └── class_...の画像n.jpg

ありがちなdierctoryを想定してます。
rootpath 以下のフォルダを数えて、分類クラスの数も確定させてます。

# rootpath
rootpath = "/root/workspace/Pytorch_Lightning/data"

# param_save_path
parampath = f"/root/workspace/Pytorch_Lightning/params/{CFG.setting['MODEL_NAME']}_{CFG.setting['Version']}"
os.makedirs(parampath, exist_ok=True)

print(f"rootpath : {rootpath}")
print(f"parampath : {parampath}")

# listdirの下に .DS_Store が勝手に生成されてることがあるので対策してます...
class_name_list = [class_name for class_name in os.listdir(rootpath) if class_name != ".DS_Store"]
CFG.n_class = len(class_name_list) # クラスの数

filepath = os.path.join(rootpath , "**/*.jpg")
path_list = [path for path in glob.glob(filepath)]

print(f"n_path : {len(path_list)}")
print(f"n_class : {CFG.n_class}")

rootpath : /root/workspace/Pytorch_Lightning/data
parampath : /root/workspace/Pytorch_Lightning/params/efficientnet_b3_trial
n_path : 144
n_class : 3

pathの指定は環境によって変更が必要です。

5. transform

class train_ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((resize, resize)),
            transforms.RandomRotation(degrees=(-10, 10)), # rotate
            transforms.RandomHorizontalFlip(p=0.5), # Flip
            transforms.Normalize(mean, std), # Normalize
            transforms.RandomErasing(p=0.8, scale=(0.06, 0.20), ratio=(0.1, 4.0)), # mask
            ])
    def __call__(self, img):
        return self.data_transform(img)

class valid_ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((resize, resize)),
            transforms.Normalize(mean, std),
            ])
    def __call__(self, img):
        return self.data_transform(img)

6. Dataset

画像ファイルへの path と label の2列からなる DataFrame を作って、そこから Dataset を作ることにします。

# make_pathlist
filepath = os.path.join(rootpath , "**/*.jpg")
path_list = [path for path in glob.glob(filepath)]
class_name_list = [class_name for class_name in os.listdir(rootpath) if class_name != ".DS_Store"]
# make_class / class_name(str) -> label(int) の変換を行うdictを作ります
label_map = dict()
for i, class_name in enumerate(class_name_list):
    label_map[class_name] = i
# 作ったdictでlabel_listにします。intでリストに格納されます
label_list = [label_map[path.split("/")[-2]] for path in path_list]
# DataFrame にする
df = pd.DataFrame({
    "path" : path_list,
    "label" : label_list
})

# print(df.shape) # (len(file_list),2)

この df から Dataset を作るクラスを作ります。

class ImageDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df
        self.transform = transform 
    
    def __len__(self):
        return len(self.df) 

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path = row.path
        img = self.transform(Image.open(img_path).convert("RGB"))
        label = row.label
        return img, label

動作確認

train_transform = train_ImageTransform(CFG.resize, CFG.mean, CFG.std)
train_dataset = ImageDataset(df, transform=train_transform)
print(len(train_dataset)) # n_pathlist
print(train_dataset[0][0].shape, train_dataset[0][1]) # torch.Size([3, 224, 224]) labelを表すint

この df は StratifiedKFold して使うことになります。

7. PyTorch_Lightning

datasetができたら PyTorchLightning を使います。

  • DataLoaderの作成
  • 訓練・検証ループの作成
    の2つを作ります。

train_datasetとvalid_datasetを渡すとこのクラスででDataLoaderが作れます。

# make DataLoader for Pytorch_Lightning
class PlDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, valid_dataset):
        super().__init__()
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

PyTorchLightningでLoop部分を作ります。

# Loop
class Net(pl.LightningModule):
    # make_model
    def __init__(self, cv):
        super(Net, self).__init__()
        # model
        self.model = load_model()
        # cv
        self.cv = cv
        # other
        self.lr = CFG.model_lr
        self.criterion  = eval(CFG.criterion)
        # metrics
        self.classes = CFG.n_class
        self.accuracy = torchmetrics.Accuracy()
        self.F1score = torchmetrics.F1Score(num_classes=self.classes, average='macro')
        # log_list
        self.train_loss_log, self.train_acc_log = list(), list()
        self.valid_loss_log, self.valid_acc_log = list(), list()
        self.F1_log = list()
    
    # forward
    def forward(self, x):
        x = self.model(x)
        return x

    # optimizer & scheduler
    def configure_optimizers(self):
        optimizer = eval(CFG.optimizer)
        scheduler = eval(CFG.scheduler)
        return [optimizer],[scheduler]
    
    # 1.training_step
    def training_step(self, batch, batch_idx):
        # x:image y:label
        imgs, labels = batch # imgs:[B, C, H, W], labels:[B]
        logits = self.forward(imgs) # [B, class]
        loss = self.criterion(logits, labels)
        # display log
        self.log("train_acc", self.accuracy(logits, labels),prog_bar=False,logger=True,
                            on_epoch=True, on_step=False)
        print(".", end='')
        return {"loss" : loss, "logits" : logits, 
                    "labels" : labels, "batch_loss" : loss.item()*imgs.size(0)}
        # returnのDict内の"loss"を基にモデルがアップデートされる

    # 2.validation_step
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self.forward(imgs)
        loss = self.criterion(logits, labels)
        # display log
        self.log("valid_acc", self.accuracy(logits, labels),prog_bar=False,logger=True,
                            on_epoch=True, on_step=False)
        return {"valid_loss" : loss, "logits" : logits,
                    "labels" : labels, "batch_loss" : loss.item()*imgs.size(0)}

    # 3.validation_epoch_end
    def validation_epoch_end(self, val_step_outputs):
        logits = torch.cat([o["logits"] for o in val_step_outputs], dim=0) # [n_train_dataset, class]
        labels = torch.cat([o["labels"] for o in val_step_outputs], dim=0) # [n_train]
        epoch_loss = sum([o["batch_loss"] for o in val_step_outputs]) / logits.size(0)
        acc = self.accuracy(logits, labels)
        F1 = self.F1score(logits, labels)
        self.valid_loss_log.append(epoch_loss)
        self.valid_acc_log.append(acc.item())
        self.F1_log.append(F1.item())
        # display log
        self.log("valid_acc_epoch", acc, prog_bar=True,logger=True,
                            on_epoch=True, on_step=False)
        # # save記述
        if np.argmax(self.valid_acc_log) == self.current_epoch:
            torch.save(self.model.state_dict(), os.path.join(parampath, f"best_acc_{self.cv}.pth"))

    # 4.training_epoch_end
    def training_epoch_end(self, train_step_outputs):
        # リストの中に return の dict が入ってる
        # print(train_step_outputs) # [ [training_step_outputs], [training_step_outputs],...,[training_step_outputs] ]
        logits = torch.cat([o["logits"] for o in train_step_outputs], dim=0) # [n_train_dataset, class]
        labels = torch.cat([o["labels"] for o in train_step_outputs], dim=0) # [n_train]
        epoch_loss = sum([o["batch_loss"] for o in train_step_outputs]) / logits.size(0)
        acc = self.accuracy(logits, labels)
        self.train_loss_log.append(epoch_loss)
        self.train_acc_log.append(acc.item())
        # display log
        self.log("train_acc_epoch", acc, prog_bar=True,logger=True,
                            on_epoch=True, on_step=False)
        print()
        print(f"epoch : {self.current_epoch} >>> ", end='')
        print(f"train_loss : {self.train_loss_log[-1]:.4f} / train_acc : {self.train_acc_log[-1]:.4f} ", end='') 
        print(f"/ valid_loss : {self.valid_loss_log[-1]:.4f} / valid_acc : {self.valid_acc_log[-1]:.4f} / F1 : {self.F1_log[-1]:.4f}")
        print()
  • __init__の部分でmodel, learning_rate, criterion, 評価指標の計算機とかを決めます。

  • forwardは [B, C, H, W]の画像のテンソルを渡した時の挙動で、モデルに突っ込むよということです。

  • configure_optimizers で optimizerとscheduler を設定

  • training_step -> validation_step -> validation_epoch_end -> training_epoch_end という順番で動いていきます。記述の順番に制約はありません。

  • 訓練回数は self.current_epoch で取得できます。

  • self.logはlogが残せます。この辺りがまだあまり理解できていません。
    progress_barを表示した時に載せるような設定にできます。
    あとは tensor_board のような logger に送ることができたりという設定ができるようです。
    あまり理解ができてない故、強引に表示してます。ここはもっとスマートにできると思われます...。

  • PyTorch Lightningでは net.train(),net.eval(),to(device),optimizer.zero_grad()とかもろもろの記述が不要、というか書かないでも中で自動でやってくれているようです。

  • CrossValidationで回したいので、cvという変数をそのまま追加しています。訓練中に検証時のacc最高のモデルを保存するようにしています。

  • validation_epoch_endの引数のval_step_outputs の部分は
    printすれば分かりますが、リストの中に validation_step の return が dict で入ってます。
    epoch_endの中ではこの dict を処理する記述をしてます。

import torch

d1 = {'loss': 1, 'labels': torch.tensor([0,1,2]), 'logits': torch.tensor([[1,1,1], [2,2,2], [3,3,3]])}
d2 = {'loss': 10, 'labels': torch.tensor([1,1,1]), 'logits': torch.tensor([[10,10,10], [20,20,20], [30,30,30]])}
d3 = {'loss': 100, 'labels': torch.tensor([2,2,0]), 'logits': torch.tensor([[0.1,0.1,0.1], [0.2,0.2,0.2], [0.3,0.3,0.3]])}
outputs = [d1, d2, d3]
print(outputs)

for o in outputs:
    print("*****")
    print(o)

print()
sum_ = sum([o["loss"] for o in outputs])
print(sum_)
cat_ = torch.cat([ o["logits"] for o in outputs ], dim=0)
print(cat_)
print(cat_.shape)

こんな感じのことをしてるはずです。

8. Trainer

PyTorchLightningのdocumentでは以下のような記述が例として記述されています。

model = MyLightningModule()
trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

今回は MyLightningModule() は Net() と記述、dataloaderは PlDataModuleクラスで作成、cvの回数をそのまま記述したので、

data_module = PlDataModule(train_dataset, valid_dataset)
# Trainer
model = Net(cv=cv)
trainer = pl.Trainer(max_epochs=CFG.epochs, gpus=1, callbacks=progress, num_sanity_val_steps=0)
trainer.fit(model,datamodule=data_module)

のように書くことにします。
max_epochsは訓練回数。並列・分散しないのであれば、gpus=1。
callbacksの部分では progress_barを消す設定をしました。progress_barの制御がVSCodeで難しかったので...

# ProgressBar setting
class NotebookProgress(ProgressBarBase):
    def __init__(self):
        super().__init__()
        self.enable = True
    def disable(self):
        self.enable = True

progress = NotebookProgress()

callbacksでは他にも様々な設定ができるようです。checkpointとかearlystoppingとか。

num_sanity_val_stepsの指定ですが、ここを指定しないと、training_step -> validation_step -> validation_epoch_end -> training_epoch_end のループに入る前に、訓練していない状態で validation が先に1回だけ回るという設定がデフォルトになっています。

こちらに設定の意図などとても分かりやすく書いてありました。

では最後に、StratifiedKFoldの枠組みで Trainer を回します。

# fix seed
pl.seed_everything(CFG.seed)
# ProgressBar setting
progress = NotebookProgress()
# transform
train_transform = train_ImageTransform(CFG.resize, CFG.mean, CFG.std)
valid_transform = valid_ImageTransform(CFG.resize, CFG.mean, CFG.std)

skf = StratifiedKFold(CFG.n_fold, shuffle=True, random_state=CFG.seed)
cv = 0
cv_log = {}

for train_index, valid_index in skf.split(df, df["label"]):
    cv +=1
    print(f"************** fold: {cv} ***************")
    # df => train_df, valid_df
    train_df, valid_df = df.iloc[train_index], df.iloc[valid_index]
    # dataset
    train_dataset = ImageDataset(train_df, transform=train_transform)
    valid_dataset = ImageDataset(valid_df, transform=valid_transform)
    # data_module
    data_module = PlDataModule(train_dataset, valid_dataset)
    model = Net(cv=cv)
    trainer = pl.Trainer(max_epochs=CFG.epochs, gpus=1, callbacks=progress, num_sanity_val_steps=0)
    trainer.fit(model,datamodule=data_module)
    # plot
    fig = visualize_process(model.train_loss_log, model.valid_loss_log, model.train_acc_log, model.valid_acc_log)

    print(f"***** fold:{cv}:best_acc *****")
    print(max(model.valid_acc_log))
    print()
    # cvごとのlog
    # cv_logはmapになっていて、cvの数字によって log を保存
    cv_log[f"cv{cv}_train_loss_log"] = model.train_loss_log
    cv_log[f"cv{cv}_train_acc_log"] = model.train_acc_log
    cv_log[f"cv{cv}_valid_loss_log"] = model.valid_loss_log
    cv_log[f"cv{cv}_valid_acc_log"] = model.valid_acc_log
    cv_log[f"cv{cv}_valid_F1_log"] = model.F1_log

# 訓練終了後に一番良い値を表示する
for i in range(CFG.n_fold):
    print(f"********** cv{i+1} **********")
    best_acc = max(cv_log[f"cv{i+1}_valid_acc_log"])
    best_acc_index = np.argmax(cv_log[f"cv{i+1}_valid_acc_log"])
    print("best_acc_index: {}, best_acc: {:.04f}".format(best_acc_index, best_acc))
    best_loss = min(cv_log[f"cv{i+1}_valid_loss_log"])
    best_loss_index = np.argmin(cv_log[f"cv{i+1}_valid_loss_log"])
    print("best_loss_index: {}, best_loss: {:.04f}".format(best_loss_index, best_loss))
    best_F1 = max(cv_log[f"cv{i+1}_valid_F1_log"])
    best_F1_index = np.argmax(cv_log[f"cv{i+1}_valid_F1_log"])
    print("best_F1_index: {}, best_F1: {:.04f}".format(best_F1_index, best_F1))

poorな部分

  • 今回、訓練結果を強引に表示していますが、これはスマートなやり方がありそうです。
  • CrossValidation 自体もこのように for 文で記述しなくても、どこかで定義できる方法があると思われます。

またdocumentを見て改善します...

参考

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?