疑問
Pytorch LightningのTrainerには色々なParametersが存在するが、その中に以下のようなものがある
- auto_scale_batch_size
If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search. Default: False
これはTrainer.tune
で適切なbatch_sizeを自動で決定してくれるようにするオプションだが
The result will be stored in self.batch_size in the LightningModule.
という説明を見る限りではLightningModuleのself.batch_size
にしか適用されないように読めるため、DataLoader
をLightningDataModule
などに切り出しているときにどのように利用すれば良いのか分からなかった。
Trainer.tune
の定義は以下のようになっており、datamodule
を渡せるのでLightningDataModule
にも適用されそうな気はするがよく分からない。という感じだった。
Trainer.tune(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, scale_batch_size_kwargs=None, lr_find_kwargs=None)
- Parameters
- model (LightningModule) – Model to tune.
- train_dataloaders (Union[DataLoader, Sequence[DataLoader], Sequence[Sequence[DataLoader]], Sequence[Dict[str, DataLoader]], Dict[str, DataLoader], Dict[str, Dict[str, DataLoader]], Dict[str, Sequence[DataLoader]], LightningDataModule, None]) – A collection of torch.utils.data.DataLoader or a LightningDataModule specifying training samples. In the case of multiple dataloaders, please see this section.
- val_dataloaders (Union[DataLoader, Sequence[DataLoader], None]) – A torch.utils.data.DataLoader or a sequence of them specifying validation samples.
- datamodule (Optional[LightningDataModule]) – An instance of LightningDataModule.
- scale_batch_size_kwargs (Optional[Dict[str, Any]]) – Arguments for scale_batch_size()
- lr_find_kwargs (Optional[Dict[str, Any]]) – Arguments for lr_find()
結論
親切な方からTwitterで、別のページにはLDMでもいけると書いてある、という情報を教えていただいた。
あまり使ったことないですが、おそらく LM でも LDM でも両方いけます。https://t.co/EyWBKAZhOz
— Akihiro Nitta (@aki_bayes) June 15, 2022
This feature expects that a batch_size field is either located as a model attribute i.e. model.batch_size or as a field in your hparams i.e. model.hparams.batch_size. Similarly it can work with datamodules too. The field should exist and will be updated by the results of this algorithm. Additionally, your train_dataloader() method should depend on this field for this feature to work i.e.
# using LightningModule class LitModel(LightningModule): def __init__(self, batch_size): super().__init__() self.save_hyperparameters() # or self.batch_size = batch_size def train_dataloader(self): return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size) trainer = Trainer(...) model = LitModel(batch_size=32) trainer.tune(model) # using LightningDataModule class LitDataModule(LightningDataModule): def __init__(self, batch_size): super().__init__() self.save_hyperparameters() # or self.batch_size = batch_size def train_dataloader(self): return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size) trainer = Trainer(...) model = MyModel() datamodule = LitDataModule(batch_size=32) trainer.tune(model, datamodule=datamodule)
Note that the train_dataloader can be either part of the LightningModule or LightningDataModule as shown above. If both the LightningModule and the LightningDataModule contain a train_dataloader, the LightningDataModule takes precedence.
実際に試してみたところ、たしかにLightningDataModule
のself.batch_size
も更新されているようだった。
オチ
LightningDataModule
でもTrainer.tune
が使えるのが分かったが、結局、
-
Trainer.tune
の実行に思っていたよりも時間がかかる - 実際に
Trainer.fit
時に使用するとOOMするbatch_sizeがtune時にはsucceeded
判定される
(実際はbatch_size=16
が限界なのに、tune時はなぜかbatch_size=128
でもsucceeded
になっていた)
ということがあり利用は見送ってしまった。
tuneの結果待つ時間よりも適当なbatch_sizeで実際にOOM出るか試す人力batch_size tuningのほうが全然速いので、 日々蓄積されていくデータを定期的に自動で学習し直す、など学習の実行に人手が介入しないケース以外ではあまり嬉しさは無さそうな気がする。