LoginSignup
0
0

More than 1 year has passed since last update.

Pytorch LightningのTrainer.tuneはLightningDataModuleにも適用される

Posted at

疑問

Pytorch LightningのTrainerには色々なParametersが存在するが、その中に以下のようなものがある

  • auto_scale_batch_size
    If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search. Default: False

これはTrainer.tuneで適切なbatch_sizeを自動で決定してくれるようにするオプションだが

The result will be stored in self.batch_size in the LightningModule.

という説明を見る限りではLightningModuleのself.batch_sizeにしか適用されないように読めるため、DataLoaderLightningDataModuleなどに切り出しているときにどのように利用すれば良いのか分からなかった。

Trainer.tuneの定義は以下のようになっており、datamoduleを渡せるのでLightningDataModuleにも適用されそうな気はするがよく分からない。という感じだった。

Trainer.tune(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, scale_batch_size_kwargs=None, lr_find_kwargs=None)
 

  • Parameters
    • model (LightningModule) – Model to tune.
    • train_dataloaders (Union[DataLoader, Sequence[DataLoader], Sequence[Sequence[DataLoader]], Sequence[Dict[str, DataLoader]], Dict[str, DataLoader], Dict[str, Dict[str, DataLoader]], Dict[str, Sequence[DataLoader]], LightningDataModule, None]) – A collection of torch.utils.data.DataLoader or a LightningDataModule specifying training samples. In the case of multiple dataloaders, please see this section.
    • val_dataloaders (Union[DataLoader, Sequence[DataLoader], None]) – A torch.utils.data.DataLoader or a sequence of them specifying validation samples.
    • datamodule (Optional[LightningDataModule]) – An instance of LightningDataModule.
    • scale_batch_size_kwargs (Optional[Dict[str, Any]]) – Arguments for scale_batch_size()
    • lr_find_kwargs (Optional[Dict[str, Any]]) – Arguments for lr_find()

結論

親切な方からTwitterで、別のページにはLDMでもいけると書いてある、という情報を教えていただいた。

This feature expects that a batch_size field is either located as a model attribute i.e. model.batch_size or as a field in your hparams i.e. model.hparams.batch_size. Similarly it can work with datamodules too. The field should exist and will be updated by the results of this algorithm. Additionally, your train_dataloader() method should depend on this field for this feature to work i.e.

# using LightningModule
class LitModel(LightningModule):
    def __init__(self, batch_size):
        super().__init__()
        self.save_hyperparameters()
        # or
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)


trainer = Trainer(...)
model = LitModel(batch_size=32)
trainer.tune(model)

# using LightningDataModule
class LitDataModule(LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.save_hyperparameters()
        # or
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)


trainer = Trainer(...)
model = MyModel()
datamodule = LitDataModule(batch_size=32)
trainer.tune(model, datamodule=datamodule)

Note that the train_dataloader can be either part of the LightningModule or LightningDataModule as shown above. If both the LightningModule and the LightningDataModule contain a train_dataloader, the LightningDataModule takes precedence.

実際に試してみたところ、たしかにLightningDataModuleself.batch_sizeも更新されているようだった。

オチ

LightningDataModuleでもTrainer.tuneが使えるのが分かったが、結局、

  • Trainer.tuneの実行に思っていたよりも時間がかかる
  • 実際にTrainer.fit時に使用するとOOMするbatch_sizeがtune時にはsucceeded判定される
    (実際はbatch_size=16が限界なのに、tune時はなぜかbatch_size=128でもsucceededになっていた)

ということがあり利用は見送ってしまった。

tuneの結果待つ時間よりも適当なbatch_sizeで実際にOOM出るか試す人力batch_size tuningのほうが全然速いので、 日々蓄積されていくデータを定期的に自動で学習し直す、など学習の実行に人手が介入しないケース以外ではあまり嬉しさは無さそうな気がする。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0