LoginSignup
0
0

More than 5 years have passed since last update.

chainerの作法 その11

Posted at

概要

chainerの作法、調べてみた。
Deconvolution2D使ってみた。
autoencoder書いてみた。

写真

auto3.png

サンプルコード

import numpy as np
from chainer import datasets, iterators
from chainer import optimizers
from chainer import Chain
from chainer import training
from chainer.training import extensions
import chainer.functions as F
import chainer.links as L
import chainer
import matplotlib.pyplot as plt


class Autoencoder(Chain):
    def __init__(self):
        super(Autoencoder, self).__init__(encoder = L.Convolution2D(None, 16, 5, 1), decoder = L.Deconvolution2D(16, 1, 5, 1))
    def __call__(self, x):
        h = F.relu(self.encoder(x))
        return F.relu(self.decoder(h))

def main():
    train, test = datasets.get_mnist()
    def transform(data):
        img, lable = data
        img = img.reshape((1, 28, 28))
        return img, lable
    train = datasets.TransformDataset(train, transform)
    test = datasets.TransformDataset(test, transform)
    train = train[0 : 1000]
    train = [i[0] for i in train]
    train = datasets.TupleDataset(train, train)
    train_iter = iterators.SerialIterator(train, 100)
    test = test[0 : 25]
    model = L.Classifier(Autoencoder(), lossfun = F.mean_squared_error)
    model.compute_accuracy = False
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    updater = training.StandardUpdater(train_iter, optimizer, device = -1)
    trainer = training.Trainer(updater, (80, 'epoch'), out = "result")
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss']))
    trainer.run()
    pred = []
    for (data, label) in test:
        pred_data = model.predictor(np.array([data]).astype(np.float32)).data
        pred.append((pred_data, label))
    for index, (data, label) in enumerate(pred):
        plt.subplot(5, 5, index + 1)
        plt.axis('off')
        plt.imshow(data.reshape(28, 28), cmap = plt.cm.gray_r, interpolation = 'nearest')
        n = int(label)
        plt.title(n, color = 'red')
    plt.savefig("auto3.png")
    plt.show()

if __name__ == '__main__':
    main()


以上。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0