はじめに
Chainerのexampleコードを動かして、深層学習(ディープラーニング)を勉強中です。
本記事を起稿した時点(2017/6)でChainerの最新バージョンは2.0ですが、1.x と互換性がなく古いバージョンのコードが動かないことがあります。
参考:chainerのバージョンごとの違い(2016年1月19日現在)
本記事は、Chainer 2.0 のMNISTサンプルで、推論を動かすための実装メモです。
実装は、こちらの記事を参考にしました。
Chainer: ビギナー向けチュートリアル Vol.1
環境
Chainer 2.0
python 2.7.10
CPUで実行
コード
Chainer 2.0 のMNISTサンプル(オリジナル)
https://github.com/chainer/chainer/tree/v2.0.0/examples/mnist
1. train_mnist.pyに学習済みモデルを保存する処理(1行)を追加
# Run the training
trainer.run()
chainer.serializers.save_npz('my_mnist.model', model) # Added
2. train_mnist.pyを実行し学習を開始する
$ python train_mnist.py --epoch 3
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 3
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time
1 0.191836 0.0885223 0.942233 0.9718 26.099
2 0.0726428 0.0825069 0.9768 0.974 53.4849
3 0.0466335 0.0751425 0.984983 0.9747 81.2683
$ ls
my_mnist.model result/ train_mnist.py*
※デフォルトのepoch=20だと学習に少し時間がかかるので、今回はepoch=3としています。
MacBook Pro(Mid2015)だと1分ちょいで学習が完了します。
3. 保存した学習済みモデルを読み込んで推論する
#!/usr/bin/env python
from __future__ import print_function
try:
import matplotlib
matplotlib.use('Agg')
except ImportError:
pass
import argparse
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
# Network definition
class MLP(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP, self).__init__()
with self.init_scope():
# the size of the inputs to each layer will be inferred
self.l1 = L.Linear(None, n_units) # n_in -> n_units
self.l2 = L.Linear(None, n_units) # n_units -> n_units
self.l3 = L.Linear(None, n_out) # n_units -> n_out
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
return self.l3(h2)
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
print('# unit: {}'.format(args.unit))
print('')
# Set up a neural network
model = L.Classifier(MLP(args.unit, 10))
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
chainer.serializers.load_npz('my_mnist.model', model)
x, t = test[0]
print('label:', t)
x = x[None, ...]
y = model.predictor(x)
y = y.data
print('predicted_label:', y.argmax(axis=1)[0])
if __name__ == '__main__':
main()
predict_mnist.py では、my_mnist.modelを読み込んで、テストデータに対するラベルの推論をしています。
$ python predict_mnist.py
# unit: 1000
label: 7
predicted_label: 7
正解ラベルと同じラベルが得られました。
モデルのオブジェクトを作るときの注意点
# iteration, which will be used by the PrintReport extension below.
model = L.Classifier(MLP(args.unit, 10))
train_mnist.py で、 L.Classifierを使ってmodelを作りました。
推論時にモデルのオブジェクトを作るときも同様にL.Classifierを使う必要があります。
L.Classifier を通さずにモデルのオブジェクトを作ると、モデルをロードしたときにエラーが返ってきます。
# Set up a neural network
model = MLP(args.unit, 10)
エラー
KeyError: 'l2/b is not a file in the archive'