学習率スケジューラごとの学習率の推移をプロットするだけのスクリプトです。
参考文献: https://docs.pytorch.org/docs/2.11/optim.html#how-to-adjust-learning-rate
plot.py
import torch
import matplotlib.pyplot as plt
def get_lrs(lr, n, sch_cls, sch_params):
opt = torch.optim.SGD(torch.nn.Linear(3, 1).parameters(), lr=lr)
sch = sch_cls(opt, **sch_params)
lrs = [lr]
for epoch in range(n - 1):
opt.step()
sch.step()
lrs.append(sch.get_last_lr()[0])
return lrs
def plot_lrs():
lr = 1.0
n = 50
desc = {}
fig, ax = plt.subplots()
for sch_cls, sch_params in [
(torch.optim.lr_scheduler.CosineAnnealingLR, {'T_max': 20}),
(torch.optim.lr_scheduler.ExponentialLR, {'gamma': 0.95}),
(torch.optim.lr_scheduler.MultiStepLR, {'milestones': [10, 40], 'gamma': 0.25}),
(torch.optim.lr_scheduler.StepLR, {'step_size': 5, 'gamma': 0.8}),
(torch.optim.lr_scheduler.LambdaLR, {'lr_lambda': lambda epoch: 0.9 ** epoch}),
]:
value_to_str = lambda v: '<function>' if callable(v) else str(v)
params_to_str = lambda params: ', '.join([f'{k}={value_to_str(v)}' for k, v in params.items()])
key = f'{sch_cls.__name__}({params_to_str(sch_params)})'
lrs = get_lrs(lr, n, sch_cls, sch_params)
desc[key] = ax.plot(range(n), lrs, linewidth=3)[0]
ax.set_xlabel('epoch')
ax.set_ylabel('lr')
plt.grid(axis='both', linestyle='dotted')
plt.legend(desc.values(), desc.keys(), bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
fig.savefig('plot.png')
if __name__ == '__main__':
with plt.rc_context({'figure.figsize': (10, 3), 'font.size': 13}):
plot_lrs()
