この記事でやること
「我流 DNN モデル作ったけどコード汚い」「事務作業(保存、ログ、DNN共通のコード)だるい」人向け
- AI 開発爆速ライブラリ Pytorch Lightning で
- きれいなコード管理&学習& tensorboard の可視化まで全部やる
Pytorch Lightning とは?
- 深層学習モデルのお決まり作業自動化 (モデルの保存、損失関数のログetc)!
- 可読性高い&コード共有も楽々に!
してくれるpythonライブラリ。
他を抑えてトップの github star 数&流行中のディープラーニングフレームワークである。
使い方
1. まずはinstall
$ pip install pytorch-lightning
2. 深層学習モデルを pytorch_lightning に従って書いていく
pytorch_lightning.LightningModule
を継承して、
- ネットワーク
- forward(self, x)、training_step(self, batch, batch_idx)、configure_optimizers(self)の3メソッド
の二つを定義すれば早速使える。ただし、関数名と引数の組は変えられないので注意!
(e.g. batch_idx いらなくてもtraining_step(self, batch)
みたいに定義するとバグったりする)
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
class LitMyModel(LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.log_softmax(x, dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
三つの関数はそれぞれ
「return ネットワークの出力」「1 loop 中の作業 & return 損失関数」「return オプティマイザ」
であればどんな処理でもOK
**長いけど見たい人向けにVAEの例 (Click)**
# MNIST を学習するFCの例
import pytorch_lightning as pl
class LitMyModel(pl.LightningModule):
def __init__(self):
# layers
self.fc1 = nn.Linear(self.out_size, 400)
self.fc4 = nn.Linear(400, self.out_size)
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.out_size))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def training_step(self, batch, batch_idx):
recon_batch, mu, logvar = self.forward(batch)
loss = self.loss_function(
recon_batch, batch, mu, logvar, out_size=self.out_size)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
return optimizer
もちろん、既にモデルがある人はコードを移動するだけでOK
あとはデータローダーとモデルをpl.Trainer()
のfit()
に入れればもう学習スタート!!
dataloader = #Your own dataloader or datamodule
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)
lightning 簡単、シュゴい。
3. 他の作業もこのクラスのメソッドに追加していく
上までで学習はできるようになったので、今度は** test ・validation・その他オプション**のメソッドをクラスに追加していく。
test
クラスメソッドにtest_step(self, batch, batch_idx)
を追加する。だけ。実行は
trainer.test()
validation
これもval_step()
メソッド、val_dataloader()
メソッドを追加すれば完成〜
dataloader
これもクラスメソッドにまとめて良いが、データセット&データローダーは別クラスのpytorch_lightning.LightningDataModule
を継承してMyDataModule
classを定義するのが推奨。
**長いけど見たい人向けにMNISTの例 (Click)**
class MyDataModule(LightningDataModule):
def init(self):
super().init()
self.train_dims = None
self.vocab_size = 0
def prepare_data(self):
# called only on 1 GPU
download_dataset()
tokenize()
build_vocab()
def setup(self):
# called on every GPU
vocab = load_vocab()
self.vocab_size = len(vocab)
self.train, self.val, self.test = load_datasets()
self.train_dims = self.train.next_batch.size()
def train_dataloader(self):
transforms = ...
return DataLoader(self.train, batch_size=64)
def val_dataloader(self):
transforms = ...
return DataLoader(self.val, batch_size=64)
def test_dataloader(self):
transforms = ...
return DataLoader(self.test, batch_size=64)
datamodule = MyDataModule()
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)
callback
「trainの初めだけやる処理」「エポック終わりにやる処理」のようなものも
https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks
あたりにいっぱい情報が載ってる。処理したいタイミング用の関数を定義してあげればOK
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('Trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
みたく別クラスに定義すれば簡潔に書けますね〜
4. tensorboard と連携させる&記録設定の追加
さて、ここからメインの記録の保存関係です。tensorboardに数値(lossやaccuracyなど)、画像、音声などを表示するためには、
with tf.name_scope('summary'):
tf.summary.scalar('loss', loss)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./logs', sess.graph)
見たいなコードを途中でぶっ刺したりして汚コードを作りがちでしたが、 pytorch_lightning は簡潔に書けて、
def training_step(self, batch, batch_idx):
# ...
loss = ...
self.logger.summary.scalar('loss', loss, step=self.global_step)
# equivalent
result = TrainResult()
result.log('loss', loss)
return result
のように記録する際のメソッド内でlogger.summary
に追加、もしくはreturn loss
の部分をpytorch_lightning.LightningModule.TrainResult()
クラスにいったん噛ませるだけで、自動的に保存ディレクトリ先に保存してくれます!
loggerはTrainer()
クラスのコンストラクタに追加すればOKで、保存ディレクトリもここで決定します。
from pytorch_lightning import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)
また、テキストや画像などのデータに関してもlogger.experiment
オブジェクトの.add_hogehoge()
を使って保存することができます!
def training_step(...):
...
# the logger you used (in this case tensorboard)
tensorboard = self.logger.experiment
tensorboard.add_histogram(...)
tensorboard.add_figure(...)
Callbackのタイミングなんかもおすすめだよ、って公式も言ってますね。
シュゴい...(大事なことなので2回言いまs(ry
終わりに
使ってみた所感として Pytorch Lightning は ~~(ignite が処理を差し込みまくって可読性悪いのとかに比べると)~~ルールが分かりやすいし、クラス設計もドキュメント整備もちゃんとしていたので、最初に使ってみるのにおすすめなディープラーニングフレームワークであるなと感じました〜