tl; dr
-
extensions.MinValueTrigger
とスナップショットを組み合わせると、chainerで簡単にearly stoppingが実現できる。 - 実装はgistにて公開中
Edit note
2019/10/30追記
はじめに
Deep learningにおいてearly stoppingは簡単ながら強力で、ほぼどんな問題設定においても威力を発揮します。かのHintonも、Early stoppingを
Beautiful FREE LUNCH
Hinton, Bengio and LeCun, in NIPS 2015 Tutorial
と述べています。
そんなEarly stoppingですが、Chainerでは3.0.0時点で公式の実装が存在しません。現在、プルリクで実装が議論されています。 実装されています。このプルリクでは、early stoppingをtriggerとして実装しており、patientで指定した回数だけまって、それでも値が改善にtrainerを終了させます。素直な実装ですが、この実装には2つ欠点があります。
- patientを上げ過ぎると、patientで指定した分だけ過学習がすすんでしまう。
- patientの指定が難しい(もっとまてば良い値が得られたかもしれない)
(Contributorの方を批判する意図はありません。後ほど述べますが、ここに記載の実装は、このプルリクと合わせるといいとこどりになります。念の為。)
そこで、この記事では上記2つの問題を迂回しつつ、最小のコードですむearly stoppingについて記載します。
アプローチと実装
基本的な考え方は、監視している値が最良となるたびにモデルを保存しておき、指定のエポック分学習したあとに最良だったときのモデルを読みだすという方法です。こうすることで、指定したエポックの中では必ず監視した値が最良だったときのモデルを得ることができます。
実装はgistにて公開中です。この実装を使うと、
# accuracyなどを使う場合はMinValueTriggerのかわりにMaxValueTriggerを使う
trainer.extend(
SaveRestore(),
trigger=chainer.training.triggers.MinValueTrigger('validation/main/loss'))
と指定するだけ、 上記にあげた1, 2の問題を迂回して、early stoppingを実現できます。
欠点と解決方法
このアプローチにも、結局最後まで学習を行うので学習時間を短くする効果はない、という弱点があります。ただし、この方法とこのプルリクを組み合わせ、かつpatientの値を緩めにすることによって、最初にあげた1と2の問題を避ける、いいとこどりができます。
2019/10/30追記:
しばらくの間上記実装を使い続けていましたが、近年は
trainer.extend(
chainer.training.extensions.snapshot(filename='best.npz'),
trigger=chainer.training.triggers.MinValueTrigger('validation/main/loss'))
trainer.run()
serializers.load_npz('best.npz', trainer)
のような実装を使う事が多いです。
SaveRestore()
は自動的に不要となったtrainerを消してくれるメリットがありますが、trainer.run()
が1エポック目にこけたときにpdbなどでエラーを追いにくくなるデメリットがあるためです。