Chainer
MNIST
Autoencoder
Deconvolution2D

概要

chainerの作法、調べてみた。
Deconvolution2D使ってみた。
autoencoder書いてみた。

写真

auto3.png

サンプルコード

import numpy as np
from chainer import datasets, iterators
from chainer import optimizers
from chainer import Chain
from chainer import training
from chainer.training import extensions
import chainer.functions as F
import chainer.links as L
import chainer
import matplotlib.pyplot as plt


class Autoencoder(Chain):
    def __init__(self):
        super(Autoencoder, self).__init__(encoder = L.Convolution2D(None, 16, 5, 1), decoder = L.Deconvolution2D(16, 1, 5, 1))
    def __call__(self, x):
        h = F.relu(self.encoder(x))
        return F.relu(self.decoder(h))

def main():
    train, test = datasets.get_mnist()
    def transform(data):
        img, lable = data
        img = img.reshape((1, 28, 28))
        return img, lable
    train = datasets.TransformDataset(train, transform)
    test = datasets.TransformDataset(test, transform)
    train = train[0 : 1000]
    train = [i[0] for i in train]
    train = datasets.TupleDataset(train, train)
    train_iter = iterators.SerialIterator(train, 100)
    test = test[0 : 25]
    model = L.Classifier(Autoencoder(), lossfun = F.mean_squared_error)
    model.compute_accuracy = False
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    updater = training.StandardUpdater(train_iter, optimizer, device = -1)
    trainer = training.Trainer(updater, (80, 'epoch'), out = "result")
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss']))
    trainer.run()
    pred = []
    for (data, label) in test:
        pred_data = model.predictor(np.array([data]).astype(np.float32)).data
        pred.append((pred_data, label))
    for index, (data, label) in enumerate(pred):
        plt.subplot(5, 5, index + 1)
        plt.axis('off')
        plt.imshow(data.reshape(28, 28), cmap = plt.cm.gray_r, interpolation = 'nearest')
        n = int(label)
        plt.title(n, color = 'red')
    plt.savefig("auto3.png")
    plt.show()

if __name__ == '__main__':
    main()


以上。