こちらの記事は 2021年6月18日に開催された 第2回分析コンペLT会 - connpass で発表に用いた資料です。
前回の発表 や 他の類似ライブラリとの比較記事 の投稿からある程度時間が経ち、PyTorch Lightning については色々と書き方も変わったのであらためてまとめました。
0. この記事について
- 対象
- 素の PyTorch を使ったことがある人
- 使ったことがない人は必ず Tutorial やってください
- MLコンペ(特に画像系)に何回か参加する予定がある人
- MLコンペは主に Kaggle を想定しています
- 一度しか参加しないなら素の PyTorch 書き下したほうがはやいです
- 素の PyTorch を使ったことがある人
- 注意事項
-
この記事は PyTorch Lightning を使う「べき」という主張ではありません
- あくまでライブラリの一つの紹介です
- 「まあそこまで言うなら触ってみよう」というきっかけになることを目指したものです
-
pytorch-lightning==1.3.1
を前提としています
-
この記事は PyTorch Lightning を使う「べき」という主張ではありません
1. 自己紹介
- ふぁむたろうです
- 現在は 株式会社Rist にて製造業周りのAI開発とか行っています
- kaggle master です (https://www.kaggle.com/yukkyo)
- https://twitter.com/fam_taro
- コンペ(Kaggle)経歴
- 主に画像系のコンペに参加しています
- 気づいたらどのメダルも PyTorch Lightning と共に獲得してました
- ですので一応使っていると言っても嘘ではないかと思います
2. PyTorch Lightning(PL) とは
名前が長いので以下 PL とします。
2.1 概要
本家ドキュメントでは以下のように紹介されています。
Organizing your code with PyTorch Lightning makes your code:
-
Keep all the flexibility (this is all pure PyTorch), but removes a ton of boilerplate
- PyTorch の柔軟性を維持しつつ大量の定型文(boilerplate)を削減する
-
More readable by decoupling the research code from the engineering
- 研究コードをエンジニアリングから切り離すことで読みやすくする
- device 設定や学習用のforループ、学習途中の結果の表示等は、本来やりたいことではない
-
Easier to reproduce
- 再現しやすくする
- seed や device 周りの吸収
- step ごとに関数を分割することによる整理
- 再現しやすくする
-
Less error-prone by automating most of the training loop and tricky engineering
- 学習ループやトリッキーなエンジニアリングを自動化することで実装ミスを減らす
- 例えば以下の全てを盛り込んでを素の PyTorch で実装しようとすると事故りやすかったりする
- Mixed Precision Training(AMP)
- Gradient Accumulation
- 自分で書くと、ループ末端での
optimizer.step()
を抜かしがち
- 自分で書くと、ループ末端での
- Data Parallel(DP) か Distributed Data Parallel(DDP)
-
Scalable to any hardware without changing your model
- コードを変えずにいろんな hardware を切り替えられる
- CPU or GPU or TPU
-
.cuda()
とかを埋め込むやつは許さない
-
- 複数GPU
- そもそもマシン自体を分散させたいときとか
雑に言うと PyTorch の学習ループ周りとかの定型文をラッピングしてくれる
ライブラリという認識で良いです。
2.2 類似ライブラリについて
PL 以外では下記のライブラリも似たようなことができます。
- pytorch/ignite: High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
- catalyst-team/catalyst: Accelerated deep learning R&D
- fastai/fastai: The fastai deep learning library
- pfnet/pytorch-pfn-extras: Supplementary components to accelerate research and development in PyTorch
細かい違いは以前 PyTorch 三国志(Ignite・Catalyst・Lightning) - Qiita にまとめましたので興味がある方はご参照ください。
これらの違いは雑に言えば 抽象化具合(どこをどこまでライブラリがラッピングするか)
だと思っています。
好みの問題もありますので興味がある方は各リポジトリをご確認ください。
2.2.1 Star数比較(2021年6月18日時点)
- 図は https://star-history.t9t.io/ より引用
- いずれも順調に伸びてます
- Lightning が頭一つ抜けてるかなという印象です
3. PL を使うメリット・デメリット
3.1 メリット
- 自分でループを書かなくてすむ
- ループをstepごとに強制分割させられるのも大きい
- 以下とかの定型文とおさらばできる
- seed 設定
- GPU周りの処理
- multi GPU の処理
-
.cuda(0)
や.to(device)
とおさらばできる
- Mixed Precision Training
- Nvidia-Apex と PyTorch-Native-AMP の切り替えも含む
- Gradient Accumulation
- 地味に実装を間違えやすい
- deterministic 設定
- 程よいコーディング制約がかかる
-
素の PyTorch は自由すぎる
- 真っ白なキャンパスに描き始めるのは辛い
- Catalyst は色々巻き取りすぎて取り回しが難しい
- Ignite はあまり巻き取ってないので自由度の高いコードになりやすい
- PL → 自作 PyTorch ループは比較的楽に移行できる
-
素の PyTorch は自由すぎる
- 結果の表示が楽
- tqdm とか自分で書かなくてすむ
- csv とか tensorboard に出力して自分で確認とかしなくてよい
- 他人と共有しやすい(お互い PL 知っていれば)
- どこで何をしているのかわかるため
- 「あの処理はどこでやってるんですか…?」みたいな問いが無くせる
- 新しくコンペを始めるときに手が動きやすい
- 上記について多くの人のレビューを経て洗練されたものを使うことができる
- 正直知らない人が書いた seed_everything とか信じられない
- 他人が書いたオレオレループとか読みたくない
- 自分がバグを埋め込んでしまう心配がない
- 誰でもかんたんにパイプライン作った気持ちになれる
3.2 デメリット
- 最初の移植時は手間の割にコード量が劇的に減るわけではない
- ループ周りの処理は書かなくてすむけど、PLが必要とする関数とか引数埋めとかが発生する
- 細かすぎる操作ができない or 難しい
- PL のソースコードまで確認する必要があったりする
- 素の PyTorch なら20分で実装できそうなことに2,3時間かかる場合がある
- 追加するにもうまく monkey-hack しないといけない場合がある
-
個人的な主張
- PLが対応できないほど細かいことは本質的ではない
- 自分の業務の場合これ以上細かいことは重要ではない場合が多い
- PL のソースコードまで確認する必要があったりする
- 最新版の pytorch が使えるようになるまで少しタイムラグがある
- コードコンペだとローカルとバージョン合わせる手間が発生する
-
Notebook 内容を流用したい場合、中身を噛み砕いてPLに組み込まないといけない
- 自分で Notebook の内容を把握してループ内の各stepを分割する必要がある
- 自分は学ぶ上ではメリットだと思っていますが、すぐ動かせないのと新たなバグを埋め込む可能性があるのでもどかしい時があります
- 素のpytorch以上にバージョンが変わったときの影響が大きい(引数名とか)
- ver 1.0 以降は
破壊的変更
は少なくなったが、まだ存在はする- 以前使えたあの引数名が変わった、等
- ver 1.0 以降は
4. PyTorch Lightning 2021(構成要素編)
現在PLを使って学習する場合、以下の要素を呼び出す(定義する)必要があります。
-
Lightning Module
- モデル + 各step(epoch や batch 単位)の挙動をまとめたクラス
- 関数名が指定してあるのでその関数の中を埋めていく
-
Data Module
- Dataset 周りを定義している
- DataLoaderを返す関数を持つクラス
- Datasetのダウンロードや train_valid_split などもここで行う
- これも関数名が決まっているので埋める
-
Trainer
- 学習や推論ループを回してくれる本体
- 使うときは引数を埋めるだけ
- epoch
- device 設定
- DP or DDP
- FP16(Mixed Precision Training) or not
- Gradient Accumulation
- 下記の Logger や ModelCheckpoint(必要であれば)
-
その他(なくてもよい)
- Logger
- CSV logger や Wandb logger が用意されている
- Wandb logger は2,3行追加するだけで使えてとても便利
- 外からでも楽に学習結果をモニタリングできる
- ついでにGPUリソース等の使用具合やコンソール情報も楽に見れる
- ModelCheckpoint
- 重みの保存方法を指定する
- latest epoch や best loss、その両方など色々指定(定義)できる
- Logger
以下のように埋めていくと、気づいたら学習が回っているイメージです。
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
class MyLightningModule(LightningModule):
# 決めてある関数の中身を埋める
class MyDataModule(LightningDataModule):
# 決めてある関数の中身を埋める
trainer = Trainer()
lightning_module = MyLightningModule()
data_module = MyDataModule()
# 学習開始
trainer.fit(lightning_module, datamodule=data_module)
以前は LightningModule 内に DataModule の機能も入れる必要がありましたので、現在はもう少し Keras Like になったとも言えます。
5. PyTorch Lightning 2021 (実践編)
ここでは PL 実践編として、分類タスクを想定して自分の実装例を紹介します。
5.1 ファイル構成例
とりあえず下記のどこかで Lightning Module
、Data Module
、Trainer
が定義されていることだけ把握しておけば大丈夫です。
.
├── docker : 環境再現用
├── input : sample_submission.csv とかが入っている
└── src
├── config : 実験設定ファイル
│ ├── 001.yaml
│ ├── 002.yaml
│ ├── ...
├── dataset.py : Data Module 等を定義
├── factory.py : config ファイルから必要なモジュールを生成
├── lightning_module.py : Lightning Module を定義
├── loss.py : 自作 Loss 置き場
├── models : 自作 Model 置き場
│ ├── model1.py
│ ├── model2.py
│ ├── ...
├── test.py : submission 作成用
├── train.py : Training 実行スクリプト
└── util : 細かい処理関数置き場
├── data_augmentation.py
├── img_preprocess.py
└── mask_func.py
5.2 Lightning Module
Lightning Module はモデル自体の定義と各step(train, valid)の挙動をまとめたものです。
以下のように training_step
や validation_step
など予め決まっている関数を上書きして定義します。
また validation_epoch_end
内の self.log_dict(d, prog_bar=True)
のように、self.log_~~~
を呼び出すことでいろんなタイミングで各指標を書き出すことができます。
from collections import OrderedDict
import pytorch_lightning as pl
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
from factory import get_loss, get_lr_scheduler, get_optimizer
from models import get_custom_timm_model
class MyLightningModule(pl.LightningModule):
def __init__(self, cfg):
super(MyLightningModule, self).__init__()
self.cfg = cfg
self.cfg_optim = cfg.Optimizer
self.loss = get_loss(cfg.Loss)
self.net = get_custom_timm_model(cfg.Model)
def forward(self, x):
logits = self.net(x).squeeze(1)
return logits
def training_step(self, batch, batch_nb):
imgs, targets = batch
# mixup とかしたい場合はここに差し込む
logits = self.forward(imgs)
loss = self.loss(logits, targets)
return loss
def validation_step(self, batch, batch_nb):
imgs, targets = batch
logits = self.forward(imgs)
loss = self.loss(logits, targets)
preds = logits.sigmoid()
output = OrderedDict({
"targets": targets.detach(), "preds": preds.detach(), "loss": loss.detach()
})
return output
def validation_epoch_end(self, outputs):
d = dict()
d["epoch"] = int(self.current_epoch)
d["v_loss"] = torch.stack([o["loss"] for o in outputs]).mean().item()
targets = torch.cat([o["targets"].view(-1) for o in outputs]).cpu().numpy()
preds = torch.cat([o["preds"].view(-1) for o in outputs]).cpu().numpy()
score = roc_auc_score(y_true=targets, y_score=preds)
d["v_score"] = score
self.log_dict(d, prog_bar=True)
def configure_optimizers(self):
conf_optim = self.cfg_optim
optimizer = get_optimizer(conf_optim)(self.parameters(), **conf_optim.optimizer.params)
scheduler = get_lr_scheduler(conf_optim)(optimizer, **conf_optim.lr_scheduler.params)
# ここでの返し方で batch ごとの scheduler.step() 実行もできる
return [optimizer], [scheduler]
5.3 DataModule
from pathlib import Path
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset
from factory import get_transform
class TrainDataset(Dataset):
# 通常の PyTorch の Dataset を定義する
class MyDataModule(pl.LightningDataModule):
def __init__(self, cfg):
super().__init__()
# Config ファイルからの読み込み
self.cfg_dataset = cfg.Data.dataset
self.cfg_augmentation = cfg.Augmentation
self.cfg_dataloader = cfg.Data.dataloader
self.test_df = None
self.train_df = None
self.valid_df = None
def get_test_df(self):
return pd.read_csv(self.cfg_dataset.test_csv)
def split_train_valid_df(self):
df = pd.read_csv(self.cfg_dataset.train_csv)
# Split
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for n, (train_index, val_index) in enumerate(skf.split(df, df[self.cfg_dataset.target_col])):
df.loc[val_index, "fold"] = int(n)
df["fold"] = df["fold"].astype(int)
fold = int(self.cfg_dataset.fold)
train_df = df[df["fold"] != fold].reset_index(drop=True)
valid_df = df[df["fold"] == fold].reset_index(drop=True)
return train_df, valid_df
# 必ず呼び出される関数
def setup(self, stage):
self.test_df = self.get_test_df()
train_df, valid_df = self.split_train_valid_df()
self.train_df = train_df
self.valid_df = valid_df
def get_dataframe(self, phase):
assert phase in {"train", "valid", "test"}
if phase == "train":
return self.train_df
elif phase == "valid":
return self.valid_df
elif phase == "test":
return self.test_df
def get_ds(self, phase):
assert phase in {"train", "valid", "test"}
transform = get_transform(conf_augmentation=self.cfg_augmentation[phase])
ds = TrainDataset(
df=self.get_dataframe(phase=phase),
transform=transform,
test=(phase == "test"),
cfg_dataset=self.cfg_dataset
)
return ds
def get_loader(self, phase):
dataset = self.get_ds(phase=phase)
return DataLoader(
dataset,
batch_size=self.cfg_dataloader.batch_size,
shuffle=True if phase == "train" else False,
num_workers=self.cfg_dataloader.num_workers,
drop_last=True if phase == "train" else False,
)
# Trainer.fit() 時に呼び出される
def train_dataloader(self):
return self.get_loader(phase="train")
# Trainer.fit() 時に呼び出される
def val_dataloader(self):
return self.get_loader(phase="valid")
def test_dataloader(self):
return self.get_loader(phase="test")
5.4 Trainer
今回は Logger として PL に入っている wandbLogger を使います。
これを使うことで学習時に Weights & Biases – Developer tools for ML に各指標(Loss とか)が送られていくので、確認するのが楽になります。
ただし wandbLogger を使う場合、予めコンソール上 wandb 初回 login をする必要があるのでご注意ください。(一行叩くとできます)
import argparse
import os
import shutil
from pathlib import Path
from typing import Dict
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger
from dataset import MyDataModule
from factory import read_yaml
from lightning_module import MyLightningModule
from util import make_output_path, src_backup
def make_parse():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg("--debug", action="store_true", help="debug")
arg("--config", default=None, type=str, help="config path")
arg("--gpus", default="0", type=str, help="gpu numbers")
arg("--fold", default="0", type=str, help="fold number")
return parser
def train(cfg_name: str, cfg: Dict, output_path: Path) -> None:
seed_everything(cfg.General.seed)
debug = cfg.General.debug
fold = cfg.Data.dataset.fold
# logger は csv logger と wandb logger 両方使ってみる
logger = CSVLogger(save_dir=str(output_path), name=f"fold_{fold}")
wandb_logger = WandbLogger(name=f"{cfg_name}_{fold}", project=cfg.General.project, offline=debug)
# 学習済重みを保存するために必要
checkpoint_callback = ModelCheckpoint(
dirpath=str(output_path), filename=f"{cfg_name}_fold_{fold}",
save_weights_only=True,
save_top_k=None,
monitor=None
)
trainer = Trainer(
max_epochs=5 if debug else cfg.General.epoch,
gpus=cfg.General.gpus,
distributed_backend=cfg.General.multi_gpu_mode,
accumulate_grad_batches=cfg.General.grad_acc,
precision=16 if cfg.General.fp16 else 32,
amp_level=cfg.General.amp_level,
amp_backend='native',
deterministic=True,
auto_select_gpus=False,
benchmark=False,
default_root_dir=os.getcwd(),
limit_train_batches=0.02 if debug else 1.0,
limit_val_batches=0.05 if debug else 1.0,
callbacks=[checkpoint_callback],
logger=[logger, wandb_logger],
replace_sampler_ddp=not cfg.Data.dataloader.change_rate_sampler,
# For fast https://pytorch-lightning.readthedocs.io/en/1.3.3/benchmarking/performance.html#
plugins=DDPPlugin(find_unused_parameters=False)
)
# Lightning module and start training
model = MyLightningModule(cfg)
datamodule = MyDataModule(cfg)
trainer.fit(model, datamodule=datamodule)
def main():
args = make_parse().parse_args()
# Read config
cfg = read_yaml(fpath=args.config)
cfg.General.debug = args.debug
cfg.General.gpus = list(map(int, args.gpus.split(",")))
cfg.Data.dataset.fold = args.fold
# Make output path
output_path = Path("../output/model") / Path(args.config).stem
output_path = make_output_path(output_path, args.debug)
# Config and Source code backup
shutil.copy2(args.config, str(output_path / Path(args.config).name))
src_backup(input_dir=Path("./"), output_dir=output_path)
# Start train
train(cfg_name=Path(args.config).stem, cfg=cfg, output_path=output_path)
if __name__ == "__main__":
main()
5.5 学習実行と表示される結果
5.5.1 学習実行と表示される結果(コンソール上)
実行すると下記のような結果が勝手に出てきます。
$ cd src
$ python train.py --config config/015.yaml --fold 0 --debug
Global seed set to 1222
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
Global seed set to 1222
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
wandb: W&B syncing is set to `offline` in this directory. Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
| Name | Type | Params
-------------------------------------------
0 | loss | BCEWithLogitsLoss | 0
1 | net | EfficientNet | 10.7 M
-------------------------------------------
10.7 M Trainable params
0 Non-trainable params
10.7 M Total params
42.788 Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check: 50%|█████████████████████████████████████████ | 1/2 [00:01<00:01, 1.08s/it]
Global seed set to 1222
Epoch 1: 40%|████████████████████ | 16/40 [00:07<00:11, 2.07it/s, loss=0.396, v_loss=0.462, v_score=0.466]
5.5.2 表示される結果(wandb上)
今回は wandbLogger を使っているので、Home – Weights & Biases を見に行くと下記のようなまとめが勝手に作られます(とても便利)。
6. その他tips
6.1 高速化手法について
Fast performance tips — PyTorch Lightning 1.3.3 documentation には PyTorch Lightning だけでなく PyTorch 全般の高速化について触れていたりして学びがあったので、お時間ある方は確認してみると良いかなと思いました。
6.2 おすすめ PL 練習場(MLコンペ)
- 現在(2021年6月18日)開催中のコンペですと SETI Breakthrough Listen - E.T. Signal Search | Kaggle が PL の練習場として良いと思います
- シンプルなタスク(分類)
- CSV を提出するだけでよい
- ローカルマシンの環境のみ気をつければよい
- Kaggle Notebook に移植しなくてよい