15
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Pytorch Lightning 使い方入門 ~ 自作モデルを整形して tensorboard に出力するまで ~

Posted at

この記事でやること

「我流 DNN モデル作ったけどコード汚い」「事務作業(保存、ログ、DNN共通のコード)だるい」人向け

  • AI 開発爆速ライブラリ Pytorch Lightning で
  • きれいなコード管理&学習& tensorboard の可視化まで全部やる

Pytorch Lightning とは?

  • 深層学習モデルのお決まり作業自動化 (モデルの保存、損失関数のログetc)!
  • 可読性高い&コード共有も楽々に!

してくれるpythonライブラリ。
他を抑えてトップの github star 数&流行中のディープラーニングフレームワークである。

使い方

1. まずはinstall

console
$ pip install pytorch-lightning

2. 深層学習モデルを pytorch_lightning に従って書いていく

pytorch_lightning.LightningModule を継承して、

  • ネットワーク
  • forward(self, x)、training_step(self, batch, batch_idx)、configure_optimizers(self)の3メソッド

の二つを定義すれば早速使える。ただし、関数名と引数の組は変えられないので注意!
(e.g. batch_idx いらなくてもtraining_step(self, batch)みたいに定義するとバグったりする)

MyModel.py
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMyModel(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)

    x = F.log_softmax(x, dim=1)
    return x

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

三つの関数はそれぞれ
「return ネットワークの出力」「1 loop 中の作業 & return 損失関数」「return オプティマイザ」
であればどんな処理でもOK

**長いけど見たい人向けにVAEの例 (Click)**
# MNIST を学習するFCの例
import pytorch_lightning as pl

class LitMyModel(pl.LightningModule):
    def __init__(self)
        # layers
        self.fc1 = nn.Linear(self.out_size, 400)
        self.fc4 = nn.Linear(400, self.out_size)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.out_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def training_step(self, batch, batch_idx):
        recon_batch, mu, logvar = self.forward(batch)
        loss = self.loss_function(
            recon_batch, batch, mu, logvar, out_size=self.out_size)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        return optimizer

もちろん、既にモデルがある人はコードを移動するだけでOK
あとはデータローダーとモデルをpl.Trainer()fit()に入れればもう学習スタート!!

実行時
dataloader = #Your own dataloader or datamodule

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)

lightning 簡単、シュゴい。


3. 他の作業もこのクラスのメソッドに追加していく

上までで学習はできるようになったので、今度は** test ・validation・その他オプション**のメソッドをクラスに追加していく。

test

クラスメソッドにtest_step(self, batch, batch_idx)を追加する。だけ。実行は

test実行時
trainer.test()

validation

これもval_step()メソッド、val_dataloader()メソッドを追加すれば完成〜

dataloader

これもクラスメソッドにまとめて良いが、データセット&データローダーは別クラスのpytorch_lightning.LightningDataModuleを継承してMyDataModuleclassを定義するのが推奨

**長いけど見たい人向けにMNISTの例 (Click)**

class MyDataModule(LightningDataModule):
def init(self):
super().init()
self.train_dims = None
self.vocab_size = 0

def prepare_data(self):
    # called only on 1 GPU
    download_dataset()
    tokenize()
    build_vocab()

def setup(self):
    # called on every GPU
    vocab = load_vocab()
    self.vocab_size = len(vocab)

    self.train, self.val, self.test = load_datasets()
    self.train_dims = self.train.next_batch.size()

def train_dataloader(self):
    transforms = ...
    return DataLoader(self.train, batch_size=64)

def val_dataloader(self):
    transforms = ...
    return DataLoader(self.val, batch_size=64)

def test_dataloader(self):
    transforms = ...
    return DataLoader(self.test, batch_size=64)
これを学習&テスト時に```.fit()```に噛ませればdata_loaderを渡さなくても勝手に解釈してくれる。
実行時
datamodule = MyDataModule()

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)

callback

「trainの初めだけやる処理」「エポック終わりにやる処理」のようなものも
https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks
あたりにいっぱい情報が載ってる。処理したいタイミング用の関数を定義してあげればOK

from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):
    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('Trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])

みたく別クラスに定義すれば簡潔に書けますね〜

4. tensorboard と連携させる&記録設定の追加

さて、ここからメインの記録の保存関係です。tensorboardに数値(lossやaccuracyなど)、画像、音声などを表示するためには、

tensorflowの例
with tf.name_scope('summary'):
  tf.summary.scalar('loss', loss)
  merged = tf.summary.merge_all()
  writer = tf.summary.FileWriter('./logs', sess.graph)

見たいなコードを途中でぶっ刺したりして汚コードを作りがちでしたが、 pytorch_lightning は簡潔に書けて、

MyModel.py
def training_step(self, batch, batch_idx):
  # ...
  loss = ...
  self.logger.summary.scalar('loss', loss, step=self.global_step)

  # equivalent
  result = TrainResult()
  result.log('loss', loss)

  return result

のように記録する際のメソッド内でlogger.summaryに追加、もしくはreturn lossの部分をpytorch_lightning.LightningModule.TrainResult()クラスにいったん噛ませるだけで、自動的に保存ディレクトリ先に保存してくれます!

loggerはTrainer()クラスのコンストラクタに追加すればOKで、保存ディレクトリもここで決定します。

from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)

また、テキストや画像などのデータに関してもlogger.experimentオブジェクトの.add_hogehoge()を使って保存することができます!

MyModel.py
def training_step(...):
  ...
  # the logger you used (in this case tensorboard)
  tensorboard = self.logger.experiment
  tensorboard.add_histogram(...)
  tensorboard.add_figure(...)

Callbackのタイミングなんかもおすすめだよ、って公式も言ってますね。

シュゴい...(大事なことなので2回言いまs(ry

終わりに

使ってみた所感として Pytorch Lightning は ~~(ignite が処理を差し込みまくって可読性悪いのとかに比べると)~~ルールが分かりやすいし、クラス設計もドキュメント整備もちゃんとしていたので、最初に使ってみるのにおすすめなディープラーニングフレームワークであるなと感じました〜

15
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?