LoginSignup
0

More than 5 years have passed since last update.

tensorflow.jsでautoencoder

Last updated at Posted at 2018-04-05

概要

tensorflow.jsでautoencoderやってみた。

写真

image

サンプルコード

var url = {
    '0': '/assets/A/j/W/3/AjW3t.png',
    '1': '/assets/e/l/w/h/elwhh.png',
    '2': '/assets/Y/0/k/e/Y0kel.png',
    '3': '/assets/8/O/x/F/8OxFx.png',
    '4': '/assets/q/P/x/6/qPx60.png',
    '5': '/assets/6/j/8/t/6j8tg.png',
    '6': '/assets/S/H/y/u/SHyuZ.png',
    '7': '/assets/q/Z/G/o/qZGoq.png',
    '8': '/assets/o/D/k/E/oDkE9.png',
    '9': '/assets/2/N/L/G/2NLGe.png',
    '10': '/assets/Y/b/w/I/YbwIP.png',
    '11': '/assets/w/e/S/9/weS9c.png',
    '12': '/assets/K/9/6/V/K96Ve.png',
    '13': '/assets/G/2/Y/x/G2Yxd.png',
    '14': '/assets/K/X/8/d/KX8dt.png',
    '15': '/assets/e/8/E/6/e8E6f.png',
    '16': '/assets/S/N/1/Z/SN1Z2.png',
    '17': '/assets/C/5/G/G/C5GGt.png',
    '18': '/assets/s/8/n/V/s8nVW.png',
    '19': '/assets/2/E/a/d/2EadT.png',
    '20': '/assets/K/c/m/W/KcmWd.png',
};
const model = tf.sequential();
model.add(tf.layers.dense({
    units: 20,
    activation: 'relu',
    inputShape: [784]
}));
model.add(tf.layers.dense({
    units: 784,
    activation: 'linear'
}));
model.compile({
    optimizer: 'adam',
    loss: 'meanSquaredError'
});
var num_batches = 21;
var data_img_elts = new Array(num_batches);
var img_data = new Array(num_batches);
var loaded = new Array(num_batches);
var loaded_train_batches = [];
var paused = false;
var embed_samples = [];
var embed_imgs = [];
var step_num = 0;
var load_data_batch = function(batch_num) {
    data_img_elts[batch_num] = new Image();
    var data_img_elt = data_img_elts[batch_num];
    data_img_elt.onload = function() { 
        var data_canvas = document.createElement('canvas');
        data_canvas.width = data_img_elt.width;
        data_canvas.height = data_img_elt.height;
        var data_ctx = data_canvas.getContext("2d");
        data_ctx.drawImage(data_img_elt, 0, 0); 
        img_data[batch_num] = data_ctx.getImageData(0, 0, data_canvas.width, data_canvas.height);
        loaded[batch_num] = true;
        alert("loaded ok");
    };
    data_img_elt.src = url[batch_num];
}
for (var k = 0; k < loaded.length; k++)
{
    loaded[k] = false;
}
load_data_batch(1);
var canvas = document.getElementById('canvas')
var ctx = canvas.getContext('2d');
function draw2(w, x, y) {
    var canv = document.createElement('canvas');
    canv.width = 28;
    canv.height = 28;
    var ctxt = canv.getContext('2d');
    var g = ctxt.createImageData(28, 28);
    for (var j = 0; j < 784; j++)
    {
        var pp = j * 4;
        var d = w[j] * 255;
        for (var k = 0; k < 3; k++)
        {
            g.data[pp + k] = d;
        }
        g.data[pp + 3] = 255;
    }
    var x0 = x * 30;
    var y0 = y * 50;
    ctx.putImageData(g, x0, y0);
}
function run() {
    const buffer = tf.buffer([10, 784]);
    var p = img_data[1].data;
    for (var i = 0; i < 10; i++) 
    {
        var x = [];
        for (var j = 0; j < 784; j++)
        {
            var s = i * 784 * 4 + j * 4;
            var v = p[s] / 255.0;
            buffer.set(v, i, j);
            x.push(v);
        }
        draw2(x, i, 1);
    }
    const xs = buffer.toTensor();
    model.fit(xs, xs, {
        batchSize: 10, 
        epochs: 100
    }).then((d) => {
        var str = "loss = ";
        str += d.history.loss[0]; 
        var pre = model.predict(xs);
        var f = pre.dataSync();
        for (var a = 0; a < 10; a++)
        {
            var d = [];
            for (var b = 0; b < 784; b++)
            {
                var e = a * 784 + b; 
                d.push(f[e]);
            }
            draw2(d, a, 2);
        }
    });    
}

成果物

以上。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0