1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

pl.LightningModuleでWarmup Cosine Decay Schedulerをステップ単位で更新

Last updated at Posted at 2024-06-20

pytorch-lightningの学習サイクルで,WarmupCosineLRSchedulerを使った学習率設定を紹介.
ここでは,エポック単位ではなくてステップ単位の更新を説明します.

準備

timmのインストール

timmCosineLRSchedulerを使います.
pytorch image modelsという画像認識用のライブラリです

pip install timm

timmのWarmup付き,CosineLRSchedulerクラス自体の使い方は次のサイトがわかりやすいです.

※lightning-boltパッケージにも,WarmupCosineLRReducerがありますがアップデートが全くされておらず,最新のlightningに対応していないので非推奨です.

実装: pl.LightningModuleで使用

Scheduler インスタンスの生成とステップ単位の更新

pl.configure_optimizersメソッドでschedulerのインスタンスを生成して,pl.lr_scheduler_stepstepメソッドを呼び出すようにしたら完了です.

from timm.scheduler import CosineLRScheduler

class MyLightningModule(pl.LightningModule):
    def __init__(self, model, num_epoch, exp_name) -> None:
        ...
        
        num_dataloader = len(dataloader) # 1epochあたりのバッチ数
        self.total_steps = num_epoch * num_dataloader # 合計step数
        self.current_step = 0

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.models, base_lr=1e-4
        )

        warmup_rate = 0.1
        warmup_t = int(self.total_steps * warmup_rate) # warmup step数
        scheduler = CosineLRScheduler(
            optimizer, t_initial=self.total_steps, warmup_t=warmup_t, warmup_prefix=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                "frequency": 1,
            }
        }

    def lr_scheduler_step(self, scheduler, metric) -> None:
        self.current_step += 1
        scheduler.step(self.current_step)

    ...

設定しているCosineLRSchedulerの引数

  • t_initial: 総ステップ数 (本来の意味は最初のサイクルのステップ数)
  • warmup_t: warmupに使うstep数
  • warmup_prefix: Trueなら,warmup終了時の学習率がOptimizerに設定した値になります

詳細はこちらから

また,返り値のdictではschedulerの更新タイミングや頻度を設定することができます.

"lr_scheduler": {
    "scheduler": scheduler,
    "interval": "step",
    "frequency": 1,
}

今回はstep単位で学習率を更新したいため,"interval": "step"にします.デフォルトは"epoch"なので,この指定は必ず必要です.

最後に,stepメソッドの呼び出しは,LightningModule.lr_scheduler_stepをオーバーライドして実装します.(PytorchのLRSchedulerlightning-bolt以外のSchedulerを使う場合はオーバーライドしてstepの処理を書く必要があります.)
この関数はschedulerの更新のタイミングで呼ばれるようになっています.

関連

schedulerについてはこちら

LightningModuleとは

LighgningModuleでは, trainステップ, valステップを作成したり,学習におけるサイクルのhooksに任意の処理を追加するAPIが提供されています.
基本的な使い方はこちらの記事をどうぞ

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?