概要
convnetjsでautoencoderやってみた。
写真
サンプルコード
var url = {
'0': '/assets/A/j/W/3/AjW3t.png',
'1': '/assets/e/l/w/h/elwhh.png',
'2': '/assets/Y/0/k/e/Y0kel.png',
'3': '/assets/8/O/x/F/8OxFx.png',
'4': '/assets/q/P/x/6/qPx60.png',
'5': '/assets/6/j/8/t/6j8tg.png',
'6': '/assets/S/H/y/u/SHyuZ.png',
'7': '/assets/q/Z/G/o/qZGoq.png',
'8': '/assets/o/D/k/E/oDkE9.png',
'9': '/assets/2/N/L/G/2NLGe.png',
'10': '/assets/Y/b/w/I/YbwIP.png',
'11': '/assets/w/e/S/9/weS9c.png',
'12': '/assets/K/9/6/V/K96Ve.png',
'13': '/assets/G/2/Y/x/G2Yxd.png',
'14': '/assets/K/X/8/d/KX8dt.png',
'15': '/assets/e/8/E/6/e8E6f.png',
'16': '/assets/S/N/1/Z/SN1Z2.png',
'17': '/assets/C/5/G/G/C5GGt.png',
'18': '/assets/s/8/n/V/s8nVW.png',
'19': '/assets/2/E/a/d/2EadT.png',
'20': '/assets/K/c/m/W/KcmWd.png',
};
var layer_defs,
net,
trainer;
layer_defs = [];
layer_defs.push({
type: 'input',
out_sx: 28,
out_sy: 28,
out_depth: 1
});
layer_defs.push({
type: 'fc',
num_neurons: 32,
activation: 'tahn'
});
layer_defs.push({
type: 'regression',
num_neurons: 28 * 28
});
net = new convnetjs.Net();
net.makeLayers(layer_defs);
trainer = new convnetjs.SGDTrainer(net, {
learning_rate: 0.01,
method: 'adagrad',
batch_size: 20,
l2_decay: 0.001,
l1_decay: 0.001
});
var num_batches = 21;
var data_img_elts = new Array(num_batches);
var img_data = new Array(num_batches);
var loaded = new Array(num_batches);
var loaded_train_batches = [];
var paused = false;
var embed_samples = [];
var embed_imgs = [];
var step_num = 0;
var load_data_batch = function(batch_num) {
data_img_elts[batch_num] = new Image();
var data_img_elt = data_img_elts[batch_num];
data_img_elt.onload = function() {
var data_canvas = document.createElement('canvas');
data_canvas.width = data_img_elt.width;
data_canvas.height = data_img_elt.height;
var data_ctx = data_canvas.getContext("2d");
data_ctx.drawImage(data_img_elt, 0, 0);
img_data[batch_num] = data_ctx.getImageData(0, 0, data_canvas.width, data_canvas.height);
loaded[batch_num] = true;
alert(batch_num);
if (batch_num < 20)
{
loaded_train_batches.push(batch_num);
}
};
data_img_elt.src = url[batch_num];
}
var sample_training_instance = function() {
var bi = Math.floor(Math.random() * loaded_train_batches.length);
var b = loaded_train_batches[bi];
var k = Math.floor(Math.random() * 3000);
var n = b * 3000 + k;
if (step_num % 5000 === 0 && step_num > 0)
{
for (var i = 0; i < num_batches; i++)
{
if (!loaded[i])
{
load_data_batch(i);
break;
}
}
}
var p = img_data[b].data;
var x = new convnetjs.Vol(28, 28, 1, 0.0);
var W = 28 * 28;
for (var i = 0; i < W; i++)
{
var ix = ((W * k) + i) * 4;
x.w[i] = p[ix] / 255.0;
}
return {
x: x
};
}
for (var k = 0; k < loaded.length; k++)
{
loaded[k] = false;
}
load_data_batch(0);
load_data_batch(20);
var canvas = document.getElementById('canvas')
var ctx = canvas.getContext('2d');
function draw2(w, x, y) {
var canv = document.createElement('canvas');
canv.width = 28;
canv.height = 28;
var ctxt = canv.getContext('2d');
var g = ctxt.createImageData(28, 28);
for (var j = 0; j < 784; j++)
{
var pp = j * 4;
var d = w[j] * 255;
for (var k = 0; k < 3; k++)
{
g.data[pp + k] = d;
}
g.data[pp + 3] = 255;
}
var x0 = x * 30;
var y0 = y * 50;
ctx.putImageData(g, x0, y0);
}
function test() {
for (var i = 0; i < 10; i++)
{
var sample = sample_training_instance();
draw2(sample.x.w, i, 1);
var r = net.forward(sample.x);
draw2(r.w, i, 3);
}
}
var loss = 0;
var lossi = 0;
function run() {
for (var i = 0; i < 20000; i++)
{
var sample = sample_training_instance();
var stats = trainer.train(sample.x, sample.x.w);
loss += stats.loss;
lossi += 1;
}
loss /= lossi;
alert(loss);
test();
}
成果物
以上。