概要

tensorflow.jsで手書き分類、やってみた。

以下から、モデルお借りした。

https://github.com/CreativeGP/tensorflowjs-mnist

サンプルコード

var canvas = $("#canvas").get(0);
var touchableDevice = ('ontouchstart' in window);
if (canvas.getContext)
{
    var context = canvas.getContext('2d');
    var drawing = false;
    var prev = {};
    canvas.width = 2 * $("#canvas").width();
    canvas.height = 2 * $("#canvas").height();
    context.scale(2.0, 2.0);
    context.lineJoin = "round";
    context.lineCap = "round";
    context.lineWidth = 20;
    context.strokeStyle = 'rgb(0, 0, 0)';
    $("#canvas").bind('touchstart mousedown', function(e) {
        e.preventDefault();
        prev = getPointOnCanvas(this, event, e);
        drawing = true;
    });
    $("#canvas").bind('touchmove mousemove', function(e) {
        if (drawing == false) return;
        e.preventDefault();
        curr = getPointOnCanvas(this, event, e);
        context.beginPath();
        context.moveTo(prev.x, prev.y);
        context.lineTo(curr.x, curr.y);
        context.stroke();
        prev = curr;
    });
    $("#canvas").bind('touchend mouseup mouseleave', function(e) {
        drawing = false;
    });
    var getPointOnCanvas = function(elem, windowEvent, touchEvent) {
        return {
            x : (touchableDevice ? windowEvent.changedTouches[0].clientX : touchEvent.clientX) - $(elem).offset().left,
            y : (touchableDevice ? windowEvent.changedTouches[0].clientY : touchEvent.clientY) - $(elem).offset().top
        };
    };
    $("#run_button").click(function() {
        test_predict();
    });
    $("#delete_button").click(function() {
        context.clearRect(0, 0, 280, 280);
    });
    var getImageBuffer = function(context, width, height) {
        var tmpCanvas = $('<canvas>').get(0);
        tmpCanvas.width = width;
        tmpCanvas.height = height;
        var tmpContext = tmpCanvas.getContext('2d');
        tmpContext.drawImage(context.canvas, 0, 0, width, height);
        var image = tmpContext.getImageData(0, 0, width, height);
        var buffer = []
        for (var i = 0; i < image.data.length; i += 4)
        {
            var sum = image.data[i + 0] + image.data[i + 1] + image.data[i + 2] + image.data[i + 3];
            buffer.push(Math.min(sum, 255));
        }
        return buffer;
    };
}
var test_predict;
tf.loadModel('https://rawgit.com/CreativeGP/tensorflowjs-mnist/master/model/model.json').then((model) => {
    test_predict = function() {
        var p = getImageBuffer(context, 28, 28);
        const buffer = tf.buffer([1, 28, 28, 1]);
        for (var i = 0; i < 28; i++) 
        {
            for (var j = 0; j < 28; j++)
            {
                var s = i * 28 + j;
                var v = p[s] / 255.0;
                buffer.set(v, 0, i, j, 0);
            }
        }
        const input = buffer.toTensor();
        const predict = model.predict(input).argMax().dataSync();
        alert(predict);
    }
});

成果物

http://jsdo.it/ohisama1/UrzR

モデルを変更した。

https://github.com/yukagil/tfjs-mnist-cnn-demo

成果物

http://jsdo.it/ohisama1/SUxj

以上。

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.