学習率のスケジューリング
Deep Learningで学習を進めるためには、学習率の減衰(learning rate decay)が必要です。Chainer(v4)ではTrainer
のExtension
で学習率の変更ができます。Extension
についてはChainer v4 ビギナー向けチュートリアルに詳しく説明されています。
Optimizerのパラメータを変更するExtensions
学習率lr
に限らず、Optimizerのパラメータを学習途中で変更することができます。ExponentialShift
とLinearShift
があり、ExponentialShift
はパラメータを指数関数的に変化させ、LinearShift
はパラメータを線形に変化させることができます。今回はこのうちのExponentialShift
を使ってみます。
triggerについて
Chainer v4 ビギナー向けチュートリアルでは学習率を変更する方法として、
trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=(30, 'epoch'))
という指定をしており、これは『30epochごとに学習率を0.1倍する』という指定です。ここでポイントとなるのはtrigger
オブジェクトです。trigger
はExtension
の実行タイミングを指定するものです。デフォルトでは、trigger
をタプルで指定すると、IntervalTrigger
というtrigger
オブジェクトにタプルが渡され、一定期間ごとに呼び出されます。trigger
の一覧は以下のドキュメントの一番下に書いてあります。
ManualScheduleTriggerを使ってみる
IntervalTrigger
でも良いときは多いのですが、例えば学習の50%, 75%, 90%のタイミングで学習率を変更するなどの指定はできません。このような場合はManualScheduleTrigger
を用います。例えば以下のようにExtension
を書いてみます。
trainer.extend(extensions.ExponentialShift('lr', 0.1),
trigger=triggers.ManualScheduleTrigger([1,3,6],'epoch'))
このように書くと、1,3,6 epoch目で学習率が0.1倍されます。適当に学習をさせてみると次のようになりました。
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy lr elapsed_time
0 2.02627 0.262109 0.01 5.88272
1 0.988403 1.34453 0.655859 0.544899 0.001 89.8289
2 0.780562 0.840948 0.726953 0.704239 0.001 178.381
3 0.795094 0.833529 0.723437 0.707273 0.0001 266.596
4 0.730944 0.792991 0.74375 0.718957 0.0001 362.734
5 0.708854 0.786186 0.761328 0.722208 0.0001 455.003
6 0.708609 0.784013 0.748438 0.72531 1e-05 552.303
ちゃんと指定したepochで学習率が下がっています。'epoch'
は'iteration'
とすることもできます。詳しくはDocumentを読んでください。
ちなみにlr
の変化を見るには
trainer.extend(extensions.observe_lr(), trigger=(1, 'iteration'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy',
'lr', 'elapsed_time']), trigger=(1, 'iteration'))
のようにextensions.observe_lr()
を実行し、extensions.PrintReport
の引数にlr
を付け加えると良いです。
というわけで学習の50%, 75%, 90%のタイミングで学習率を変更するには
epoch = 100
points = [epoch*0.5, epoch*0.75, epoch*0.9]
trainer.extend(extensions.ExponentialShift('lr', 0.1),
trigger=triggers.ManualScheduleTrigger(points, 'epoch'))
とします。points
はfloat
も受け取れます。例えば、
points = 37.5
37 epoch目の50%の段階でExponentialShift
が実行されます。