【追記】 predictor
の前にwith chainer.using_config('train', False):
を設定しないと誤差伝番が実行されてしまうので注意
Chainerで推論実行する時はchainer.serializers.load_npz()
でモデルを読み込む。chainer.serializers.load_npz
で生成されたモデルと違い、extensions.snapshot
で生成されたスナップショットからモデルを読み込む場合は注意が必要。
動作環境
- Chainer 3.2
- Python 3.5
モデルの読み込み
.model
ファイルはchainer.serializers.load_npz
によって保存されたファイル。.snapshot
はextensions.snapshot
で生成されたファイルであるとする。snapshot
のファイルはネットワークの情報以外にも色々と書き込まれており、推論実行に必要な部分を指定するにはpath='updater/model:main/'
の設定が必要。以下はargs.model
(=.model
または.snapshot
のパス)にどちらが格納されていても対応できるようにした一例。
import chainer
import chainer.links as L
model = L.Classifier(適当なネットワークのクラス)
name, ext = os.path.splitext(os.path.basename(args.model))
load_path = ''
if(ext == '.model'):
print('model read')
elif(ext == '.snapshot'):
print('snapshot read')
load_path = 'updater/model:main/'
else:
print('model read error')
exit()
chainer.serializers.load_npz(args.model, model, path=load_path)
モデルによる推論実行
モデルを使用する場合は以下のようにする。モデルの生成時にL.Classifier
を利用しているため、predictor()
で呼び出せば推論実行が可能になる。コンフィグ設定も忘れずに。
with chainer.using_config('train', False):
y = model.predictor(xなどの適当な入力データ)
以上。