search
LoginSignup
117

More than 1 year has passed since last update.

posted at

updated at

PyTorch Lightning 2021 (for MLコンペ)

こちらの記事は 2021年6月18日に開催された 第2回分析コンペLT会 - connpass で発表に用いた資料です。

前回の発表他の類似ライブラリとの比較記事 の投稿からある程度時間が経ち、PyTorch Lightning については色々と書き方も変わったのであらためてまとめました。

0. この記事について

  • 対象
    • 素の PyTorch を使ったことがある人
      • 使ったことがない人は必ず Tutorial やってください
    • MLコンペ(特に画像系)に何回か参加する予定がある人
      • MLコンペは主に Kaggle を想定しています
      • 一度しか参加しないなら素の PyTorch 書き下したほうがはやいです
  • 注意事項
    • この記事は PyTorch Lightning を使う「べき」という主張ではありません
      • あくまでライブラリの一つの紹介です
      • 「まあそこまで言うなら触ってみよう」というきっかけになることを目指したものです
    • pytorch-lightning==1.3.1 を前提としています

1. 自己紹介

  • ふぁむたろうです
  • コンペ(Kaggle)経歴
    • 主に画像系のコンペに参加しています
    • 気づいたらどのメダルも PyTorch Lightning と共に獲得してました
      • ですので一応使っていると言っても嘘ではないかと思います

2. PyTorch Lightning(PL) とは

名前が長いので以下 PL とします。

2.1 概要

本家ドキュメントでは以下のように紹介されています。

Organizing your code with PyTorch Lightning makes your code:

  • Keep all the flexibility (this is all pure PyTorch), but removes a ton of boilerplate
    • PyTorch の柔軟性を維持しつつ大量の定型文(boilerplate)を削減する
  • More readable by decoupling the research code from the engineering
    • 研究コードをエンジニアリングから切り離すことで読みやすくする
    • device 設定や学習用のforループ、学習途中の結果の表示等は、本来やりたいことではない
  • Easier to reproduce
    • 再現しやすくする
      • seed や device 周りの吸収
      • step ごとに関数を分割することによる整理
  • Less error-prone by automating most of the training loop and tricky engineering
    • 学習ループやトリッキーなエンジニアリングを自動化することで実装ミスを減らす
    • 例えば以下の全てを盛り込んでを素の PyTorch で実装しようとすると事故りやすかったりする
      • Mixed Precision Training(AMP)
      • Gradient Accumulation
        • 自分で書くと、ループ末端での optimizer.step() を抜かしがち
      • Data Parallel(DP) か Distributed Data Parallel(DDP)
  • Scalable to any hardware without changing your model
    • コードを変えずにいろんな hardware を切り替えられる
    • CPU or GPU or TPU
      • .cuda() とかを埋め込むやつは許さない
    • 複数GPU
    • そもそもマシン自体を分散させたいときとか

雑に言うと PyTorch の学習ループ周りとかの定型文をラッピングしてくれる ライブラリという認識で良いです。

2.2 類似ライブラリについて

PL 以外では下記のライブラリも似たようなことができます。

細かい違いは以前 PyTorch 三国志(Ignite・Catalyst・Lightning) - Qiita にまとめましたので興味がある方はご参照ください。

これらの違いは雑に言えば 抽象化具合(どこをどこまでライブラリがラッピングするか) だと思っています。
好みの問題もありますので興味がある方は各リポジトリをご確認ください。

2.2.1 Star数比較(2021年6月18日時点)

  • 図は https://star-history.t9t.io/ より引用
  • いずれも順調に伸びてます
  • Lightning が頭一つ抜けてるかなという印象です

3. PL を使うメリット・デメリット

3.1 メリット

  • 自分でループを書かなくてすむ
    • ループをstepごとに強制分割させられるのも大きい
  • 以下とかの定型文とおさらばできる
    • seed 設定
    • GPU周りの処理
      • multi GPU の処理
      • .cuda(0).to(device) とおさらばできる
    • Mixed Precision Training
      • Nvidia-Apex と PyTorch-Native-AMP の切り替えも含む
    • Gradient Accumulation
      • 地味に実装を間違えやすい
    • deterministic 設定
  • 程よいコーディング制約がかかる
    • 素の PyTorch は自由すぎる
      • 真っ白なキャンパスに描き始めるのは辛い
    • Catalyst は色々巻き取りすぎて取り回しが難しい
    • Ignite はあまり巻き取ってないので自由度の高いコードになりやすい
    • PL → 自作 PyTorch ループは比較的楽に移行できる
  • 結果の表示が楽
    • tqdm とか自分で書かなくてすむ
    • csv とか tensorboard に出力して自分で確認とかしなくてよい
  • 他人と共有しやすい(お互い PL 知っていれば)
    • どこで何をしているのかわかるため
    • 「あの処理はどこでやってるんですか…?」みたいな問いが無くせる
  • 新しくコンペを始めるときに手が動きやすい
  • 上記について多くの人のレビューを経て洗練されたものを使うことができる
    • 正直知らない人が書いた seed_everything とか信じられない
    • 他人が書いたオレオレループとか読みたくない
    • 自分がバグを埋め込んでしまう心配がない
  • 誰でもかんたんにパイプライン作った気持ちになれる

3.2 デメリット

  • 最初の移植時は手間の割にコード量が劇的に減るわけではない
    • ループ周りの処理は書かなくてすむけど、PLが必要とする関数とか引数埋めとかが発生する
  • 細かすぎる操作ができない or 難しい
    • PL のソースコードまで確認する必要があったりする
      • 素の PyTorch なら20分で実装できそうなことに2,3時間かかる場合がある
    • 追加するにもうまく monkey-hack しないといけない場合がある
    • 個人的な主張
      • PLが対応できないほど細かいことは本質的ではない
      • 自分の業務の場合これ以上細かいことは重要ではない場合が多い
  • 最新版の pytorch が使えるようになるまで少しタイムラグがある
  • コードコンペだとローカルとバージョン合わせる手間が発生する
  • Notebook 内容を流用したい場合、中身を噛み砕いてPLに組み込まないといけない
    • 自分で Notebook の内容を把握してループ内の各stepを分割する必要がある
    • 自分は学ぶ上ではメリットだと思っていますが、すぐ動かせないのと新たなバグを埋め込む可能性があるのでもどかしい時があります
  • 素のpytorch以上にバージョンが変わったときの影響が大きい(引数名とか)
    • ver 1.0 以降は 破壊的変更 は少なくなったが、まだ存在はする
      • 以前使えたあの引数名が変わった、等

4. PyTorch Lightning 2021(構成要素編)

現在PLを使って学習する場合、以下の要素を呼び出す(定義する)必要があります。

  • Lightning Module
    • モデル + 各step(epoch や batch 単位)の挙動をまとめたクラス
    • 関数名が指定してあるのでその関数の中を埋めていく
  • Data Module
    • Dataset 周りを定義している
    • DataLoaderを返す関数を持つクラス
    • Datasetのダウンロードや train_valid_split などもここで行う
    • これも関数名が決まっているので埋める
  • Trainer
    • 学習や推論ループを回してくれる本体
    • 使うときは引数を埋めるだけ
      • epoch
      • device 設定
      • DP or DDP
      • FP16(Mixed Precision Training) or not
      • Gradient Accumulation
      • 下記の Logger や ModelCheckpoint(必要であれば)
  • その他(なくてもよい)
    • Logger
      • CSV logger や Wandb logger が用意されている
      • Wandb logger は2,3行追加するだけで使えてとても便利
        • 外からでも楽に学習結果をモニタリングできる
        • ついでにGPUリソース等の使用具合やコンソール情報も楽に見れる
    • ModelCheckpoint
      • 重みの保存方法を指定する
      • latest epoch や best loss、その両方など色々指定(定義)できる

以下のように埋めていくと、気づいたら学習が回っているイメージです。

from pytorch_lightning import LightningModule, LightningDataModule, Trainer

class MyLightningModule(LightningModule):
    # 決めてある関数の中身を埋める

class MyDataModule(LightningDataModule):
    # 決めてある関数の中身を埋める

trainer = Trainer()
lightning_module = MyLightningModule()
data_module = MyDataModule()

# 学習開始
trainer.fit(lightning_module, datamodule=data_module)

以前は LightningModule 内に DataModule の機能も入れる必要がありましたので、現在はもう少し Keras Like になったとも言えます。

5. PyTorch Lightning 2021 (実践編)

ここでは PL 実践編として、分類タスクを想定して自分の実装例を紹介します。

5.1 ファイル構成例

とりあえず下記のどこかで Lightning ModuleData ModuleTrainer が定義されていることだけ把握しておけば大丈夫です。

.
├── docker                        : 環境再現用
├── input                         : sample_submission.csv とかが入っている
└── src
    ├── config                    : 実験設定ファイル
    │   ├── 001.yaml
    │   ├── 002.yaml
    │   ├── ...
    ├── dataset.py                : Data Module 等を定義
    ├── factory.py                : config ファイルから必要なモジュールを生成
    ├── lightning_module.py       : Lightning Module を定義
    ├── loss.py                   : 自作 Loss 置き場
    ├── models                    : 自作 Model 置き場
    │   ├── model1.py
    │   ├── model2.py
    │   ├── ...
    ├── test.py                   : submission 作成用
    ├── train.py                  : Training 実行スクリプト
    └── util                      : 細かい処理関数置き場
        ├── data_augmentation.py
        ├── img_preprocess.py
        └── mask_func.py

5.2 Lightning Module

Lightning Module はモデル自体の定義と各step(train, valid)の挙動をまとめたものです。

以下のように training_stepvalidation_step など予め決まっている関数を上書きして定義します。
また validation_epoch_end 内の self.log_dict(d, prog_bar=True) のように、self.log_~~~ を呼び出すことでいろんなタイミングで各指標を書き出すことができます。

lightning_module.py
from collections import OrderedDict

import pytorch_lightning as pl
import torch
import numpy as np
from sklearn.metrics import roc_auc_score

from factory import get_loss, get_lr_scheduler, get_optimizer
from models import get_custom_timm_model


class MyLightningModule(pl.LightningModule):
    def __init__(self, cfg):
        super(MyLightningModule, self).__init__()
        self.cfg = cfg
        self.cfg_optim = cfg.Optimizer
        self.loss = get_loss(cfg.Loss)
        self.net = get_custom_timm_model(cfg.Model)

    def forward(self, x):
        logits = self.net(x).squeeze(1)
        return logits

    def training_step(self, batch, batch_nb):
        imgs, targets = batch
        # mixup とかしたい場合はここに差し込む
        logits = self.forward(imgs)
        loss = self.loss(logits, targets)
        return loss

    def validation_step(self, batch, batch_nb):
        imgs, targets = batch
        logits = self.forward(imgs)
        loss = self.loss(logits, targets)
        preds = logits.sigmoid()
        output = OrderedDict({
            "targets": targets.detach(), "preds": preds.detach(), "loss": loss.detach()
        })
        return output

    def validation_epoch_end(self, outputs):
        d = dict()
        d["epoch"] = int(self.current_epoch)
        d["v_loss"] = torch.stack([o["loss"] for o in outputs]).mean().item()

        targets = torch.cat([o["targets"].view(-1) for o in outputs]).cpu().numpy()
        preds = torch.cat([o["preds"].view(-1) for o in outputs]).cpu().numpy()

        score = roc_auc_score(y_true=targets, y_score=preds)
        d["v_score"] = score
        self.log_dict(d, prog_bar=True)

    def configure_optimizers(self):
        conf_optim = self.cfg_optim
        optimizer = get_optimizer(conf_optim)(self.parameters(), **conf_optim.optimizer.params)
        scheduler = get_lr_scheduler(conf_optim)(optimizer, **conf_optim.lr_scheduler.params)
        # ここでの返し方で batch ごとの scheduler.step() 実行もできる
        return [optimizer], [scheduler]

5.3 DataModule

dataset.py
from pathlib import Path

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset

from factory import get_transform


class TrainDataset(Dataset):
    # 通常の PyTorch の Dataset を定義する

class MyDataModule(pl.LightningDataModule):
    def __init__(self, cfg):
        super().__init__()

        # Config ファイルからの読み込み
        self.cfg_dataset = cfg.Data.dataset
        self.cfg_augmentation = cfg.Augmentation
        self.cfg_dataloader = cfg.Data.dataloader

        self.test_df = None
        self.train_df = None
        self.valid_df = None

    def get_test_df(self):
        return pd.read_csv(self.cfg_dataset.test_csv)

    def split_train_valid_df(self):
        df = pd.read_csv(self.cfg_dataset.train_csv)

        # Split
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        for n, (train_index, val_index) in enumerate(skf.split(df, df[self.cfg_dataset.target_col])):
            df.loc[val_index, "fold"] = int(n)
        df["fold"] = df["fold"].astype(int)

        fold = int(self.cfg_dataset.fold)
        train_df = df[df["fold"] != fold].reset_index(drop=True)
        valid_df = df[df["fold"] == fold].reset_index(drop=True)
        return train_df, valid_df

    # 必ず呼び出される関数
    def setup(self, stage):
        self.test_df = self.get_test_df()
        train_df, valid_df = self.split_train_valid_df()
        self.train_df = train_df
        self.valid_df = valid_df

    def get_dataframe(self, phase):
        assert phase in {"train", "valid", "test"}
        if phase == "train":
            return self.train_df
        elif phase == "valid":
            return self.valid_df
        elif phase == "test":
            return self.test_df

    def get_ds(self, phase):
        assert phase in {"train", "valid", "test"}
        transform = get_transform(conf_augmentation=self.cfg_augmentation[phase])
        ds = TrainDataset(
            df=self.get_dataframe(phase=phase),
            transform=transform,
            test=(phase == "test"),
            cfg_dataset=self.cfg_dataset
        )
        return ds

    def get_loader(self, phase):
        dataset = self.get_ds(phase=phase)
        return DataLoader(
            dataset,
            batch_size=self.cfg_dataloader.batch_size,
            shuffle=True if phase == "train" else False,
            num_workers=self.cfg_dataloader.num_workers,
            drop_last=True if phase == "train" else False,
        )

    # Trainer.fit() 時に呼び出される
    def train_dataloader(self):
        return self.get_loader(phase="train")

    # Trainer.fit() 時に呼び出される
    def val_dataloader(self):
        return self.get_loader(phase="valid")

    def test_dataloader(self):
        return self.get_loader(phase="test")

5.4 Trainer

今回は Logger として PL に入っている wandbLogger を使います。
これを使うことで学習時に Weights & Biases – Developer tools for ML に各指標(Loss とか)が送られていくので、確認するのが楽になります。
ただし wandbLogger を使う場合、予めコンソール上 wandb 初回 login をする必要があるのでご注意ください。(一行叩くとできます)

train.py
import argparse
import os
import shutil
from pathlib import Path
from typing import Dict

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger

from dataset import MyDataModule
from factory import read_yaml
from lightning_module import MyLightningModule
from util import make_output_path, src_backup


def make_parse():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg("--debug", action="store_true", help="debug")
    arg("--config", default=None, type=str, help="config path")
    arg("--gpus", default="0", type=str, help="gpu numbers")
    arg("--fold", default="0", type=str, help="fold number")
    return parser


def train(cfg_name: str, cfg: Dict, output_path: Path) -> None:
    seed_everything(cfg.General.seed)
    debug = cfg.General.debug
    fold = cfg.Data.dataset.fold

    # logger は csv logger と wandb logger 両方使ってみる
    logger = CSVLogger(save_dir=str(output_path), name=f"fold_{fold}")
    wandb_logger = WandbLogger(name=f"{cfg_name}_{fold}", project=cfg.General.project, offline=debug)

    # 学習済重みを保存するために必要
    checkpoint_callback = ModelCheckpoint(
        dirpath=str(output_path), filename=f"{cfg_name}_fold_{fold}",
        save_weights_only=True,
        save_top_k=None,
        monitor=None
    )
    trainer = Trainer(
        max_epochs=5 if debug else cfg.General.epoch,
        gpus=cfg.General.gpus,
        distributed_backend=cfg.General.multi_gpu_mode,
        accumulate_grad_batches=cfg.General.grad_acc,
        precision=16 if cfg.General.fp16 else 32,
        amp_level=cfg.General.amp_level,
        amp_backend='native',
        deterministic=True,
        auto_select_gpus=False,
        benchmark=False,
        default_root_dir=os.getcwd(),
        limit_train_batches=0.02 if debug else 1.0,
        limit_val_batches=0.05 if debug else 1.0,
        callbacks=[checkpoint_callback],
        logger=[logger, wandb_logger],
        replace_sampler_ddp=not cfg.Data.dataloader.change_rate_sampler,

        # For fast https://pytorch-lightning.readthedocs.io/en/1.3.3/benchmarking/performance.html#
        plugins=DDPPlugin(find_unused_parameters=False)
    )

    # Lightning module and start training
    model = MyLightningModule(cfg)
    datamodule = MyDataModule(cfg)
    trainer.fit(model, datamodule=datamodule)


def main():
    args = make_parse().parse_args()

    # Read config
    cfg = read_yaml(fpath=args.config)
    cfg.General.debug = args.debug
    cfg.General.gpus = list(map(int, args.gpus.split(",")))
    cfg.Data.dataset.fold = args.fold

    # Make output path
    output_path = Path("../output/model") / Path(args.config).stem
    output_path = make_output_path(output_path, args.debug)

    # Config and Source code backup
    shutil.copy2(args.config, str(output_path / Path(args.config).name))
    src_backup(input_dir=Path("./"), output_dir=output_path)

    # Start train
    train(cfg_name=Path(args.config).stem, cfg=cfg, output_path=output_path)


if __name__ == "__main__":
    main()

5.5 学習実行と表示される結果

5.5.1 学習実行と表示される結果(コンソール上)

実行すると下記のような結果が勝手に出てきます。

$ cd src
$ python train.py --config config/015.yaml --fold 0 --debug
Global seed set to 1222
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
Global seed set to 1222
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
wandb: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.

  | Name | Type              | Params
-------------------------------------------
0 | loss | BCEWithLogitsLoss | 0
1 | net  | EfficientNet      | 10.7 M
-------------------------------------------
10.7 M    Trainable params
0         Non-trainable params
10.7 M    Total params
42.788    Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:  50%|█████████████████████████████████████████                                    | 1/2 [00:01<00:01,  1.08s/it]
Global seed set to 1222
Epoch 1:  40%|████████████████████                             | 16/40 [00:07<00:11,  2.07it/s, loss=0.396, v_loss=0.462, v_score=0.466]

5.5.2 表示される結果(wandb上)

今回は wandbLogger を使っているので、Home – Weights & Biases を見に行くと下記のようなまとめが勝手に作られます(とても便利)。

6. その他tips

6.1 高速化手法について

Fast performance tips — PyTorch Lightning 1.3.3 documentation には PyTorch Lightning だけでなく PyTorch 全般の高速化について触れていたりして学びがあったので、お時間ある方は確認してみると良いかなと思いました。

6.2 おすすめ PL 練習場(MLコンペ)

  • 現在(2021年6月18日)開催中のコンペですと SETI Breakthrough Listen - E.T. Signal Search | Kaggle が PL の練習場として良いと思います
    • シンプルなタスク(分類)
    • CSV を提出するだけでよい
      • ローカルマシンの環境のみ気をつければよい
      • Kaggle Notebook に移植しなくてよい

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
117