はじめに
当記事では以下の PyTorch Lightning v1.5 のリリースノートとブログ投稿を基に、PyTorch Lightning v1.5 の新機能を簡略に紹介します。
対象読者
- PyTorch Lightning を使ったことがある人
- PyTorch を使ったことがある人 (端折って説明しているのでわからない部分があるかもしれません)
機能紹介一覧
以下、インパクトのある (と個人的に思う) 順番にしています。
- Batch-Level Fault-Tolerant Training
- BFloat16
- Trainer ハードウェア関係引数
- 無限 Epochs
- Lightning Lite
- LightningCLI V2
- Rich Progress Bar
- Loop Customization
- その他の機能
機能紹介
1. Batch-Level Fault-Tolerant Training
学習途中に予期しないエラーで中断しても、中断したバッチから学習を再開でき、環境変数を設定するだけで利用可能です。
PL_FAULT_TOLERANT_TRAINING=1 python script.py
参考
2. BFloat16
PyTorch 1.10 以降でサポートされる torch.bfloat16
(Brain Floating Point) を利用することで torch.float16
の Automatic Mixed Precision よりも安定した学習が可能になります。
from pytorch_lightning import Trainer
Trainer(precision="bf16")
参考
- Introducing Faster Training with Lightning and Brain Float16 - PyTorch Lightning Developer Blog
- BFloat16: The secret to high performance on Cloud TPUs - Google Cloud Blog
3. Trainer ハードウェア関係引数
accelerator="auto"
や devices="auto"
で自動的に検出した利用可能なデバイスを指定できます。
-
accelerator
はハードウェアの種類を指す:"cpu"
"gpu"
"tpu"
"ipu"
"auto"
-
strategy
はハードウェアの利用方法を指す:"dp"
"ddp"
"ddp_spawn"
"deepspeed"
等 -
devices
はデバイスの数を指す: 整数もしくは"auto"
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin
# DDP with 4 GPUs
Trainer(accelerator="gpu", devices=4, strategy="ddp")
Trainer(accelerator="gpu", devices=4, strategy=DDPPlugin(...))
# DDP Spawn with detected accelerator (gpu, tpu, ...)
Trainer(accelerator="auto", devices=4, strategy="ddp_spawn")
# DDP with available GPUs
Trainer(accelerator="gpu", devices="auto", strategy="ddp")
参考
- Announcing the new Lightning Trainer Strategy API - PyTorch Lightning Developer Blog
- Trainer Flags - PyTorch Lightning documentation
4. 無限 Epochs
-1
で無限 epochs/steps を設定できます。
from pytorch_lightning import Trainer
# infinite epochs
trainer = Trainer(max_epochs=-1)
# an endless epoch
trainer = Trainer(max_steps=-1)
5. Lightning Lite
LightningModule
のサブクラスを実装しなくても、既存の PyTorch コードに最小限の変更を加えることで、各種アクセラレータを利用し訓練できます。下記のように定義した LightningLite
のサブクラスのコンストラクタに好きな設定 (GPU, TPU, DDP, DeepSpeed, amp など) を引数として与えインスタンス化し、run()
メソッドを実行します。
+from pytorch_lightning.lite import LightningLite
+class Lite(LightningLite):
- def run(num_epochs):
+ def run(self, num_epochs, ...):
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model = MyModel(...).to(device) # .to(...) は不要
+ model = MyModel(...)
optimizer = torch.optim.SGD(model.parameters(), ...)
+ model, optimizer = self.setup(model, optimizer)
dataloader = DataLoader(MyDataset(...), ...)
+ dataloader = self.setup_dataloaders(dataloader)
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
- batch = batch.to(device) # .to(...) は不要
optimizer.zero_grad()
loss = model(batch)
- loss.backward()
+ self.backward(loss)
optimizer.step()
Lite(devices=8, accelerator="gpu", precision="bf16", strategy="ddp").run(10, ...)
Lite(devices=8, accelerator="gpu", precision=16, strategy="deepspeed").run(10, ...)
Lite(devices="auto", accelerator="auto", precision=16).run(10, ...)
参考
- Scale your PyTorch code with LightningLite - PyTorch Lightning Developer Blog
- LightningLite - Stepping Stone to Lightning - PyTorch Lightning documentation
6. LightningCLI V2
スクリプト内で Trainer.{fit,validate,test,predict,tune}
を記述せずに、コマンドライン引数から指定して呼び出すことができます。
$ pip install "jsonargparse[signatures]"
$ python trainer.py fit
# trainer.py
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.utilities import LightningCLI
class MyModel(LightningModule):
...
class MyData(LightningDataModule):
...
cli = LightningCLI(MyModel, MyData)
その他にも、様々なパラメータを指定することができます。
-
LightningModule
- 例:--model=MyModel --model.feat_dim=64
-
LightningDataModule
- 例:--data=MyDataModule --data.batch_size=128
-
Optimizer
- 例:--optimizer=Adam --optimizer.lr=0.01
-
lr_scheduler
- 例:--lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1
-
Callback
- 例:--trainer.callbacks=EarlyStopping --trainer.callbacks.patience=5
- 自作クラスの登録も可能です。詳しくはドキュメントを参照ください。LightningCLI and Config Files - PyTorch Lightning documentation
$ python trainer.py fit \
--trainer.callbacks=EarlyStopping \
--trainer.callbacks.patience=5 \
--trainer.callbacks=LearningRateMonitor \
--trainer.callbacks.logging_interval=epoch \
--optimizer=Adam \
--optimizer.lr=0.01 \
--lr_scheduler=ExponentialLR \
--lr_scheduler.gamma=0.1
参考
- Introducing LightningCLI V2 - PyTorch Lightning Developer Blog
- LightningCLI and Config Files - PyTorch Lightning documentation
7. Rich Progress Bar
プログレスバーの見た目を華やかにできます☆
画像は Super-Charged Progress Bars with Rich & Lightning - PyTorch Lightning Developer Blog より
$ pip install rich
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
# デフォルトは紫色
progress_bar = RichProgressBar()
# 画像例
# 指定可能な色は rich のドキュメント参考
# https://rich.readthedocs.io/en/stable/appendix/colors.html
progress_bar = RichProgressBar(
theme=RichProgressBarTheme(
description="green_yellow",
progress_bar="green1",
progress_bar_finished="green1",
batch_progress="green_yellow",
time="grey82",
processing_speed="grey82",
metrics="grey82",
)
)
trainer = Trainer(callbacks=[progress_bar])
参考
- Super-Charged Progress Bars with Rich & Lightning - PyTorch Lightning Developer Blog
- RichProgressBar - PyTorch Lightning documentation
- 色見本: Standard Colors - Rich documentation
8. Loop Customization
trainer.fit()
等の内部のフローをユーザは変更する術はありませんでしたが、v1.5 以降では Loop
クラスのサブクラスをユーザが定義し、デフォルトの Loop
と入れ替えることで部分的にカスタマイズできます。公式ブログでは次の 3 つのカスタムループが例として紹介されています。
参考
- Train anything with Lightning custom Loops - PyTorch Lightning Developer Blog
- Loops - PyTorch Lightning documentation
9. その他の新機能
- CheckpointIO Plugin
- Gradient clipping customization with
LightningModule.configure_gradient_clipping
-
Callback.on_exception
hook -
torch.use_deterministic_algorithms
support bypl.seed_everything
- [experimental] Meta module for large sharded models.
- [experimental] Inter-batch parallelism
- [experimental] Training step with
DataLoader
iterator
おわりに
自己紹介
日本の田舎で学生をしています、Akihiro Nitta です。
PyTorch Lightning Slack
PyTorch Lightning に関連する質問は PyTorch Lightning Slack の #questions
チャンネル、もしくは GitHub Discussions で受け付けています。
- PyTorch Lightning Slack 登録リンク: https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ
- GitHub Discussions: https://github.com/PyTorchLightning/pytorch-lightning/discussions
Grid.ai
PyTorch Lightning を開発する Grid.ai は、コードを一行も修正することなく、手元のローカル環境からクラウドへと学習をスケールするサービスを提供しています。本サービスは、TensorFlow、Keras、PyTorch などの機械学習ライブラリをすべてサポートしていますが、任意のライブラリを使用することができます。Early Stopping、Integrated Logging、Automatic Checkpointing、CLI などの PyTorch Lightning の機能を活用することで、モデル学習後の従来の MLOps を見えなくすることができます。
Google もしくは GitHub アカウントでサインアップすることができ、アカウントが認証されれば 25 米ドル分のクレジットが無料で付与されます。Grid.ai から是非ご登録ください。
画像は Announcing Lightning v1.5 - PyTorch Medium より