16
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorch Lightning v1.5 新機能の紹介

Last updated at Posted at 2021-12-14

はじめに

当記事では以下の PyTorch Lightning v1.5 のリリースノートとブログ投稿を基に、PyTorch Lightning v1.5 の新機能を簡略に紹介します。

対象読者

  • PyTorch Lightning を使ったことがある人
  • PyTorch を使ったことがある人 (端折って説明しているのでわからない部分があるかもしれません)

機能紹介一覧
以下、インパクトのある (と個人的に思う) 順番にしています。

  1. Batch-Level Fault-Tolerant Training
  2. BFloat16
  3. Trainer ハードウェア関係引数
  4. 無限 Epochs
  5. Lightning Lite
  6. LightningCLI V2
  7. Rich Progress Bar
  8. Loop Customization
  9. その他の機能

機能紹介

1. Batch-Level Fault-Tolerant Training

学習途中に予期しないエラーで中断しても、中断したバッチから学習を再開でき、環境変数を設定するだけで利用可能です。

PL_FAULT_TOLERANT_TRAINING=1 python script.py

参考

2. BFloat16

PyTorch 1.10 以降でサポートされる torch.bfloat16 (Brain Floating Point) を利用することで torch.float16 の Automatic Mixed Precision よりも安定した学習が可能になります。

from pytorch_lightning import Trainer

Trainer(precision="bf16")

参考

3. Trainer ハードウェア関係引数

accelerator="auto"devices="auto" で自動的に検出した利用可能なデバイスを指定できます。

  • accelerator はハードウェアの種類を指す: "cpu" "gpu" "tpu" "ipu" "auto"
  • strategy はハードウェアの利用方法を指す: "dp" "ddp" "ddp_spawn" "deepspeed"
  • devices はデバイスの数を指す: 整数もしくは "auto"
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin

# DDP with 4 GPUs
Trainer(accelerator="gpu", devices=4, strategy="ddp")
Trainer(accelerator="gpu", devices=4, strategy=DDPPlugin(...))

# DDP Spawn with detected accelerator (gpu, tpu, ...)
Trainer(accelerator="auto", devices=4, strategy="ddp_spawn")

# DDP with available GPUs
Trainer(accelerator="gpu", devices="auto", strategy="ddp")

参考

4. 無限 Epochs

-1 で無限 epochs/steps を設定できます。

from pytorch_lightning import Trainer

# infinite epochs
trainer = Trainer(max_epochs=-1)

# an endless epoch
trainer = Trainer(max_steps=-1)

5. Lightning Lite

LightningModule のサブクラスを実装しなくても、既存の PyTorch コードに最小限の変更を加えることで、各種アクセラレータを利用し訓練できます。下記のように定義した LightningLite のサブクラスのコンストラクタに好きな設定 (GPU, TPU, DDP, DeepSpeed, amp など) を引数として与えインスタンス化し、run() メソッドを実行します。

+from pytorch_lightning.lite import LightningLite

+class Lite(LightningLite):
-    def run(num_epochs):
+    def run(self, num_epochs, ...):

-        device = "cuda" if torch.cuda.is_available() else "cpu"

-        model = MyModel(...).to(device)  # .to(...) は不要
+        model = MyModel(...)
         optimizer = torch.optim.SGD(model.parameters(), ...)
+        model, optimizer = self.setup(model, optimizer)
         dataloader = DataLoader(MyDataset(...), ...)
+        dataloader = self.setup_dataloaders(dataloader)

         model.train()
         for epoch in range(num_epochs):
             for batch in dataloader:
-                batch = batch.to(device)  # .to(...) は不要
                 optimizer.zero_grad()
                 loss = model(batch)
-                loss.backward()
+                self.backward(loss)
                 optimizer.step()
Lite(devices=8, accelerator="gpu", precision="bf16", strategy="ddp").run(10, ...)
Lite(devices=8, accelerator="gpu", precision=16, strategy="deepspeed").run(10, ...)
Lite(devices="auto", accelerator="auto", precision=16).run(10, ...)

参考

6. LightningCLI V2

スクリプト内で Trainer.{fit,validate,test,predict,tune} を記述せずに、コマンドライン引数から指定して呼び出すことができます。

$ pip install "jsonargparse[signatures]"
$ python trainer.py fit
# trainer.py
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.utilities import LightningCLI

class MyModel(LightningModule):
    ...

class MyData(LightningDataModule):
    ...

cli = LightningCLI(MyModel, MyData)

その他にも、様々なパラメータを指定することができます。

  • LightningModule - 例: --model=MyModel --model.feat_dim=64
  • LightningDataModule - 例: --data=MyDataModule --data.batch_size=128
  • Optimizer - 例: --optimizer=Adam --optimizer.lr=0.01
  • lr_scheduler - 例: --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1
  • Callback - 例: --trainer.callbacks=EarlyStopping --trainer.callbacks.patience=5
  • 自作クラスの登録も可能です。詳しくはドキュメントを参照ください。LightningCLI and Config Files - PyTorch Lightning documentation
$ python trainer.py fit \
   --trainer.callbacks=EarlyStopping \
   --trainer.callbacks.patience=5 \
   --trainer.callbacks=LearningRateMonitor \
   --trainer.callbacks.logging_interval=epoch \
   --optimizer=Adam \
   --optimizer.lr=0.01 \
   --lr_scheduler=ExponentialLR \
   --lr_scheduler.gamma=0.1

参考

7. Rich Progress Bar

プログレスバーの見た目を華やかにできます☆

rich-progress-bar.gif
画像は Super-Charged Progress Bars with Rich & Lightning - PyTorch Lightning Developer Blog より

$ pip install rich
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

# デフォルトは紫色
progress_bar = RichProgressBar()

# 画像例
# 指定可能な色は rich のドキュメント参考
# https://rich.readthedocs.io/en/stable/appendix/colors.html
progress_bar = RichProgressBar(
    theme=RichProgressBarTheme(
        description="green_yellow",
        progress_bar="green1",
        progress_bar_finished="green1",
        batch_progress="green_yellow",
        time="grey82",
        processing_speed="grey82",
        metrics="grey82",
    )
)

trainer = Trainer(callbacks=[progress_bar])

参考

8. Loop Customization

trainer.fit() 等の内部のフローをユーザは変更する術はありませんでしたが、v1.5 以降では Loop クラスのサブクラスをユーザが定義し、デフォルトの Loop と入れ替えることで部分的にカスタマイズできます。公式ブログでは次の 3 つのカスタムループが例として紹介されています。

参考

9. その他の新機能

  • CheckpointIO Plugin
  • Gradient clipping customization with LightningModule.configure_gradient_clipping
  • Callback.on_exception hook
  • torch.use_deterministic_algorithms support by pl.seed_everything
  • [experimental] Meta module for large sharded models.
  • [experimental] Inter-batch parallelism
  • [experimental] Training step with DataLoader iterator

おわりに

自己紹介

日本の田舎で学生をしています、Akihiro Nitta です。

PyTorch Lightning Slack

PyTorch Lightning に関連する質問は PyTorch Lightning Slack の #questions チャンネル、もしくは GitHub Discussions で受け付けています。

Grid.ai

PyTorch Lightning を開発する Grid.ai は、コードを一行も修正することなく、手元のローカル環境からクラウドへと学習をスケールするサービスを提供しています。本サービスは、TensorFlow、Keras、PyTorch などの機械学習ライブラリをすべてサポートしていますが、任意のライブラリを使用することができます。Early Stopping、Integrated Logging、Automatic Checkpointing、CLI などの PyTorch Lightning の機能を活用することで、モデル学習後の従来の MLOps を見えなくすることができます。
Google もしくは GitHub アカウントでサインアップすることができ、アカウントが認証されれば 25 米ドル分のクレジットが無料で付与されます。Grid.ai から是非ご登録ください。
https://miro.medium.com/max/945/0*BSvoUp30zZ8dkwpv.png
画像は Announcing Lightning v1.5 - PyTorch Medium より

16
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?