LoginSignup
0
0

More than 5 years have passed since last update.

convnetjsでautoencoder

Posted at

概要

convnetjsでautoencoderやってみた。

写真

image.png

サンプルコード

var url = {
    '0': '/assets/A/j/W/3/AjW3t.png',
    '1': '/assets/e/l/w/h/elwhh.png',
    '2': '/assets/Y/0/k/e/Y0kel.png',
    '3': '/assets/8/O/x/F/8OxFx.png',
    '4': '/assets/q/P/x/6/qPx60.png',
    '5': '/assets/6/j/8/t/6j8tg.png',
    '6': '/assets/S/H/y/u/SHyuZ.png',
    '7': '/assets/q/Z/G/o/qZGoq.png',
    '8': '/assets/o/D/k/E/oDkE9.png',
    '9': '/assets/2/N/L/G/2NLGe.png',
    '10': '/assets/Y/b/w/I/YbwIP.png',
    '11': '/assets/w/e/S/9/weS9c.png',
    '12': '/assets/K/9/6/V/K96Ve.png',
    '13': '/assets/G/2/Y/x/G2Yxd.png',
    '14': '/assets/K/X/8/d/KX8dt.png',
    '15': '/assets/e/8/E/6/e8E6f.png',
    '16': '/assets/S/N/1/Z/SN1Z2.png',
    '17': '/assets/C/5/G/G/C5GGt.png',
    '18': '/assets/s/8/n/V/s8nVW.png',
    '19': '/assets/2/E/a/d/2EadT.png',
    '20': '/assets/K/c/m/W/KcmWd.png',
};
var layer_defs, 
    net, 
    trainer;
layer_defs = [];
layer_defs.push({
    type: 'input', 
    out_sx: 28, 
    out_sy: 28, 
    out_depth: 1
});
layer_defs.push({
    type: 'fc', 
    num_neurons: 32,
    activation: 'tahn'
});
layer_defs.push({
    type: 'regression', 
    num_neurons: 28 * 28
});
net = new convnetjs.Net();
net.makeLayers(layer_defs);
trainer = new convnetjs.SGDTrainer(net, {
    learning_rate: 0.01, 
    method: 'adagrad',
    batch_size: 20, 
    l2_decay: 0.001,
    l1_decay: 0.001
});
var num_batches = 21;
var data_img_elts = new Array(num_batches);
var img_data = new Array(num_batches);
var loaded = new Array(num_batches);
var loaded_train_batches = [];
var paused = false;
var embed_samples = [];
var embed_imgs = [];
var step_num = 0;
var load_data_batch = function(batch_num) {
    data_img_elts[batch_num] = new Image();
    var data_img_elt = data_img_elts[batch_num];
    data_img_elt.onload = function() { 
        var data_canvas = document.createElement('canvas');
        data_canvas.width = data_img_elt.width;
        data_canvas.height = data_img_elt.height;
        var data_ctx = data_canvas.getContext("2d");
        data_ctx.drawImage(data_img_elt, 0, 0); 
        img_data[batch_num] = data_ctx.getImageData(0, 0, data_canvas.width, data_canvas.height);
        loaded[batch_num] = true;
        alert(batch_num);
        if (batch_num < 20)
        {
            loaded_train_batches.push(batch_num);
        }
    };
    data_img_elt.src = url[batch_num];
}
var sample_training_instance = function() {
    var bi = Math.floor(Math.random() * loaded_train_batches.length);
    var b = loaded_train_batches[bi];
    var k = Math.floor(Math.random() * 3000);
    var n = b * 3000 + k;
    if (step_num % 5000 === 0 && step_num > 0) 
    {
        for (var i = 0; i < num_batches; i++)
        {
            if (!loaded[i])
            {
                load_data_batch(i);
                break;
            }
        }
    }
    var p = img_data[b].data;
    var x = new convnetjs.Vol(28, 28, 1, 0.0);
    var W = 28 * 28;
    for (var i = 0; i < W; i++)
    {
        var ix = ((W * k) + i) * 4;
        x.w[i] = p[ix] / 255.0;
    }
    return {
        x: x
    };
}
for (var k = 0; k < loaded.length; k++)
{
    loaded[k] = false;
}
load_data_batch(0);
load_data_batch(20); 
var canvas = document.getElementById('canvas')
var ctx = canvas.getContext('2d');

function draw2(w, x, y) {
    var canv = document.createElement('canvas');
    canv.width = 28;
    canv.height = 28;
    var ctxt = canv.getContext('2d');
    var g = ctxt.createImageData(28, 28);
    for (var j = 0; j < 784; j++)
    {
        var pp = j * 4;
        var d = w[j] * 255;
        for (var k = 0; k < 3; k++)
        {
            g.data[pp + k] = d;
        }
        g.data[pp + 3] = 255;
    }
    var x0 = x * 30;
    var y0 = y * 50;
    ctx.putImageData(g, x0, y0);
}
function test() {
    for (var i = 0; i < 10; i++)
    {
        var sample = sample_training_instance();
        draw2(sample.x.w, i, 1);
        var r = net.forward(sample.x);
        draw2(r.w, i, 3);
    }
}
var loss = 0;
var lossi = 0;
function run() {
    for (var i = 0; i < 20000; i++)
    {
        var sample = sample_training_instance();
        var stats = trainer.train(sample.x, sample.x.w);
        loss += stats.loss;
        lossi += 1;
    }
    loss /= lossi;
    alert(loss);
    test();
}

成果物

以上。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0