概要
jsdoでdeeplearn.jsやってみた。
tensorflowで学習したモデルを使ってみた。
手書き数字を推定してみた。
サンプルコード
var MANIFEST_FILE = '/assets/q/K/x/K/qKxKH';
function CheckpointLoader(urlPath) {
this.urlPath = urlPath;
if (this.urlPath.charAt(this.urlPath.length - 1) !== '/')
{
this.urlPath += '/';
}
}
CheckpointLoader.prototype.loadManifest = function() {
var _this = this;
return new Promise(function(resolve, reject) {
var xhr = new XMLHttpRequest();
xhr.open('GET', _this.urlPath + MANIFEST_FILE);
xhr.onload = function() {
_this.checkpointManifest = JSON.parse(xhr.responseText);
resolve();
};
xhr.onerror = function(error) {
throw new Error(MANIFEST_FILE + " not found at " + _this.urlPath + ". " + error);
};
xhr.send();
});
};
CheckpointLoader.prototype.getCheckpointManifest = function() {
var _this = this;
if (this.checkpointManifest == null)
{
return new Promise(function(resolve, reject) {
_this.loadManifest().then(function() {
resolve(_this.checkpointManifest);
});
});
}
return new Promise(function(resolve, reject) {
resolve(_this.checkpointManifest);
});
};
CheckpointLoader.prototype.getAllVariables = function() {
var _this = this;
if (this.variables != null)
{
return new Promise(function(resolve, reject) {
resolve(_this.variables);
});
}
return new Promise(function(resolve, reject) {
_this.getCheckpointManifest().then(function(checkpointDefinition) {
var variableNames = Object.keys(_this.checkpointManifest);
var variablePromises = [];
for (var i = 0; i < variableNames.length; i++)
{
variablePromises.push(_this.getVariable(variableNames[i]));
}
Promise.all(variablePromises).then(function(variables) {
_this.variables = {};
for (var i = 0; i < variables.length; i++)
{
_this.variables[variableNames[i]] = variables[i];
}
resolve(_this.variables);
});
});
});
};
CheckpointLoader.prototype.getVariable = function(varName) {
var _this = this;
if (!(varName in this.checkpointManifest))
{
throw new Error('Cannot load non-existant variable ' + varName);
}
var variableRequestPromiseMethod = function(resolve, reject) {
var xhr = new XMLHttpRequest();
xhr.responseType = 'arraybuffer';
var fname = _this.checkpointManifest[varName].filename;
xhr.open('GET', _this.urlPath + fname);
xhr.onload = function() {
var values = new Float32Array(xhr.response);
var ndarray = dl.NDArray.make(_this.checkpointManifest[varName].shape, {
values: values
});
resolve(ndarray);
};
xhr.onerror = function(error) {
throw new Error('Could not fetch variable ' + varName + ': ' + error);
};
xhr.send();
};
if (this.checkpointManifest == null)
{
return new Promise(function(resolve, reject) {
_this.loadManifest().then(function() {
new Promise(variableRequestPromiseMethod).then(resolve);
});
});
}
return new Promise(variableRequestPromiseMethod);
};
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var vars2;
var reader = new CheckpointLoader('http://jsrun.it');
reader.getAllVariables().then(function(vars) {
vars2 = vars;
var xhr = new XMLHttpRequest();
xhr.open('GET', 'http://jsrun.it/assets/O/U/w/U/OUwUs');
xhr.onload = function() {
var data = JSON.parse(xhr.responseText);
var _a = buildModelLayersAPI(data, vars),
input = _a[0],
probs = _a[1];
var sess = new dl.Session(input.node.graph, math);
math.scope(function() {
var numCorrect = 0;
for (var i = 0; i < data.images.length; i++)
{
var inputData = dl.Array1D.new(data.images[i]);
var probsVal = sess.eval(probs, [{
tensor: input,
data: inputData
}]);
if (data.labels[i] === probsVal.get())
{
numCorrect++;
}
}
var accuracy = numCorrect * 100 / data.images.length;
document.getElementById('helloWorld').innerHTML = accuracy + '%';
});
};
xhr.onerror = function (err) {
return alert(err);
};
xhr.send();
});
function buildModelMathAPI(math, data, vars) {
var hidden1W = vars['hidden1/weights'];
var hidden1B = vars['hidden1/biases'];
var hidden2W = vars['hidden2/weights'];
var hidden2B = vars['hidden2/biases'];
var softmaxW = vars['softmax_linear/weights'];
var softmaxB = vars['softmax_linear/biases'];
return function(x) {
return math.scope(function() {
var hidden1 = math.relu(math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B));
var hidden2 = math.relu(math.add(math.vectorTimesMatrix(hidden1, hidden2W), hidden2B));
var logits = math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB);
return math.argMax(logits);
});
};
}
function buildModelGraphAPI(data, vars) {
var input = g.placeholder('input', [784]);
var hidden1W = g.constant(vars['hidden1/weights']);
var hidden1B = g.constant(vars['hidden1/biases']);
var hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B));
var hidden2W = g.constant(vars['hidden2/weights']);
var hidden2B = g.constant(vars['hidden2/biases']);
var hidden2 = g.relu(g.add(g.matmul(hidden1, hidden2W), hidden2B));
var softmaxW = g.constant(vars['softmax_linear/weights']);
var softmaxB = g.constant(vars['softmax_linear/biases']);
var logits = g.add(g.matmul(hidden2, softmaxW), softmaxB);
return [input, g.argmax(logits)];
}
function buildModelLayersAPI(data, vars) {
var input = g.placeholder('input', [784]);
var hidden1W = vars['hidden1/weights'];
var hidden1B = vars['hidden1/biases'];
var hidden1 = g.layers.dense('hidden1', input, hidden1W.shape[1], function(x) {
return g.relu(x);
}, true, new dl.NDArrayInitializer(hidden1W), new dl.NDArrayInitializer(hidden1B));
var hidden2W = vars['hidden2/weights'];
var hidden2B = vars['hidden2/biases'];
var hidden2 = g.layers.dense('hidden2', hidden1, hidden2W.shape[1], function(x) {
return g.relu(x);
}, true, new dl.NDArrayInitializer(hidden2W), new dl.NDArrayInitializer(hidden2B));
var softmaxW = vars['softmax_linear/weights'];
var softmaxB = vars['softmax_linear/biases'];
var logits = g.layers.dense('softmax', hidden2, softmaxW.shape[1], null, true, new dl.NDArrayInitializer(softmaxW), new dl.NDArrayInitializer(softmaxB));
return [input, g.argmax(logits)];
}
var canvas = $("#canvas").get(0);
var touchableDevice = ('ontouchstart' in window);
if (canvas.getContext)
{
var context = canvas.getContext('2d');
var drawing = false;
var prev = {};
canvas.width = 2 * $("#canvas").width();
canvas.height = 2 * $("#canvas").height();
context.scale(2.0, 2.0);
context.lineJoin = "round";
context.lineCap = "round";
context.lineWidth = 20;
context.strokeStyle = 'rgb(0, 0, 0)';
$("#canvas").bind('touchstart mousedown', function(e) {
e.preventDefault();
prev = getPointOnCanvas(this, event, e);
drawing = true;
});
$("#canvas").bind('touchmove mousemove', function(e) {
if (drawing == false) return;
e.preventDefault();
curr = getPointOnCanvas(this, event, e);
context.beginPath();
context.moveTo(prev.x, prev.y);
context.lineTo(curr.x, curr.y);
context.stroke();
prev = curr;
});
$("#canvas").bind('touchend mouseup mouseleave', function(e) {
drawing = false;
});
var getPointOnCanvas = function(elem, windowEvent, touchEvent) {
return {
x : (touchableDevice ? windowEvent.changedTouches[0].clientX : touchEvent.clientX) - $(elem).offset().left,
y : (touchableDevice ? windowEvent.changedTouches[0].clientY : touchEvent.clientY) - $(elem).offset().top
};
};
$("#run_button").click(function() {
test();
});
$("#delete_button").click(function() {
context.clearRect(0, 0, 280, 280);
});
var getImageBuffer = function(context, width, height) {
var tmpCanvas = $('<canvas>').get(0);
tmpCanvas.width = width;
tmpCanvas.height = height;
var tmpContext = tmpCanvas.getContext('2d');
tmpContext.drawImage(context.canvas, 0, 0, width, height);
var image = tmpContext.getImageData(0, 0, width, height);
var buffer = []
for (var i = 0; i < image.data.length; i += 4)
{
var sum = image.data[i + 0] + image.data[i + 1] + image.data[i + 2] + image.data[i + 3];
buffer.push(Math.min(sum, 255));
}
return buffer;
};
}
var test = function() {
var p = getImageBuffer(context, 28, 28);
var _a = buildModelLayersAPI(p, vars2),
input = _a[0],
probs = _a[1];
var sess = new dl.Session(input.node.graph, math);
math.scope(function() {
var inputData = dl.Array1D.new(p);
var probsVal = sess.eval(probs, [{
tensor: input,
data: inputData
}]);
alert(probsVal.get());
});
}