LoginSignup
3
2

More than 1 year has passed since last update.

pytorch-lightningのprogressbarが出力する内容を調査する

Last updated at Posted at 2022-11-15

目的

pytorch-lightningのprogressbarで、出力が何の値であるかいまいち分からなかったので調査しました。以下の2点が対象です。

  • デフォルトで出力されるlossの値
  • torchmetricsのMetriclog()に渡した時の値

lossの値

これは2つの候補がありそうです。

  • training全体でのlossの平均値
  • epoch内でのlossの平均値

どちらか判定するため、以下のコードを実行します。

import time

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchmetrics

class ToyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
        self.all_loss = []
        self.epoch_loss = []
        self.train_metric = torchmetrics.MeanAbsoluteError()
        self.val_metric = torchmetrics.MeanAbsoluteError()

    def training_step(self, batch, batch_idx):
        data, true = batch
        pred = self.linear(data)
        loss = self.train_metric(true, pred)
        self.all_loss.append(loss.item())
        self.epoch_loss.append(loss.item())
        self.log("train_loss", self.train_metric)
        time.sleep(1)
        return loss

    def training_epoch_start(self, outs):
        self.epoch_loss = []

    def validation_step(self, batch, batch_idx):
        data, true = batch
        pred = self.linear(data)
        loss = self.val_metric(true, pred)
        self.log("val_loss", self.val_metric)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 0.01)

class ToyDataset(Dataset):
    def __init__(self):
        self.datas = [(torch.rand(10), torch.rand(1)) for _ in range(4*10)]

    def __getitem__(self, i):
        return self.datas[i]

    def __len__(self):
        return len(self.datas)

class MyProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, pl_module):
        items = super().get_metrics(trainer, pl_module)
        items["all_loss_mean"] = np.mean(pl_module.all_loss or float("nan"))
        items["epoch_loss_mean"] = np.mean(pl_module.epoch_loss or float("nan"))
        return items

def main():
    model = ToyModel()
    dataset = ToyDataset()
    callbacks = [MyProgressBar()]

    trainer = pl.Trainer(max_epochs=2, callbacks=callbacks, log_every_n_steps=1)
    trainer.fit(
        model,
        train_dataloaders=DataLoader(dataset, batch_size=4),
        val_dataloaders=DataLoader(dataset, batch_size=4)
    )

if __name__ == "__main__":
    main()

LightningModuleにtraining全体のlossを保持しておくリスト(all_loss)とepoch内のlossを保持しておくリスト(epoch_loss)を設定しました。それらをprogressbarに表示する、MyProgressBarのcallbackを作成しています。
time.sleep(1)は確認用です。

これを実行します。epoch-0とepoch-1でのprogressbarの出力を切り出したものが以下になります。

Epoch 0:  40%|████      | 8/20 [00:08<00:12,  1.01s/it, loss=0.803, v_num=3, all_loss_mean=0.803, epoch_loss_mean=0.803]
Epoch 1:  40%|████      | 8/20 [00:08<00:12,  1.00s/it, loss=0.524, v_num=1, all_loss_mean=0.524, epoch_loss_mean=0.316]

epoch-0では全て同じ値ですが、epoch-1ではlossepoch_loss_meanとは異なっています。
つまり、lossが表示しているものはepoch内のloss平均ではなく、trainnig全体でのlossの平均になります。
ソースコードを見ても

running_train_loss = trainer.fit_loop.running_loss.mean()

となっており、全体の平均であることが伺えます。

torchmetricsのMetriclog()に渡した時の値

torchmetricsのMetricではmetric(true, pred)でlossの値が計算され、クラス内に保存されます。何度かlossを計算したのち、computeを呼ぶとそれまでの平均などを計算することができます。

MetricはpytorchlightningのTrainer.logに渡すことが可能です。このとき、どの値が出力されるのでしょうか?以下の関数内で呼んだ時の値を確認します。

  • training_step
  • training_epoch_end
  • validation_step
  • validation_epoch_end

確認用のコードは以下になります。(変更点のみ記載)

class ToyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)
        self.train_metric = torchmetrics.MeanAbsoluteError()
        self.val_metric = torchmetrics.MeanAbsoluteError()

    def _get_loss_dict(self, metric, loss, prefix):
        metrics = {
            f"{prefix}_m": metric,
            f"{prefix}_c": metric.compute(),
            f"{prefix}_l": loss
        }
        return metrics

    def training_step(self, batch, batch_idx):
        data, true = batch
        pred = self.linear(data)
        loss = self.train_metric(true, pred)
        metrics = self._get_loss_dict(self.train_metric, loss, "ts")
        self.log_dict(metrics, prog_bar=True)
        time.sleep(1)
        return loss

    def training_epoch_end(self, outs):
        loss = torch.stack([i["loss"] for i in outs]).mean()
        metrics = self._get_loss_dict(self.train_metric, loss, "te")
        self.log_dict(metrics, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        data, true = batch
        pred = self.linear(data)
        loss = self.val_metric(true, pred)
        metrics = self._get_loss_dict(self.val_metric, loss, "vs")
        self.log_dict(metrics, prog_bar=True)
        time.sleep(1)
        return loss

    def validation_epoch_end(self, outs):
        loss = torch.stack(outs).mean()
        metrics = self._get_loss_dict(self.val_metric, loss, "ve")
        self.log_dict(metrics, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 0.01)

class MyProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, pl_module):
        items = super().get_metrics(trainer, pl_module)
        for prop in ["v_num", "loss"]:
            items.pop(prop)
        return items

これを実行すると以下になります。

Epoch 1:  20%|███▍             | 4/20 [00:04<00:16,  1.01s/it, ts_m=0.184, ts_c=0.255, ts_l=0.184, vs_m=0.263, vs_c=0.283, vs_l=0.263, ve_m=0.263, ve_c=0.263, ve_l=0.263, te_m=0.351, te_c=0.351, te_l=0.351]

今度はやや分かりにくいのでtensorboadの出力も見てみます。コードは省略しますが、各lossは以下になります。

=== ts_m ===
[0.5485677719116211, 0.47600966691970825, 0.33111822605133057, 0.5658555030822754, 0.4193665385246277, 
0.1254565715789795, 0.33993926644325256, 0.3271455764770508, 0.18989351391792297, 0.18316617608070374, 
0.32229477167129517, 0.21195293962955475, 0.30307191610336304, 0.1840662956237793, 0.44732925295829773, 
0.2892676293849945, 0.3235018849372864, 0.28755465149879456, 0.1626671552658081, 0.2397872805595398]
=== ts_c ===
[0.5485677719116211, 0.5122886896133423, 0.4518985450267792, 0.48038777709007263, 0.4681835174560547, 
0.4110623598098755, 0.4009019434452057, 0.39168238639831543, 0.36926138401031494, 0.35065189003944397, 
0.32229477167129517, 0.26712384819984436, 0.2791065275669098, 0.25534647703170776, 0.2937430441379547, 
0.2929971516132355, 0.29735496640205383, 0.2961299419403076, 0.28130075335502625, 0.27714940905570984]
=== ts_l ===
[0.5485677719116211, 0.47600966691970825, 0.33111822605133057, 0.5658555030822754, 0.4193665385246277, 
0.1254565715789795, 0.33993926644325256, 0.3271455764770508, 0.18989351391792297, 0.18316617608070374, 
0.32229477167129517, 0.21195293962955475, 0.30307191610336304, 0.1840662956237793, 0.44732925295829773, 
0.2892676293849945, 0.3235018849372864, 0.28755465149879456, 0.1626671552658081, 0.2397872805595398]
=== epoch ===
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
=== vs_m ===
[0.26260727643966675, 0.2663346230983734]
=== vs_c ===
[0.282810240983963, 0.27834028005599976]
=== vs_l ===
[0.26260727643966675, 0.2663346230983734]
=== ve_m ===
[0.26260727643966675, 0.2663346230983734]
=== ve_c ===
[0.26260727643966675, 0.2663346230983734]
=== ve_l ===
[0.26260727643966675, 0.2663346230983734]
=== te_m ===
[0.35065189003944397, 0.27714940905570984]
=== te_c ===
[0.35065189003944397, 0.27714940905570984]
=== te_l ===
[0.35065191984176636, 0.27714937925338745]

trainでの挙動

***_stepではmetric(true, pred)で計算されたlossの値、***_epoch_endではcomputeで計算された値が記録されるようです。
***_epoch_endでのcomputeの出力値(te_c)はepochのloss全体の平均値(te_l)と一致していることも確認できます。

validationでの挙動

stepでlogをしてもepoch終わりしか出力しないようです(引数でタイミングを変えることが可能だと思います)。また、stepの出力値はlogに渡した値をプロパティごとに平均した値になっています。
そのため、vs_m[loss_batch1, loss_batch2, ...]を平均した値、ve_mmetric.compute()となり、実質的に同じになります。
vs_c[metric.compute(), metric_compute(), ...]の平均となり、それぞれ別のタイミングでcomputeを呼んでいるため、ve_cの値(最終時点でcomputeを呼んだ値)と異なります。

これらより、logMetricクラスを渡すと、ケースに応じて自動的に正しい値(step内ではそのバッチのloss、epoch終わりではlossの平均、validation時ではepoch終わりに平均をとる)を出力してくれるようです。

なお、分散学習時には***_stepでの値はrank0のlossのみ出力し、データ転送を抑えているようです。***_epoch_endでは各プロセスのlossを平均した値が出力されます。

まとめ

pytorch-lightningのprogressbarが出力する内容について調べました。結果は以下です。

  • lossの値はtrainingの全てのlossの平均である
  • Metricを渡すと、ケースに応じて自動的に正しい値(step内ではそのバッチのloss、epoch終わりではlossの平均、validation時ではepoch終わりに平均をとる)を出力する
3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2