概要
jsdoでdeeplearn.jsやってみた。
xorを学習させてみた。
リベンジした。学習しても、セーブする術が無い。
サンプルコード
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var inputTensor = g.placeholder('input', [2]);
var labelTensor = g.placeholder('label', [1]);
var w = g.variable('w', dl.Array2D.randNormal([2, 2]));
var w1 = g.variable('w1', dl.Array2D.randNormal([2, 1]));
var b = g.variable('b', dl.Array1D.randNormal([2]));
var b1 = g.variable('b', dl.Array1D.randNormal([1]));
var h = g.sigmoid(g.add(g.matmul(inputTensor, w), b));
var h1 = g.sigmoid(g.add(g.matmul(h, w1), b1));
var s0 = g.multiply(labelTensor, g.log(h1));
var s1 = g.subtract(g.constant(1.0), labelTensor);
var s2 = g.log(g.subtract(g.constant(1.0), h1));
s1 = g.multiply(s1, s2);
var cost2 = g.reduceSum(g.add(s0, s1));
cost2 = g.subtract(g.constant(0), cost2);
var learningRate = 0.1;
var optimizer = new dl.SGDOptimizer(learningRate);
var inputs = [dl.Array1D.new([1.0, 0.0]), dl.Array1D.new([0.0, 1.0]), dl.Array1D.new([0.0, 0.0]), dl.Array1D.new([1.0, 1.0])];
var labels = [dl.Array1D.new([1.0]), dl.Array1D.new([1.0]), dl.Array1D.new([0.0]), dl.Array1D.new([0.0])];
var shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([inputs, labels]);
var [inputProvider, labelProvider] = shuffledInputProviderBuilder.getInputProviders();
var feedEntries = [{
tensor: inputTensor,
data: inputProvider
}, {
tensor: labelTensor,
data: labelProvider
}];
var sess = new dl.Session(g, math);
var batchSize = 2;
for (var i = 0; i < 4001; i++)
{
math.scope(() => {
var cost = sess.train(cost2, feedEntries, batchSize, optimizer, dl.CostReduction.MEAN);
if (i % 2000 == 0) alert('last average cost: ' + cost.get());
});
}
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([1.0, 1.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 1 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([1.0, 0.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 1 0 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([0.0, 1.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 0 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([0.0, 0.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 0 0 -> ' + testOutput.getValues());
});