はじめに
これは、PyTorchは少し知っているけど、PyTorch Lightningは全く知らない人が、PyTorch Lightningを勉強してみましたという話です。
公式のサンプルを元にPyTorchの初歩的な部分を解説しています。
この記事はHEROZ社内での勉強会の内容をリライトしたものです。
Pytorch Lightningについて
Pytorch Lightningについて簡単に概要を触れておくと、Pytorch LightningはPytorchのラッパーで、
学習ループなどの定型文(boilerplate)をラッピングし学習周りのコードを簡潔にわかりやすく書けるようにするライブラリです。
類似のフレームワークとしては、以下のものがあります。
- Ignite
- Catalyst
最近は、PyTorch Lightningがだいぶ優勢で、利用例や解説記事も多く見つかります。
Pytorch Lightningの基本
以下の3つの部分がPytorch Lightningの基本です。
- データに関する部分(LightningDataModule)
- 学習に関する部分(LightningModule)
- 実際の学習を行う部分(Trainer)
LightningDataModule
データ周りの処理を集約したクラスです。
以下の公式のサンプルのように、train/validation/testのそれぞれで使用するPyTorchのDataLoaderを返すメソッドを作成します。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
...
公式の例だとMNISTを使っているのでシンプルに見えますが、実際には普通にPyTorchのDatasetを書く必要があるので、普通にPyTorchで書くのと大差はありません。
LigntningModule
学習で行うことを集約したクラスです。
以下の公式のサンプルのように、モデルの定義、forwardの定義、train/val/testの各stepで実行すること、Optimizerの定義を記載します。
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
こちらも、普通にPyTorchで書くのと大差はなく、コードが集約されることが最大の利点です。
Trainer
最後に、作成したLightningDataModuleとLightningModuleをTrainerに渡し、fitを呼べば学習が開始します。
https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html#train-the-model
PyTorchで学習するコードを書くには、このあたりいろいろと書かなくてはいけないですが、Trainer作ってfitを呼べばとりあえず学習は始まります。
また、Trainerには、以下のよいところがあります。
- ModelCheckPointやEarlyStoppingなどの基本的な機能は実装されており、callbacksに指定するのみ
- TensorBoardやMLFlow等の様々な実験管理ライブラリに対応しており、loggerに指定するのみ
- デバイス依存のコード(CPU/GPU/TPU)も楽にかける
このあたりはいろいろと用意されていて、モデルの開発に注力できるというところがPyTorch Lightningの利点なのかなというように思います。
コンピュータ将棋界での使用
コンピュータ将棋に関連したところでもちらほらPyTorch Lightningを使用しているケースがありますので、ちょっと紹介します。
チェスソフトのStockfishの学習ルーチンであるnnue-pytorchにPytorch Lightningが使用されています。
https://github.com/glinscott/nnue-pytorch
こちらは、nodchipさんの手によって将棋の学習もできるようになっています。
https://github.com/nodchip/nnue-pytorch
また、プロ将棋棋士でもある谷合さんがヨビノリたくみさんとのYoutube対戦のために作成された、Bertの将棋ソフトでも学習にPytorch Lightningが使用されています。
https://github.com/nyoki-mtl/bert-mcts-youtube
まとめ
PyTorch Lightningを使用する利点としては、
- 利用例がたくさんあり、調べればたくさんの情報がある
- PyTorchだけで書くより楽にコードが書ける
- コードの書き方がある程度統一される
また、欠点は、
- PyTorch Lightningの作法を覚えないといけない
- PyTorch Lightningで用意されていないものを使おうとすると結構大変
このあたりは、フレームワークを使う上では避けられない点で、この点が気にならなければ良いフレームワークかなという気がしました。
参考文献
最後に、この記事を書くにあたり参考にしたWeb記事のリンクを張っておきます。どれも、Pytorch Lightningを知るのにとても勉強になりました。
- PyTorch Lightningについて紹介されてる資料
- PyTorch Lightningで書かれたKaggle Notebook
おまけ
HEROZでは、2022年7月に将棋解析サービス、棋神アナリティクスをリリースしました。
https://kishin-analytics.heroz.jp/
最新・最強の将棋ソフトでお手軽に将棋の解析ができるサービスです。
さらに、2022年12月に月額1100円から使用できるプランもリリースしました。
アカウント作成から30日以内の間、無料で1時間までお試しで解析できるので、ぜひ一度試していただければうれしいです。