LoginSignup
0
1

More than 5 years have passed since last update.

jsdoでdeeplearn.js その2

Last updated at Posted at 2017-08-20

概要

jsdoでdeeplearn.jsやってみた。
xorを学習させてみた。
結果が安定しない。どっか、間違ってる。

サンプルコード

var dl = deeplearn;
var gr = new dl.Graph();
var math = new dl.NDArrayMathCPU();

var inputShape = [2];
var inputTensor = gr.placeholder('input', inputShape);
var labelShape = [1];
var labelTensor = gr.placeholder('label', labelShape);

var w = gr.variable('w', dl.Array2D.randNormal([2, 20]));
var b = gr.variable('b', dl.Array1D.randNormal([20]));
var w1 = gr.variable('w1', dl.Array2D.randNormal([20, 1]));
var b1 = gr.variable('b', dl.Array1D.randNormal([1]));
var h = gr.relu(gr.add(gr.matmul(inputTensor, w), b));
var h1 = gr.relu(gr.add(gr.matmul(h, w1), b1));
var cost2 = gr.meanSquaredCost(h1, labelTensor);

var learningRate = 0.0001;
var optimizer = new dl.SGDOptimizer(learningRate);

var inputs = [dl.Array1D.new([1.0, 0.0]), dl.Array1D.new([0.0, 1.0]), dl.Array1D.new([0.0, 0.0]), dl.Array1D.new([1.0, 1.0])];
var labels = [dl.Array1D.new([1.0]), dl.Array1D.new([1.0]), dl.Array1D.new([0.0]), dl.Array1D.new([0.0])];
var shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([inputs, labels]);
var [inputProvider, labelProvider] = shuffledInputProviderBuilder.getInputProviders();
var feedEntries = [{
    tensor: inputTensor,
    data: inputProvider
}, {
    tensor: labelTensor, 
    data: labelProvider
}];
var sess = new dl.Session(gr, math);
var batchSize = 4;
for (var i = 0; i < 4001; i++)
{
    math.scope(() => {
        var cost = sess.train(cost2, feedEntries, batchSize, optimizer, dl.CostReduction.MEAN);
        if (i % 2000 == 0)  alert('last average cost: ' + cost.get());
    });
}
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([1.0, 1.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 1 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([1.0, 0.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 1 0 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([0.0, 1.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 0 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
    var testInput = track(dl.Array1D.new([0.0, 0.0]));
    var testFeedEntries = [{
        tensor: inputTensor, 
        data: testInput
    }];
    var testOutput = sess.eval(h1, testFeedEntries);
    alert('xor 0 0 -> ' + testOutput.getValues());
});



成果物

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1