jsdo
deeplearn.js

jsdoでdeeplearn.js その4

More than 1 year has passed since last update.

概要

jsdoでdeeplearn.jsやってみた。
xorを学習させてみた。
リベンジした。学習しても、セーブする術が無い。

サンプルコード

var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();

var inputTensor = g.placeholder('input', [2]);
var labelTensor = g.placeholder('label', [1]);

var w = g.variable('w', dl.Array2D.randNormal([2, 2]));
var w1 = g.variable('w1', dl.Array2D.randNormal([2, 1]));
var b = g.variable('b', dl.Array1D.randNormal([2]));
var b1 = g.variable('b', dl.Array1D.randNormal([1]));
var h = g.sigmoid(g.add(g.matmul(inputTensor, w), b));
var h1 = g.sigmoid(g.add(g.matmul(h, w1), b1));

var s0 = g.multiply(labelTensor, g.log(h1));
var s1 = g.subtract(g.constant(1.0), labelTensor);
var s2 = g.log(g.subtract(g.constant(1.0), h1));
s1 = g.multiply(s1, s2);
var cost2 = g.reduceSum(g.add(s0, s1));
cost2 = g.subtract(g.constant(0), cost2);

var learningRate = 0.1;
var optimizer = new dl.SGDOptimizer(learningRate);

var inputs = [dl.Array1D.new([1.0, 0.0]), dl.Array1D.new([0.0, 1.0]), dl.Array1D.new([0.0, 0.0]), dl.Array1D.new([1.0, 1.0])];
var labels = [dl.Array1D.new([1.0]), dl.Array1D.new([1.0]), dl.Array1D.new([0.0]), dl.Array1D.new([0.0])];
var shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([inputs, labels]);
var [inputProvider, labelProvider] = shuffledInputProviderBuilder.getInputProviders();
var feedEntries = [{
    tensor: inputTensor,
    data: inputProvider
}, {
    tensor: labelTensor, 
    data: labelProvider
}];
var sess = new dl.Session(g, math);
var batchSize = 2;
for (var i = 0; i < 4001; i++)
{
    math.scope(() => {
        var cost = sess.train(cost2, feedEntries, batchSize, optimizer, dl.CostReduction.MEAN);
        if (i % 2000 == 0)  alert('last average cost: ' + cost.get());
    });
}
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([1.0, 1.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 1 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([1.0, 0.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 1 0 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([0.0, 1.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 0 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([0.0, 0.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 0 0 -> ' + testOutput.getValues());
});

成果物

http://jsdo.it/ohisama1/QTRl