pytorch-lighitningのLighitningModuleのライフサイクルで,WarmupCosineLRSchedulerを使う方法を紹介.
前提知識
schedulerについてはこちら
LighitningModuleとは
LighgningModuleでは, trainステップ, valステップを作成したり,学習におけるサイクルのhooksに任意の処理を追加するAPIが提供されています.
基本的な使い方はこちらの記事をどうぞ
実装
パッケージのインストール
timm
のCosineLRScheduler
を使います.
pytorch image modelsという画像認識用のライブラリです
pip install timm
timmのWarmup付き,CosineLRSchedulerクラス自体の使い方は次のサイトがわかりやすいです.
※lighitning-boltパッケージにも,WarmupCosineLRReducerがありますがアップデートが全くされておらず,最新のlighitningに対応していないので非推奨です.
LightningModuleに適用
step0: lightningにおけるoptimizer, schedulerの前提知識
LightningModule.configure_optimizer
関数で初期化できます.
# 例
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.models, lr)
scheduler = Scheduler(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch", # ステップ単位で更新する場合は,"step"
"frequency": 1,
},
}
schedulerの更新はLighitningModule.lr_scheduler_step
関数が設定した"interval"の間隔で呼ばれて行われます.デフォルトではschedulerのstep
関数が呼び出される仕組みになっています.
step1: Schedulerの継承
まず, CosineLRSchedulerを継承して,step_
関数を作成します.
CosineLRScheduler
は, step
関数で学習率の更新をしますが,引数として,ステップ数を入れる必要があります.
なので,内部でステップ数を管理して自動で加算するstep_
関数を作成します.
from timm.scheduler import CosineLRScheduler
class WarmupCosineLRScheduler(CosineLRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min: float = 0.0,
cycle_mul: float = 1.0,
cycle_decay: float = 1.0,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True,
) -> None:
super().__init__(
optimizer,
t_initial,
lr_min,
cycle_mul,
cycle_decay,
cycle_limit,
warmup_t,
warmup_lr_init,
warmup_prefix,
t_in_epochs,
noise_range_t,
noise_pct,
noise_std,
noise_seed,
k_decay,
initialize,
)
self.current_step = 0
def step_(self):
self.current_step += 1
self.step(self.current_step)
Step2: 初期化と更新
下が初期化と,更新の処理の実装です.
configure_optimizers
で初期化して,lr_scheduler_step
でstep_
を呼び出すようにしたら完了です.
class MyLighitningModule(pl.LightningModule):
def __init__(self, model, args, exp_name) -> None:
...
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.models, base_lr=1e-4
)
scheduler = WarmupCosineLRScheduler(
optimizer, t_initial=self.args.num_epochs, warmup_t=10, warmup_prefix=True
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
},
}
def lr_scheduler_step(self, scheduler, metric) -> None:
scheduler.step_()
...