12
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Chainer の extension ~精度更新時のモデルの保存・学習率の変化のさせ方~

Last updated at Posted at 2018-11-29

はじめに

よく使いそうなのに意外とサンプルが見つからない Chainer (Trainer) の extension の使い方を2つ載せておきます(要素要素はすべてドキュメントに書いてありますが、まとまった形のサンプルが意外と見つからなかったので...)。

  • Test-accuracy が更新されたときにモデルを保存する
  • 学習率を任意に変化させる

Chainer v4.0.0/v5.0.0 で確認しています。

Test-accuracy が更新されたときにモデルを保存する

MaxValueTrigger を使って、accuracy が更新されたときに snapshot_object extension を呼び出す trigger を作ってあげます。

下の例では、1エポックごとに、現時点で最も accuracy の高いモデルを best_model というファイル名で保存します。

from chainer.training import extensions, triggers

trigger = triggers.MaxValueTrigger('validation/main/accuracy', trigger=(1, 'epoch'))
trainer.extend(extensions.snapshot_object(model, filename='best_model'), trigger=trigger)

学習率を任意に変化させる

下の例では、1エポックごとにcosine関数を使って学習率を変化させていますが、__calc_cosine_lr()を適当に置き換えれば学習率を任意に制御できます。

from chainer.training import extension

class LrSceduler(extension.Extension):

    trigger = (1, 'epoch')

    def __init__(self, base_lr, epochs, optimizer_name='main', lr_name='lr'):
        self._base_lr = base_lr
        self._epochs = epochs
        self._optimizer_name = optimizer_name
        self._lr_name = lr_name

    def __call__(self, trainer):
        optimizer = trainer.updater.get_optimizer(self._optimizer_name)
        lr = self.__calc_cosine_lr(trainer)
        setattr(optimizer, self._lr_name, lr)

    def __calc_cosine_lr(self, trainer):
        import math
        e = trainer.updater.epoch
        #iter = trainer.updater.iteration # If you want to use current total iterations to compute lr.
        lr = 0.5 * self._base_lr * (math.cos(math.pi * e / self._epochs) + 1.)
	
        return lr

trainer.extend(LrSceduler(base_lr=0.01, epochs=100))

なお、定番の、一定エポック学習した後にステップ状に学習率を変化させるなら、
ExponentialShift extension を使うのが良いです。下は、30エポック目で学習率を0.1倍に、60エポック学習した後にさらに0.1倍にする例です。

from chainer.training import extensions, triggers

trigger = triggers.ManualScheduleTrigger([30, 60], 'epoch')
trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=trigger)

おわりに

また見つかったら追記します。

12
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?