MNIST
Autoencoder
jsdo
convnetjs

概要

convnetjsでautoencoderやってみた。

写真

image.png

サンプルコード

var url = {
    '0': '/assets/A/j/W/3/AjW3t.png',
    '1': '/assets/e/l/w/h/elwhh.png',
    '2': '/assets/Y/0/k/e/Y0kel.png',
    '3': '/assets/8/O/x/F/8OxFx.png',
    '4': '/assets/q/P/x/6/qPx60.png',
    '5': '/assets/6/j/8/t/6j8tg.png',
    '6': '/assets/S/H/y/u/SHyuZ.png',
    '7': '/assets/q/Z/G/o/qZGoq.png',
    '8': '/assets/o/D/k/E/oDkE9.png',
    '9': '/assets/2/N/L/G/2NLGe.png',
    '10': '/assets/Y/b/w/I/YbwIP.png',
    '11': '/assets/w/e/S/9/weS9c.png',
    '12': '/assets/K/9/6/V/K96Ve.png',
    '13': '/assets/G/2/Y/x/G2Yxd.png',
    '14': '/assets/K/X/8/d/KX8dt.png',
    '15': '/assets/e/8/E/6/e8E6f.png',
    '16': '/assets/S/N/1/Z/SN1Z2.png',
    '17': '/assets/C/5/G/G/C5GGt.png',
    '18': '/assets/s/8/n/V/s8nVW.png',
    '19': '/assets/2/E/a/d/2EadT.png',
    '20': '/assets/K/c/m/W/KcmWd.png',
};
var layer_defs, 
    net, 
    trainer;
layer_defs = [];
layer_defs.push({
    type: 'input', 
    out_sx: 28, 
    out_sy: 28, 
    out_depth: 1
});
layer_defs.push({
    type: 'fc', 
    num_neurons: 32,
    activation: 'tahn'
});
layer_defs.push({
    type: 'regression', 
    num_neurons: 28 * 28
});
net = new convnetjs.Net();
net.makeLayers(layer_defs);
trainer = new convnetjs.SGDTrainer(net, {
    learning_rate: 0.01, 
    method: 'adagrad',
    batch_size: 20, 
    l2_decay: 0.001,
    l1_decay: 0.001
});
var num_batches = 21;
var data_img_elts = new Array(num_batches);
var img_data = new Array(num_batches);
var loaded = new Array(num_batches);
var loaded_train_batches = [];
var paused = false;
var embed_samples = [];
var embed_imgs = [];
var step_num = 0;
var load_data_batch = function(batch_num) {
    data_img_elts[batch_num] = new Image();
    var data_img_elt = data_img_elts[batch_num];
    data_img_elt.onload = function() { 
        var data_canvas = document.createElement('canvas');
        data_canvas.width = data_img_elt.width;
        data_canvas.height = data_img_elt.height;
        var data_ctx = data_canvas.getContext("2d");
        data_ctx.drawImage(data_img_elt, 0, 0); 
        img_data[batch_num] = data_ctx.getImageData(0, 0, data_canvas.width, data_canvas.height);
        loaded[batch_num] = true;
        alert(batch_num);
        if (batch_num < 20)
        {
            loaded_train_batches.push(batch_num);
        }
    };
    data_img_elt.src = url[batch_num];
}
var sample_training_instance = function() {
    var bi = Math.floor(Math.random() * loaded_train_batches.length);
    var b = loaded_train_batches[bi];
    var k = Math.floor(Math.random() * 3000);
    var n = b * 3000 + k;
    if (step_num % 5000 === 0 && step_num > 0) 
    {
        for (var i = 0; i < num_batches; i++)
        {
            if (!loaded[i])
            {
                load_data_batch(i);
                break;
            }
        }
    }
    var p = img_data[b].data;
    var x = new convnetjs.Vol(28, 28, 1, 0.0);
    var W = 28 * 28;
    for (var i = 0; i < W; i++)
    {
        var ix = ((W * k) + i) * 4;
        x.w[i] = p[ix] / 255.0;
    }
    return {
        x: x
    };
}
for (var k = 0; k < loaded.length; k++)
{
    loaded[k] = false;
}
load_data_batch(0);
load_data_batch(20); 
var canvas = document.getElementById('canvas')
var ctx = canvas.getContext('2d');

function draw2(w, x, y) {
    var canv = document.createElement('canvas');
    canv.width = 28;
    canv.height = 28;
    var ctxt = canv.getContext('2d');
    var g = ctxt.createImageData(28, 28);
    for (var j = 0; j < 784; j++)
    {
        var pp = j * 4;
        var d = w[j] * 255;
        for (var k = 0; k < 3; k++)
        {
            g.data[pp + k] = d;
        }
        g.data[pp + 3] = 255;
    }
    var x0 = x * 30;
    var y0 = y * 50;
    ctx.putImageData(g, x0, y0);
}
function test() {
    for (var i = 0; i < 10; i++)
    {
        var sample = sample_training_instance();
        draw2(sample.x.w, i, 1);
        var r = net.forward(sample.x);
        draw2(r.w, i, 3);
    }
}
var loss = 0;
var lossi = 0;
function run() {
    for (var i = 0; i < 20000; i++)
    {
        var sample = sample_training_instance();
        var stats = trainer.train(sample.x, sample.x.w);
        loss += stats.loss;
        lossi += 1;
    }
    loss /= lossi;
    alert(loss);
    test();
}

成果物

http://jsdo.it/ohisama1/GKKp

以上。