はじめに
このteratailにあるような現象が発生したため、chainer.serializers.save_npz
を使いtrainerを保存するように変更したという話です。
問題事項
※実行環境はChainer v5.4でした。 現在のバージョンでは解決されているかもしれません
以下のようにextensions.snapshot
を呼び出してtrainerを保存すると、なぜかchainer.serializers.load_npz
で保存したtrainerをロードできないときがありました。
trainer.extend(extensions.snapshot(filename='snapshot_latest'), trigger=(args.snapshot_interval, 'iteration'))
毎回ロードできないのならまだ分かるのですが、場合によって呼び出せたり呼び出せなかったりするという現象が起きました。
で、原因はextensions.snapshot
かもしれないとのことだったので、解決策になるかどうかはわかりませんが、自前でchainer.serializers.save_npz
を呼び出してtrainerのsnapshotをとることにしました。
実装
ポイントはtrainer
が各epochやiterationに合わせた様々値を持っている点でしょうか。
せっかくなので呼び出して有効活用してみます。
import chainer
class SnapshotTrainer(chainer.training.Extension):
def __init__(self, out, fo_name='snapshot', save_type='latest_only'):
self.out = out.rstrip('/') + '/' + fo_name
self.save_type = save_type
def __call__(self, trainer):
if self.save_type == 'epoch':
fo = '{}_epoch_{}'.format(self.out, trainer.updater.epoch)
elif self.save_type == 'iteration':
fo = '{}_iter_{}'.format(self.out, trainer.updater.iteration)
else:
fo = '{}_latest'.format(self.out)
chainer.serializers.save_npz(fo, trainer)
↑のクラスをtrainer.extendで呼び出します。
# Take a snapshot (original)
snapshot_type = 'epoch' if args.snapshot_dvide else 'latest'
trainer.extend(SnapshotTrainer(out=args.out, fo_name='snapshot', save_type=snapshot_type), trigger=(args.snapshot_interval, 'iteration'))
これで解決できている・・・のかなぁ?(今のところ問題なしですが、果たして・・・
まとめ
今回はsnapshotをchainer.serializers.save_npz
で保存するExtensionを作りました。
これならsnapshot以外にも色々なExtensionを作れそうです。