LoginSignup
0
2

More than 5 years have passed since last update.

chainerの作法 その8

Posted at

概要

chainerの作法を調べて見た。
saveして、loadしてみた。

サンプルコード

save_npzするやつ。

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, datasets, iterators, Chain
from chainer.training import extensions

class MLP(Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)
            self.l2 = L.Linear(None, n_units)
            self.l3 = L.Linear(None, n_out)
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def main():
    train, test = datasets.get_mnist()
    train_iter = iterators.SerialIterator(train, 100)
    test_iter = iterators.SerialIterator(test, 100, repeat = False, shuffle = False)    
    model = L.Classifier(MLP(1000, 10))
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    updater = training.updaters.StandardUpdater(train_iter, optimizer, device = -1)
    trainer = training.Trainer(updater, (5, 'epoch'), out = 'result')
    trainer.extend(extensions.Evaluator(test_iter, model, device = -1))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'elapsed_time']))
    trainer.run()
    chainer.serializers.save_npz('mnist3.model', model)


if __name__ == '__main__':
    main()



サンプルコード

load_npzするやつ。

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

class MLP(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)
            self.l2 = L.Linear(None, n_units)
            self.l3 = L.Linear(None, n_out)
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

def main():
    model = L.Classifier(MLP(1000, 10))
    train, test = chainer.datasets.get_mnist()
    chainer.serializers.load_npz('mnist3.model', model)
    for i in range(10):
        x, t = test[i]
        print ('label:', t)
        x = x[None, ...]
        y = model.predictor(x)
        y = y.data
        print ('predicted_label:', y.argmax(axis = 1)[0])

if __name__ == '__main__':
    main()



以上。

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2