0
0

pytorch-lighitningのLighitningModuleのライフサイクルで,WarmupCosineLRSchedulerを使う方法を紹介.

前提知識

schedulerについてはこちら

LighitningModuleとは

LighgningModuleでは, trainステップ, valステップを作成したり,学習におけるサイクルのhooksに任意の処理を追加するAPIが提供されています.
基本的な使い方はこちらの記事をどうぞ

実装

パッケージのインストール

timmCosineLRSchedulerを使います.
pytorch image modelsという画像認識用のライブラリです

pip install timm

timmのWarmup付き,CosineLRSchedulerクラス自体の使い方は次のサイトがわかりやすいです.

※lighitning-boltパッケージにも,WarmupCosineLRReducerがありますがアップデートが全くされておらず,最新のlighitningに対応していないので非推奨です.

LightningModuleに適用

step0: lightningにおけるoptimizer, schedulerの前提知識

LightningModule.configure_optimizer関数で初期化できます.

# 例
def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.models, lr)
        scheduler = Scheduler(optimizer)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch", # ステップ単位で更新する場合は,"step"
                "frequency": 1,
            },
        }

schedulerの更新はLighitningModule.lr_scheduler_step関数が設定した"interval"の間隔で呼ばれて行われます.デフォルトではschedulerのstep関数が呼び出される仕組みになっています.

step1: Schedulerの継承

まず, CosineLRSchedulerを継承して,step_関数を作成します.
CosineLRSchedulerは, step関数で学習率の更新をしますが,引数として,ステップ数を入れる必要があります.
なので,内部でステップ数を管理して自動で加算するstep_関数を作成します.

from timm.scheduler import CosineLRScheduler

class WarmupCosineLRScheduler(CosineLRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        t_initial: int,
        lr_min: float = 0.0,
        cycle_mul: float = 1.0,
        cycle_decay: float = 1.0,
        cycle_limit: int = 1,
        warmup_t=0,
        warmup_lr_init=0,
        warmup_prefix=False,
        t_in_epochs=True,
        noise_range_t=None,
        noise_pct=0.67,
        noise_std=1.0,
        noise_seed=42,
        k_decay=1.0,
        initialize=True,
    ) -> None:
        super().__init__(
            optimizer,
            t_initial,
            lr_min,
            cycle_mul,
            cycle_decay,
            cycle_limit,
            warmup_t,
            warmup_lr_init,
            warmup_prefix,
            t_in_epochs,
            noise_range_t,
            noise_pct,
            noise_std,
            noise_seed,
            k_decay,
            initialize,
        )
        self.current_step = 0

    def step_(self):
        self.current_step += 1
        self.step(self.current_step)

Step2: 初期化と更新

下が初期化と,更新の処理の実装です.
configure_optimizersで初期化して,lr_scheduler_stepstep_を呼び出すようにしたら完了です.

class MyLighitningModule(pl.LightningModule):
    def __init__(self, model, args, exp_name) -> None:
        ...

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.models, base_lr=1e-4
        )
        scheduler = WarmupCosineLRScheduler(
            optimizer, t_initial=self.args.num_epochs, warmup_t=10, warmup_prefix=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def lr_scheduler_step(self, scheduler, metric) -> None:
        scheduler.step_()

    ...

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0