LoginSignup
0
2

More than 3 years have passed since last update.

Chainerでオリジナルのsnapshotを作る

Posted at

はじめに

このteratailにあるような現象が発生したため、chainer.serializers.save_npzを使いtrainerを保存するように変更したという話です。

問題事項

※実行環境はChainer v5.4でした。 現在のバージョンでは解決されているかもしれません

以下のようにextensions.snapshotを呼び出してtrainerを保存すると、なぜかchainer.serializers.load_npzで保存したtrainerをロードできないときがありました。

trainer.extend(extensions.snapshot(filename='snapshot_latest'), trigger=(args.snapshot_interval, 'iteration'))

毎回ロードできないのならまだ分かるのですが、場合によって呼び出せたり呼び出せなかったりするという現象が起きました。

で、原因はextensions.snapshotかもしれないとのことだったので、解決策になるかどうかはわかりませんが、自前でchainer.serializers.save_npzを呼び出してtrainerのsnapshotをとることにしました。

実装

ポイントはtrainerが各epochやiterationに合わせた様々値を持っている点でしょうか。
せっかくなので呼び出して有効活用してみます。


import chainer

class SnapshotTrainer(chainer.training.Extension):
    def __init__(self, out, fo_name='snapshot', save_type='latest_only'):
        self.out = out.rstrip('/') + '/' + fo_name
        self.save_type = save_type

    def __call__(self, trainer):
        if self.save_type == 'epoch':
            fo = '{}_epoch_{}'.format(self.out, trainer.updater.epoch)
        elif self.save_type == 'iteration':
            fo = '{}_iter_{}'.format(self.out, trainer.updater.iteration)
        else:
            fo = '{}_latest'.format(self.out)

        chainer.serializers.save_npz(fo, trainer)

↑のクラスをtrainer.extendで呼び出します。


# Take a snapshot (original)
snapshot_type = 'epoch' if args.snapshot_dvide else 'latest'
trainer.extend(SnapshotTrainer(out=args.out,  fo_name='snapshot', save_type=snapshot_type), trigger=(args.snapshot_interval, 'iteration'))

これで解決できている・・・のかなぁ?(今のところ問題なしですが、果たして・・・

まとめ

今回はsnapshotをchainer.serializers.save_npzで保存するExtensionを作りました。

これならsnapshot以外にも色々なExtensionを作れそうです。

参考

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2