LoginSignup
10
7

More than 5 years have passed since last update.

Chainerで学習率のスケジューリングをする方法

Last updated at Posted at 2018-08-02

学習率のスケジューリング

Deep Learningで学習を進めるためには、学習率の減衰(learning rate decay)が必要です。Chainer(v4)ではTrainerExtensionで学習率の変更ができます。ExtensionについてはChainer v4 ビギナー向けチュートリアルに詳しく説明されています。

Optimizerのパラメータを変更するExtensions

学習率lrに限らず、Optimizerのパラメータを学習途中で変更することができます。ExponentialShiftLinearShiftがあり、ExponentialShiftはパラメータを指数関数的に変化させ、LinearShiftはパラメータを線形に変化させることができます。今回はこのうちのExponentialShiftを使ってみます。

triggerについて

Chainer v4 ビギナー向けチュートリアルでは学習率を変更する方法として、


trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=(30, 'epoch'))

という指定をしており、これは『30epochごとに学習率を0.1倍する』という指定です。ここでポイントとなるのはtriggerオブジェクトです。triggerExtensionの実行タイミングを指定するものです。デフォルトでは、triggerをタプルで指定すると、IntervalTriggerというtriggerオブジェクトにタプルが渡され、一定期間ごとに呼び出されます。triggerの一覧は以下のドキュメントの一番下に書いてあります。

ManualScheduleTriggerを使ってみる

IntervalTriggerでも良いときは多いのですが、例えば学習の50%, 75%, 90%のタイミングで学習率を変更するなどの指定はできません。このような場合はManualScheduleTriggerを用います。例えば以下のようにExtensionを書いてみます。


trainer.extend(extensions.ExponentialShift('lr', 0.1),   
               trigger=triggers.ManualScheduleTrigger([1,3,6],'epoch'))

このように書くと、1,3,6 epoch目で学習率が0.1倍されます。適当に学習をさせてみると次のようになりました。

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  lr          elapsed_time
0           2.02627                           0.262109                                 0.01        5.88272       
1           0.988403    1.34453               0.655859       0.544899                  0.001       89.8289       
2           0.780562    0.840948              0.726953       0.704239                  0.001       178.381       
3           0.795094    0.833529              0.723437       0.707273                  0.0001      266.596       
4           0.730944    0.792991              0.74375        0.718957                  0.0001      362.734       
5           0.708854    0.786186              0.761328       0.722208                  0.0001      455.003       
6           0.708609    0.784013              0.748438       0.72531                   1e-05       552.303       

ちゃんと指定したepochで学習率が下がっています。'epoch''iteration'とすることもできます。詳しくはDocumentを読んでください。

ちなみにlrの変化を見るには


trainer.extend(extensions.observe_lr(), trigger=(1, 'iteration'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss',
                                       'main/accuracy', 'validation/main/accuracy',
                                       'lr', 'elapsed_time']), trigger=(1, 'iteration'))

のようにextensions.observe_lr()を実行し、extensions.PrintReportの引数にlrを付け加えると良いです。

というわけで学習の50%, 75%, 90%のタイミングで学習率を変更するには


epoch = 100
points = [epoch*0.5, epoch*0.75, epoch*0.9]
trainer.extend(extensions.ExponentialShift('lr', 0.1), 
               trigger=triggers.ManualScheduleTrigger(points, 'epoch'))

とします。pointsfloatも受け取れます。例えば、


points = 37.5

37 epoch目の50%の段階でExponentialShiftが実行されます。

参考

10
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
7