LoginSignup
5
10

More than 5 years have passed since last update.

Chainer 2.0 のMNISTサンプルで推論を動かす

Posted at

はじめに

Chainerのexampleコードを動かして、深層学習(ディープラーニング)を勉強中です。

本記事を起稿した時点(2017/6)でChainerの最新バージョンは2.0ですが、1.x と互換性がなく古いバージョンのコードが動かないことがあります。
参考:chainerのバージョンごとの違い(2016年1月19日現在)

本記事は、Chainer 2.0 のMNISTサンプルで、推論を動かすための実装メモです。

実装は、こちらの記事を参考にしました。
Chainer: ビギナー向けチュートリアル Vol.1

環境

Chainer 2.0
python 2.7.10
CPUで実行

コード

Chainer 2.0 のMNISTサンプル(オリジナル)
https://github.com/chainer/chainer/tree/v2.0.0/examples/mnist

1. train_mnist.pyに学習済みモデルを保存する処理(1行)を追加

train_mnist.py
    # Run the training
    trainer.run()

    chainer.serializers.save_npz('my_mnist.model', model) # Added

2. train_mnist.pyを実行し学習を開始する

$ python train_mnist.py --epoch 3
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 3

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.191836    0.0885223             0.942233       0.9718                    26.099        
2           0.0726428   0.0825069             0.9768         0.974                     53.4849       
3           0.0466335   0.0751425             0.984983       0.9747                    81.2683       
$ ls
my_mnist.model  result/         train_mnist.py*

※デフォルトのepoch=20だと学習に少し時間がかかるので、今回はepoch=3としています。
MacBook Pro(Mid2015)だと1分ちょいで学習が完了します。

3. 保存した学習済みモデルを読み込んで推論する

predict_mnist.py
#!/usr/bin/env python

from __future__ import print_function 

try:
    import matplotlib
    matplotlib.use('Agg')
except ImportError:
    pass

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions


# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    args = parser.parse_args()

    print('# unit: {}'.format(args.unit))
    print('')

    # Set up a neural network
    model = L.Classifier(MLP(args.unit, 10))

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    chainer.serializers.load_npz('my_mnist.model', model)

    x, t = test[0]
    print('label:', t)

    x = x[None, ...]
    y = model.predictor(x)
    y = y.data

    print('predicted_label:', y.argmax(axis=1)[0])

if __name__ == '__main__':
    main()

predict_mnist.py では、my_mnist.modelを読み込んで、テストデータに対するラベルの推論をしています。

$ python predict_mnist.py 
# unit: 1000

label: 7
predicted_label: 7

正解ラベルと同じラベルが得られました。

モデルのオブジェクトを作るときの注意点

train_mnist.py
    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))

train_mnist.py で、 L.Classifierを使ってmodelを作りました。
推論時にモデルのオブジェクトを作るときも同様にL.Classifierを使う必要があります。

L.Classifier を通さずにモデルのオブジェクトを作ると、モデルをロードしたときにエラーが返ってきます。

predict_mnist.py
    # Set up a neural network
    model = MLP(args.unit, 10)

エラー
KeyError: 'l2/b is not a file in the archive'

参考 Chainerのモデルのセーブとロード

5
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
10