Chainer
MNIST

chainerの作法 その8

概要

chainerの作法を調べて見た。
saveして、loadしてみた。

サンプルコード

save_npzするやつ。

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, datasets, iterators, Chain
from chainer.training import extensions

class MLP(Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)
            self.l2 = L.Linear(None, n_units)
            self.l3 = L.Linear(None, n_out)
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def main():
    train, test = datasets.get_mnist()
    train_iter = iterators.SerialIterator(train, 100)
    test_iter = iterators.SerialIterator(test, 100, repeat = False, shuffle = False)    
    model = L.Classifier(MLP(1000, 10))
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    updater = training.updaters.StandardUpdater(train_iter, optimizer, device = -1)
    trainer = training.Trainer(updater, (5, 'epoch'), out = 'result')
    trainer.extend(extensions.Evaluator(test_iter, model, device = -1))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'elapsed_time']))
    trainer.run()
    chainer.serializers.save_npz('mnist3.model', model)


if __name__ == '__main__':
    main()



サンプルコード

load_npzするやつ。

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

class MLP(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)
            self.l2 = L.Linear(None, n_units)
            self.l3 = L.Linear(None, n_out)
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def main():
    model = L.Classifier(MLP(1000, 10))
    train, test = chainer.datasets.get_mnist()
    chainer.serializers.load_npz('mnist3.model', model)
    for i in range(10):
        x, t = test[i]
        print ('label:', t)
        x = x[None, ...]
        y = model.predictor(x)
        y = y.data
        print ('predicted_label:', y.argmax(axis = 1)[0])

if __name__ == '__main__':
    main()



以上。