LoginSignup
0
0

More than 5 years have passed since last update.

jsdoでdeeplearn.js その3

Posted at

概要

jsdoでdeeplearn.jsやってみた。
学習させてみた。

参考にしたページ

写真

image.png

サンプルコード

function mylog(txt) {
    document.getElementById('helloWorld').innerHTML += txt + "<BR>";
}
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
function leaky_relu(x) {
    var mx = g.multiply(g.constant(-1), g.relu(x));
    var neg_part = g.multiply(g.constant(-0.2), g.relu(mx));
    return g.add(g.relu(x) , neg_part)
}
var x = g.placeholder('x', []);
var W1data = dl.Array2D.randNormal([16, 1]);
var b1data = dl.Array2D.zeros([16, 1]);
var W2data = dl.Array2D.randNormal([32, 16], 0, Math.sqrt(1.0 / 16.0));
var b2data = dl.Array2D.zeros([32, 1]);
var W3data = dl.Array2D.randNormal([1, 32], 0, Math.sqrt(1.0 / 32.0));
var b3data = dl.Array2D.zeros([1, 1]);
var W1 = g.variable('W1', W1data);
var b1 = g.variable('b1', b1data);
var W2 = g.variable('W2', W2data);
var b2 = g.variable('b2', b2data);
var W3 = g.variable('W3', W3data);
var b3 = g.variable('b3', b3data);
var h1 = leaky_relu(g.add(g.multiply(W1, x), b1));
var h2 = leaky_relu(g.add(g.matmul(W2, h1), b2));
var y_ = leaky_relu(g.add(g.matmul(W3, h2), b3));
var y = g.reshape(y_,[]);
var yLabel = g.placeholder('y label', []);
var cost = g.meanSquaredCost(y, yLabel);
var sess = new dl.Session(g, math);
math.scope((keep, track) => {
    var xs = [];
    var ys = [];
    for (var i = 0; i < 100; i++)
    {
        var xr = Math.random();
        xs.push(track(dl.Scalar.new(xr)));
        ys.push(track(dl.Scalar.new(Math.exp(xr))));
    }
    var shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([xs, ys]);
    var [xProvider, yProvider] = shuffledInputProviderBuilder.getInputProviders();
    var NUM_BATCHES = 100;
    var BATCH_SIZE = xs.length;
    var LEARNING_RATE = 0.01;
    var optimizer = new dl.SGDOptimizer(LEARNING_RATE);
    var startTime = new Date();
    for (var i = 0; i < NUM_BATCHES; i++) 
    {
        var costValue = sess.train(cost, [{ 
            tensor: x, 
            data: xProvider
        }, { 
            tensor: yLabel,
            data: yProvider
        }], BATCH_SIZE, optimizer, dl.CostReduction.MEAN);
        var cost_val = costValue.get();
    }
    var endTime = new Date();
    var timeDiff = endTime.getTime() - startTime.getTime();
    var predicted = sess.eval(y, [{ 
        tensor: x,
        data: track(dl.Scalar.new(0.2))
    }]).getValues();
    mylog('--- prediction check ---');
    mylog('predicted : ' + predicted);
    mylog('truth     :' + Math.exp(0.2));
    mylog('');
    mylog('--- benchmark result ---');
    mylog('elasped time for training: ' + timeDiff / 1000 +'[sec]');
    var startTime = new Date();
    for (let i = 0; i < 1000; i++) 
    {
        var xdata = dl.Scalar.new(Math.random());
        var predicted = sess.eval(y, [{ 
            tensor: x, 
            data: track(xdata)
        }]).getValues();
    }
    var endTime = new Date();
    var timeDiff = endTime.getTime() - startTime.getTime();
    mylog('elasped time for prediction: ' + timeDiff / 1000.0 + '[msec/cycle]');
});




成果物

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0