LoginSignup
2
2

More than 1 year has passed since last update.

pytorch-lightningのModelCheckpointの挙動を確認する

Posted at

目的

pytorch-lightningでvalidationのlossが小さいモデルを保存したいとき、ModelCheckpointを使います。ドキュメントにはmonitorにlossの名前を渡すとありますが、validation_stepでの値を渡しても、途中のあるバッチでlossが最小になったときに記録されるのか、全体の値が最小になったときに記録されるかよくわかりませんでした。

今回は実際に分かりやすいデータを流してみて挙動を確認したいと思います。

モデル

今回使用するモデルは以下です。

from copy import deepcopy

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
        
class ToyModel(pl.LightningModule):
    def __init__(self, epoch_div=True):
        super().__init__()
        self.linear = torch.nn.Linear(1,1)
        self.val_step_loss = []
        self.val_epoch_loss = []
        self.states = []

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.linear.parameters(), lr=0.1)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        true, pred = batch
        loss = F.l1_loss(true, pred)
        dummy_loss = F.l1_loss(self.linear(true), pred)
        self.log(f"train_loss_step", loss, prog_bar=True)
        return dummy_loss
    
    def validation_step(self, batch, batch_idx):
        """
        [1,2,...,10] -> [11,12,...,20] -> [1000,0,0,...,0]
        """
        true, pred = batch
        if self.current_epoch < 2:
            loss = F.l1_loss(true, pred) + self.current_epoch * 10
        else:
            loss = torch.tensor(0.0) if batch_idx != 0 else torch.tensor(1000.0)
        self.log(f"val_loss_step", loss, prog_bar=True)
        self.val_step_loss.append(loss)
        return loss
    
    def validation_epoch_end(self, outs) -> None:
        """
        1 -> 0.5 -> 0.33
        """
        loss = torch.tensor(1 / (self.current_epoch + 1))
        self.log(f"val_loss_epoch", loss, prog_bar=True)
        self.val_epoch_loss.append(loss)
        self.states.append(deepcopy(self.state_dict()))

        
class ToyDataset(Dataset):
    def __init__(self):
        datas = []
        for i in range(10):
            datas.append((torch.tensor([float(i+1)]), torch.zeros(1)))
        self.datas = datas
    
    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, i):
        return self.datas[i]

データセットは(torch.tensor(batch_idx), torch.tensor(0.0))を返すようにしています。これでlossが決定的に決まり分かりやすくなります。

モデルではtraining_stepでは雑にモデルを更新しています。今回はこの部分は関係ありません。

validation_stepではlossがepoch-0では1,2,...,10、epoch-1では11,12,...,20、epoch-2では1000,0,0,...,0と決定的にしました。val_loss_stepの瞬間最小値はepoch-2、全体の最小値はepoch-0になります。

validation_epoch_endではlossがepoch-0では1.0、epoch-1では0.5、epoch-2では0.33と決定的にしました。
val_loss_epochの最小値はepoch-2になります。

実行

model = ToyModel()
dataset = ToyDataset()
loader = DataLoader(dataset, batch_size=1)
callbacks = [
    pl.callbacks.ModelCheckpoint(
        monitor="val_loss_step",
        filename="ckpt_step_{epoch}-{step}-{val_loss_step:.2f}-{val_loss_epoch:.2f}"
    ),
    pl.callbacks.ModelCheckpoint(
        monitor="val_loss_epoch",
        filename="ckpt_epoch_{epoch}-{step}-{val_loss_step:.2f}-{val_loss_epoch:.2f}"
    ),
    
]
trainer = pl.Trainer(callbacks=callbacks, max_epochs=3, num_sanity_val_steps=0, enable_progress_bar=True)
trainer.fit(model, train_dataloaders=loader, val_dataloaders=loader)

学習を行います。monitorが異なる2種のModelCheckpointを作成しました。それぞれval_loss_stepval_loss_epochをモニタリングします。名前にそれぞれのlossの値を加えるようにしました。
保存されたチェックポイントは以下になりました。

>>> ls lightning_logs/version_0/checkpoints/
ckpt_epoch_epoch=2-step=30-val_loss_step=100.00-val_loss_epoch=0.33.ckpt
ckpt_step_epoch=0-step=10-val_loss_step=5.50-val_loss_epoch=1.00.ckpt

val_loss_stepをモニタリングした場合はepoch-0が保存されています。これはval_loss_stepのepoch全体の値を平均化した値をモニタリングしているようです。実際val_loss_step=5.50となっており、epoch-0のval_loss_step平均は

>>> torch.stack(model.val_step_loss[:10]).mean()
tensor(5.5000)

となっています。

一方で、val_loss_epochをモニタリングした場合はepoch-2が保存されています。このようにvalidation_epoch_endで設定したlossも使用可能のようです。特別な統計処理を用いたい場合はvalidation_epoch_endで処理を定義し、それをモニタリングすればよさそうです。

チェックポイントを読み込んでみると、モデルの重みもしっかり一致していました。

>>> model.states[0]
OrderedDict([('linear.weight', tensor([[0.0303]])),
             ('linear.bias', tensor([0.2593]))])
>>> ToyModel.load_from_checkpoint("lightning_logs/version_1/checkpoints/ckpt_epoch_epoch=2-step=30-val_loss_step=100.00-val_loss_epoch=0.33.ckpt").state_dict()
OrderedDict([('linear.weight', tensor([[0.0303]])),
             ('linear.bias', tensor([0.2593]))])

まとめ

ModelCheckpointmonitorvalidation_step内で定義されたlossを渡すと、平均化した値をモニタリングするようです。これに限らずpytorch-lightningにおいてvalidation_step内でlog()を呼ぶと、reduce_fxでまとめられる挙動になっていそうです(progress barで出力される情報など)。

もう一つ注意点として、ModelCheckpointevery_n_train_stepsTrainerval_check_intervalはどちらもtrainingの1epoch中に複数回のvalidationを組み込むことができます。過剰なvalidation回数にならないようにどちらかのみ設定するのが良いと思います。

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2