LoginSignup
0
0

More than 5 years have passed since last update.

jsdoでdeeplearn.js その4

Last updated at Posted at 2017-08-23

概要

jsdoでdeeplearn.jsやってみた。
xorを学習させてみた。
リベンジした。学習しても、セーブする術が無い。

サンプルコード

var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();

var inputTensor = g.placeholder('input', [2]);
var labelTensor = g.placeholder('label', [1]);

var w = g.variable('w', dl.Array2D.randNormal([2, 2]));
var w1 = g.variable('w1', dl.Array2D.randNormal([2, 1]));
var b = g.variable('b', dl.Array1D.randNormal([2]));
var b1 = g.variable('b', dl.Array1D.randNormal([1]));
var h = g.sigmoid(g.add(g.matmul(inputTensor, w), b));
var h1 = g.sigmoid(g.add(g.matmul(h, w1), b1));

var s0 = g.multiply(labelTensor, g.log(h1));
var s1 = g.subtract(g.constant(1.0), labelTensor);
var s2 = g.log(g.subtract(g.constant(1.0), h1));
s1 = g.multiply(s1, s2);
var cost2 = g.reduceSum(g.add(s0, s1));
cost2 = g.subtract(g.constant(0), cost2);

var learningRate = 0.1;
var optimizer = new dl.SGDOptimizer(learningRate);

var inputs = [dl.Array1D.new([1.0, 0.0]), dl.Array1D.new([0.0, 1.0]), dl.Array1D.new([0.0, 0.0]), dl.Array1D.new([1.0, 1.0])];
var labels = [dl.Array1D.new([1.0]), dl.Array1D.new([1.0]), dl.Array1D.new([0.0]), dl.Array1D.new([0.0])];
var shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([inputs, labels]);
var [inputProvider, labelProvider] = shuffledInputProviderBuilder.getInputProviders();
var feedEntries = [{
    tensor: inputTensor,
    data: inputProvider
}, {
    tensor: labelTensor, 
    data: labelProvider
}];
var sess = new dl.Session(g, math);
var batchSize = 2;
for (var i = 0; i < 4001; i++)
{
    math.scope(() => {
        var cost = sess.train(cost2, feedEntries, batchSize, optimizer, dl.CostReduction.MEAN);
        if (i % 2000 == 0)  alert('last average cost: ' + cost.get());
    });
}
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([1.0, 1.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 1 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([1.0, 0.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 1 0 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([0.0, 1.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 0 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([0.0, 0.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 0 0 -> ' + testOutput.getValues());
});

成果物

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0