#概要
tensorflow.jsで手書き分類、やってみた。
#以下から、モデルお借りした。
#サンプルコード
var canvas = $("#canvas").get(0);
var touchableDevice = ('ontouchstart' in window);
if (canvas.getContext)
{
var context = canvas.getContext('2d');
var drawing = false;
var prev = {};
canvas.width = 2 * $("#canvas").width();
canvas.height = 2 * $("#canvas").height();
context.scale(2.0, 2.0);
context.lineJoin = "round";
context.lineCap = "round";
context.lineWidth = 20;
context.strokeStyle = 'rgb(0, 0, 0)';
$("#canvas").bind('touchstart mousedown', function(e) {
e.preventDefault();
prev = getPointOnCanvas(this, event, e);
drawing = true;
});
$("#canvas").bind('touchmove mousemove', function(e) {
if (drawing == false) return;
e.preventDefault();
curr = getPointOnCanvas(this, event, e);
context.beginPath();
context.moveTo(prev.x, prev.y);
context.lineTo(curr.x, curr.y);
context.stroke();
prev = curr;
});
$("#canvas").bind('touchend mouseup mouseleave', function(e) {
drawing = false;
});
var getPointOnCanvas = function(elem, windowEvent, touchEvent) {
return {
x : (touchableDevice ? windowEvent.changedTouches[0].clientX : touchEvent.clientX) - $(elem).offset().left,
y : (touchableDevice ? windowEvent.changedTouches[0].clientY : touchEvent.clientY) - $(elem).offset().top
};
};
$("#run_button").click(function() {
test_predict();
});
$("#delete_button").click(function() {
context.clearRect(0, 0, 280, 280);
});
var getImageBuffer = function(context, width, height) {
var tmpCanvas = $('<canvas>').get(0);
tmpCanvas.width = width;
tmpCanvas.height = height;
var tmpContext = tmpCanvas.getContext('2d');
tmpContext.drawImage(context.canvas, 0, 0, width, height);
var image = tmpContext.getImageData(0, 0, width, height);
var buffer = []
for (var i = 0; i < image.data.length; i += 4)
{
var sum = image.data[i + 0] + image.data[i + 1] + image.data[i + 2] + image.data[i + 3];
buffer.push(Math.min(sum, 255));
}
return buffer;
};
}
var test_predict;
tf.loadModel('https://rawgit.com/CreativeGP/tensorflowjs-mnist/master/model/model.json').then((model) => {
test_predict = function() {
var p = getImageBuffer(context, 28, 28);
const buffer = tf.buffer([1, 28, 28, 1]);
for (var i = 0; i < 28; i++)
{
for (var j = 0; j < 28; j++)
{
var s = i * 28 + j;
var v = p[s] / 255.0;
buffer.set(v, 0, i, j, 0);
}
}
const input = buffer.toTensor();
const predict = model.predict(input).argMax().dataSync();
alert(predict);
}
});
#成果物
#モデルを変更した。
#成果物
以上。