目的
pytorch-lightningのprogressbarで、出力が何の値であるかいまいち分からなかったので調査しました。以下の2点が対象です。
- デフォルトで出力される
loss
の値 - torchmetricsの
Metric
をlog()
に渡した時の値
loss
の値
これは2つの候補がありそうです。
- training全体でのlossの平均値
- epoch内でのlossの平均値
どちらか判定するため、以下のコードを実行します。
import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchmetrics
class ToyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.all_loss = []
self.epoch_loss = []
self.train_metric = torchmetrics.MeanAbsoluteError()
self.val_metric = torchmetrics.MeanAbsoluteError()
def training_step(self, batch, batch_idx):
data, true = batch
pred = self.linear(data)
loss = self.train_metric(true, pred)
self.all_loss.append(loss.item())
self.epoch_loss.append(loss.item())
self.log("train_loss", self.train_metric)
time.sleep(1)
return loss
def training_epoch_start(self, outs):
self.epoch_loss = []
def validation_step(self, batch, batch_idx):
data, true = batch
pred = self.linear(data)
loss = self.val_metric(true, pred)
self.log("val_loss", self.val_metric)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), 0.01)
class ToyDataset(Dataset):
def __init__(self):
self.datas = [(torch.rand(10), torch.rand(1)) for _ in range(4*10)]
def __getitem__(self, i):
return self.datas[i]
def __len__(self):
return len(self.datas)
class MyProgressBar(TQDMProgressBar):
def get_metrics(self, trainer, pl_module):
items = super().get_metrics(trainer, pl_module)
items["all_loss_mean"] = np.mean(pl_module.all_loss or float("nan"))
items["epoch_loss_mean"] = np.mean(pl_module.epoch_loss or float("nan"))
return items
def main():
model = ToyModel()
dataset = ToyDataset()
callbacks = [MyProgressBar()]
trainer = pl.Trainer(max_epochs=2, callbacks=callbacks, log_every_n_steps=1)
trainer.fit(
model,
train_dataloaders=DataLoader(dataset, batch_size=4),
val_dataloaders=DataLoader(dataset, batch_size=4)
)
if __name__ == "__main__":
main()
LightningModule
にtraining全体のlossを保持しておくリスト(all_loss
)とepoch内のlossを保持しておくリスト(epoch_loss
)を設定しました。それらをprogressbarに表示する、MyProgressBar
のcallbackを作成しています。
time.sleep(1)
は確認用です。
これを実行します。epoch-0とepoch-1でのprogressbarの出力を切り出したものが以下になります。
Epoch 0: 40%|████ | 8/20 [00:08<00:12, 1.01s/it, loss=0.803, v_num=3, all_loss_mean=0.803, epoch_loss_mean=0.803]
Epoch 1: 40%|████ | 8/20 [00:08<00:12, 1.00s/it, loss=0.524, v_num=1, all_loss_mean=0.524, epoch_loss_mean=0.316]
epoch-0では全て同じ値ですが、epoch-1ではloss
はepoch_loss_mean
とは異なっています。
つまり、loss
が表示しているものはepoch内のloss平均ではなく、trainnig全体でのlossの平均になります。
ソースコードを見ても
running_train_loss = trainer.fit_loop.running_loss.mean()
となっており、全体の平均であることが伺えます。
torchmetricsのMetric
をlog()
に渡した時の値
torchmetricsのMetric
ではmetric(true, pred)
でlossの値が計算され、クラス内に保存されます。何度かlossを計算したのち、compute
を呼ぶとそれまでの平均などを計算することができます。
Metric
はpytorchlightningのTrainer.log
に渡すことが可能です。このとき、どの値が出力されるのでしょうか?以下の関数内で呼んだ時の値を確認します。
- training_step
- training_epoch_end
- validation_step
- validation_epoch_end
確認用のコードは以下になります。(変更点のみ記載)
class ToyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
self.train_metric = torchmetrics.MeanAbsoluteError()
self.val_metric = torchmetrics.MeanAbsoluteError()
def _get_loss_dict(self, metric, loss, prefix):
metrics = {
f"{prefix}_m": metric,
f"{prefix}_c": metric.compute(),
f"{prefix}_l": loss
}
return metrics
def training_step(self, batch, batch_idx):
data, true = batch
pred = self.linear(data)
loss = self.train_metric(true, pred)
metrics = self._get_loss_dict(self.train_metric, loss, "ts")
self.log_dict(metrics, prog_bar=True)
time.sleep(1)
return loss
def training_epoch_end(self, outs):
loss = torch.stack([i["loss"] for i in outs]).mean()
metrics = self._get_loss_dict(self.train_metric, loss, "te")
self.log_dict(metrics, prog_bar=True)
def validation_step(self, batch, batch_idx):
data, true = batch
pred = self.linear(data)
loss = self.val_metric(true, pred)
metrics = self._get_loss_dict(self.val_metric, loss, "vs")
self.log_dict(metrics, prog_bar=True)
time.sleep(1)
return loss
def validation_epoch_end(self, outs):
loss = torch.stack(outs).mean()
metrics = self._get_loss_dict(self.val_metric, loss, "ve")
self.log_dict(metrics, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), 0.01)
class MyProgressBar(TQDMProgressBar):
def get_metrics(self, trainer, pl_module):
items = super().get_metrics(trainer, pl_module)
for prop in ["v_num", "loss"]:
items.pop(prop)
return items
これを実行すると以下になります。
Epoch 1: 20%|███▍ | 4/20 [00:04<00:16, 1.01s/it, ts_m=0.184, ts_c=0.255, ts_l=0.184, vs_m=0.263, vs_c=0.283, vs_l=0.263, ve_m=0.263, ve_c=0.263, ve_l=0.263, te_m=0.351, te_c=0.351, te_l=0.351]
今度はやや分かりにくいのでtensorboadの出力も見てみます。コードは省略しますが、各lossは以下になります。
=== ts_m ===
[0.5485677719116211, 0.47600966691970825, 0.33111822605133057, 0.5658555030822754, 0.4193665385246277,
0.1254565715789795, 0.33993926644325256, 0.3271455764770508, 0.18989351391792297, 0.18316617608070374,
0.32229477167129517, 0.21195293962955475, 0.30307191610336304, 0.1840662956237793, 0.44732925295829773,
0.2892676293849945, 0.3235018849372864, 0.28755465149879456, 0.1626671552658081, 0.2397872805595398]
=== ts_c ===
[0.5485677719116211, 0.5122886896133423, 0.4518985450267792, 0.48038777709007263, 0.4681835174560547,
0.4110623598098755, 0.4009019434452057, 0.39168238639831543, 0.36926138401031494, 0.35065189003944397,
0.32229477167129517, 0.26712384819984436, 0.2791065275669098, 0.25534647703170776, 0.2937430441379547,
0.2929971516132355, 0.29735496640205383, 0.2961299419403076, 0.28130075335502625, 0.27714940905570984]
=== ts_l ===
[0.5485677719116211, 0.47600966691970825, 0.33111822605133057, 0.5658555030822754, 0.4193665385246277,
0.1254565715789795, 0.33993926644325256, 0.3271455764770508, 0.18989351391792297, 0.18316617608070374,
0.32229477167129517, 0.21195293962955475, 0.30307191610336304, 0.1840662956237793, 0.44732925295829773,
0.2892676293849945, 0.3235018849372864, 0.28755465149879456, 0.1626671552658081, 0.2397872805595398]
=== epoch ===
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
=== vs_m ===
[0.26260727643966675, 0.2663346230983734]
=== vs_c ===
[0.282810240983963, 0.27834028005599976]
=== vs_l ===
[0.26260727643966675, 0.2663346230983734]
=== ve_m ===
[0.26260727643966675, 0.2663346230983734]
=== ve_c ===
[0.26260727643966675, 0.2663346230983734]
=== ve_l ===
[0.26260727643966675, 0.2663346230983734]
=== te_m ===
[0.35065189003944397, 0.27714940905570984]
=== te_c ===
[0.35065189003944397, 0.27714940905570984]
=== te_l ===
[0.35065191984176636, 0.27714937925338745]
trainでの挙動
***_step
ではmetric(true, pred)
で計算されたlossの値、***_epoch_end
ではcompute
で計算された値が記録されるようです。
***_epoch_end
でのcompute
の出力値(te_c
)はepochのloss全体の平均値(te_l
)と一致していることも確認できます。
validationでの挙動
stepでlogをしてもepoch終わりしか出力しないようです(引数でタイミングを変えることが可能だと思います)。また、stepの出力値はlogに渡した値をプロパティごとに平均した値になっています。
そのため、vs_m
は[loss_batch1, loss_batch2, ...]
を平均した値、ve_m
はmetric.compute()
となり、実質的に同じになります。
vs_c
は[metric.compute(), metric_compute(), ...]
の平均となり、それぞれ別のタイミングでcompute
を呼んでいるため、ve_c
の値(最終時点でcompute
を呼んだ値)と異なります。
これらより、log
にMetric
クラスを渡すと、ケースに応じて自動的に正しい値(step内ではそのバッチのloss、epoch終わりではlossの平均、validation時ではepoch終わりに平均をとる)を出力してくれるようです。
なお、分散学習時には***_step
での値はrank0のlossのみ出力し、データ転送を抑えているようです。***_epoch_end
では各プロセスのlossを平均した値が出力されます。
まとめ
pytorch-lightningのprogressbarが出力する内容について調べました。結果は以下です。
-
loss
の値はtrainingの全てのlossの平均である -
Metric
を渡すと、ケースに応じて自動的に正しい値(step内ではそのバッチのloss、epoch終わりではlossの平均、validation時ではepoch終わりに平均をとる)を出力する