目的
pytorch-lightningでvalidationのlossが小さいモデルを保存したいとき、ModelCheckpoint
を使います。ドキュメントにはmonitor
にlossの名前を渡すとありますが、validation_step
での値を渡しても、途中のあるバッチでlossが最小になったときに記録されるのか、全体の値が最小になったときに記録されるかよくわかりませんでした。
今回は実際に分かりやすいデータを流してみて挙動を確認したいと思います。
モデル
今回使用するモデルは以下です。
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class ToyModel(pl.LightningModule):
def __init__(self, epoch_div=True):
super().__init__()
self.linear = torch.nn.Linear(1,1)
self.val_step_loss = []
self.val_epoch_loss = []
self.states = []
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.linear.parameters(), lr=0.1)
return optimizer
def training_step(self, batch, batch_idx):
true, pred = batch
loss = F.l1_loss(true, pred)
dummy_loss = F.l1_loss(self.linear(true), pred)
self.log(f"train_loss_step", loss, prog_bar=True)
return dummy_loss
def validation_step(self, batch, batch_idx):
"""
[1,2,...,10] -> [11,12,...,20] -> [1000,0,0,...,0]
"""
true, pred = batch
if self.current_epoch < 2:
loss = F.l1_loss(true, pred) + self.current_epoch * 10
else:
loss = torch.tensor(0.0) if batch_idx != 0 else torch.tensor(1000.0)
self.log(f"val_loss_step", loss, prog_bar=True)
self.val_step_loss.append(loss)
return loss
def validation_epoch_end(self, outs) -> None:
"""
1 -> 0.5 -> 0.33
"""
loss = torch.tensor(1 / (self.current_epoch + 1))
self.log(f"val_loss_epoch", loss, prog_bar=True)
self.val_epoch_loss.append(loss)
self.states.append(deepcopy(self.state_dict()))
class ToyDataset(Dataset):
def __init__(self):
datas = []
for i in range(10):
datas.append((torch.tensor([float(i+1)]), torch.zeros(1)))
self.datas = datas
def __len__(self):
return len(self.datas)
def __getitem__(self, i):
return self.datas[i]
データセットは(torch.tensor(batch_idx), torch.tensor(0.0))
を返すようにしています。これでlossが決定的に決まり分かりやすくなります。
モデルではtraining_step
では雑にモデルを更新しています。今回はこの部分は関係ありません。
validation_step
ではlossがepoch-0では1,2,...,10、epoch-1では11,12,...,20、epoch-2では1000,0,0,...,0と決定的にしました。val_loss_step
の瞬間最小値はepoch-2、全体の最小値はepoch-0になります。
validation_epoch_end
ではlossがepoch-0では1.0、epoch-1では0.5、epoch-2では0.33と決定的にしました。
val_loss_epoch
の最小値はepoch-2になります。
実行
model = ToyModel()
dataset = ToyDataset()
loader = DataLoader(dataset, batch_size=1)
callbacks = [
pl.callbacks.ModelCheckpoint(
monitor="val_loss_step",
filename="ckpt_step_{epoch}-{step}-{val_loss_step:.2f}-{val_loss_epoch:.2f}"
),
pl.callbacks.ModelCheckpoint(
monitor="val_loss_epoch",
filename="ckpt_epoch_{epoch}-{step}-{val_loss_step:.2f}-{val_loss_epoch:.2f}"
),
]
trainer = pl.Trainer(callbacks=callbacks, max_epochs=3, num_sanity_val_steps=0, enable_progress_bar=True)
trainer.fit(model, train_dataloaders=loader, val_dataloaders=loader)
学習を行います。monitor
が異なる2種のModelCheckpoint
を作成しました。それぞれval_loss_step
、val_loss_epoch
をモニタリングします。名前にそれぞれのlossの値を加えるようにしました。
保存されたチェックポイントは以下になりました。
>>> ls lightning_logs/version_0/checkpoints/
ckpt_epoch_epoch=2-step=30-val_loss_step=100.00-val_loss_epoch=0.33.ckpt
ckpt_step_epoch=0-step=10-val_loss_step=5.50-val_loss_epoch=1.00.ckpt
val_loss_step
をモニタリングした場合はepoch-0が保存されています。これはval_loss_step
のepoch全体の値を平均化した値をモニタリングしているようです。実際val_loss_step=5.50となっており、epoch-0のval_loss_step
平均は
>>> torch.stack(model.val_step_loss[:10]).mean()
tensor(5.5000)
となっています。
一方で、val_loss_epoch
をモニタリングした場合はepoch-2が保存されています。このようにvalidation_epoch_end
で設定したlossも使用可能のようです。特別な統計処理を用いたい場合はvalidation_epoch_end
で処理を定義し、それをモニタリングすればよさそうです。
チェックポイントを読み込んでみると、モデルの重みもしっかり一致していました。
>>> model.states[0]
OrderedDict([('linear.weight', tensor([[0.0303]])),
('linear.bias', tensor([0.2593]))])
>>> ToyModel.load_from_checkpoint("lightning_logs/version_1/checkpoints/ckpt_epoch_epoch=2-step=30-val_loss_step=100.00-val_loss_epoch=0.33.ckpt").state_dict()
OrderedDict([('linear.weight', tensor([[0.0303]])),
('linear.bias', tensor([0.2593]))])
まとめ
ModelCheckpoint
のmonitor
にvalidation_step
内で定義されたlossを渡すと、平均化した値をモニタリングするようです。これに限らずpytorch-lightningにおいてvalidation_step
内でlog()
を呼ぶと、reduce_fx
でまとめられる挙動になっていそうです(progress barで出力される情報など)。
もう一つ注意点として、ModelCheckpoint
のevery_n_train_steps
とTrainer
のval_check_interval
はどちらもtrainingの1epoch中に複数回のvalidationを組み込むことができます。過剰なvalidation回数にならないようにどちらかのみ設定するのが良いと思います。