LoginSignup
11
12

More than 5 years have passed since last update.

Chainerで推論実行するためにsnapshotをload_npzする時はpathの設定が必要

Last updated at Posted at 2018-01-27

【追記】 predictorの前にwith chainer.using_config('train', False):を設定しないと誤差伝番が実行されてしまうので注意

Chainerで推論実行する時はchainer.serializers.load_npz()でモデルを読み込む。chainer.serializers.load_npzで生成されたモデルと違い、extensions.snapshotで生成されたスナップショットからモデルを読み込む場合は注意が必要。

動作環境

  • Chainer 3.2
  • Python 3.5

モデルの読み込み

.modelファイルはchainer.serializers.load_npzによって保存されたファイル。.snapshotextensions.snapshotで生成されたファイルであるとする。snapshotのファイルはネットワークの情報以外にも色々と書き込まれており、推論実行に必要な部分を指定するにはpath='updater/model:main/'の設定が必要。以下はargs.model(=.modelまたは.snapshotのパス)にどちらが格納されていても対応できるようにした一例。

    import chainer
    import chainer.links as L

    model = L.Classifier(適当なネットワークのクラス)

    name, ext = os.path.splitext(os.path.basename(args.model))
    load_path = ''
    if(ext == '.model'):
        print('model read')
    elif(ext == '.snapshot'):
        print('snapshot read')
        load_path = 'updater/model:main/'
    else:
        print('model read error')
        exit()

    chainer.serializers.load_npz(args.model, model, path=load_path)

モデルによる推論実行

モデルを使用する場合は以下のようにする。モデルの生成時にL.Classifierを利用しているため、predictor()で呼び出せば推論実行が可能になる。コンフィグ設定も忘れずに。

    with chainer.using_config('train', False):
        y = model.predictor(xなどの適当な入力データ)

以上。

11
12
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
12