概要
jsdoでdeeplearn.jsやってみた。
xorを学習させてみた。
結果が安定しない。どっか、間違ってる。
サンプルコード
var dl = deeplearn;
var gr = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var inputShape = [2];
var inputTensor = gr.placeholder('input', inputShape);
var labelShape = [1];
var labelTensor = gr.placeholder('label', labelShape);
var w = gr.variable('w', dl.Array2D.randNormal([2, 20]));
var b = gr.variable('b', dl.Array1D.randNormal([20]));
var w1 = gr.variable('w1', dl.Array2D.randNormal([20, 1]));
var b1 = gr.variable('b', dl.Array1D.randNormal([1]));
var h = gr.relu(gr.add(gr.matmul(inputTensor, w), b));
var h1 = gr.relu(gr.add(gr.matmul(h, w1), b1));
var cost2 = gr.meanSquaredCost(h1, labelTensor);
var learningRate = 0.0001;
var optimizer = new dl.SGDOptimizer(learningRate);
var inputs = [dl.Array1D.new([1.0, 0.0]), dl.Array1D.new([0.0, 1.0]), dl.Array1D.new([0.0, 0.0]), dl.Array1D.new([1.0, 1.0])];
var labels = [dl.Array1D.new([1.0]), dl.Array1D.new([1.0]), dl.Array1D.new([0.0]), dl.Array1D.new([0.0])];
var shuffledInputProviderBuilder = new dl.InCPUMemoryShuffledInputProviderBuilder([inputs, labels]);
var [inputProvider, labelProvider] = shuffledInputProviderBuilder.getInputProviders();
var feedEntries = [{
tensor: inputTensor,
data: inputProvider
}, {
tensor: labelTensor,
data: labelProvider
}];
var sess = new dl.Session(gr, math);
var batchSize = 4;
for (var i = 0; i < 4001; i++)
{
math.scope(() => {
var cost = sess.train(cost2, feedEntries, batchSize, optimizer, dl.CostReduction.MEAN);
if (i % 2000 == 0) alert('last average cost: ' + cost.get());
});
}
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([1.0, 1.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 1 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([1.0, 0.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 1 0 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([0.0, 1.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 0 1 -> ' + testOutput.getValues());
});
math.scope((keep, track) => {
var testInput = track(dl.Array1D.new([0.0, 0.0]));
var testFeedEntries = [{
tensor: inputTensor,
data: testInput
}];
var testOutput = sess.eval(h1, testFeedEntries);
alert('xor 0 0 -> ' + testOutput.getValues());
});
成果物