MNIST
jsdo
deeplearn.js

jsdoでdeeplearn.js その6

概要

jsdoでdeeplearn.jsやってみた。
tensorflowで学習したモデルを使ってみた。
手書き数字を推定してみた。

サンプルコード

var MANIFEST_FILE = '/assets/q/K/x/K/qKxKH';
function CheckpointLoader(urlPath) {
    this.urlPath = urlPath;
    if (this.urlPath.charAt(this.urlPath.length - 1) !== '/')
    {
        this.urlPath += '/';
    }
}
CheckpointLoader.prototype.loadManifest = function() {
    var _this = this;
    return new Promise(function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.open('GET', _this.urlPath + MANIFEST_FILE);
        xhr.onload = function() {
            _this.checkpointManifest = JSON.parse(xhr.responseText);
            resolve();
        };
        xhr.onerror = function(error) {
            throw new Error(MANIFEST_FILE + " not found at " + _this.urlPath + ". " + error);
        };
        xhr.send();
    });
};
CheckpointLoader.prototype.getCheckpointManifest = function() {
    var _this = this;
    if (this.checkpointManifest == null) 
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                resolve(_this.checkpointManifest);
            });
        });
    }
    return new Promise(function(resolve, reject) {
        resolve(_this.checkpointManifest);
    });
};
CheckpointLoader.prototype.getAllVariables = function() {
    var _this = this;
    if (this.variables != null) 
    {
        return new Promise(function(resolve, reject) {
            resolve(_this.variables);
        });
    }
    return new Promise(function(resolve, reject) {
        _this.getCheckpointManifest().then(function(checkpointDefinition) {
            var variableNames = Object.keys(_this.checkpointManifest);
            var variablePromises = [];
            for (var i = 0; i < variableNames.length; i++) 
            {
                variablePromises.push(_this.getVariable(variableNames[i]));
            }
            Promise.all(variablePromises).then(function(variables) {
                _this.variables = {};
                for (var i = 0; i < variables.length; i++)
                {
                    _this.variables[variableNames[i]] = variables[i];
                }
                resolve(_this.variables);
            });
        });
    });
};
CheckpointLoader.prototype.getVariable = function(varName) {
    var _this = this;
    if (!(varName in this.checkpointManifest))
    {
        throw new Error('Cannot load non-existant variable ' + varName);
    }
    var variableRequestPromiseMethod = function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.responseType = 'arraybuffer';
        var fname = _this.checkpointManifest[varName].filename;
        xhr.open('GET', _this.urlPath + fname);
        xhr.onload = function() {
            var values = new Float32Array(xhr.response);
            var ndarray = dl.NDArray.make(_this.checkpointManifest[varName].shape, {
                values: values 
            });
            resolve(ndarray);
        };
        xhr.onerror = function(error) {
            throw new Error('Could not fetch variable ' + varName + ': ' + error);
        };
        xhr.send();
    };
    if (this.checkpointManifest == null)
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                new Promise(variableRequestPromiseMethod).then(resolve);
            });
        });
    }
    return new Promise(variableRequestPromiseMethod);
};
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var vars2;
var reader = new CheckpointLoader('http://jsrun.it');
reader.getAllVariables().then(function(vars) {
    vars2 = vars;
    var xhr = new XMLHttpRequest();
    xhr.open('GET', 'http://jsrun.it/assets/O/U/w/U/OUwUs');
    xhr.onload = function() {
        var data = JSON.parse(xhr.responseText);
        var _a = buildModelLayersAPI(data, vars),
            input = _a[0],
            probs = _a[1];
        var sess = new dl.Session(input.node.graph, math);
        math.scope(function() {
            var numCorrect = 0;
            for (var i = 0; i < data.images.length; i++) 
            {
                var inputData = dl.Array1D.new(data.images[i]);
                var probsVal = sess.eval(probs, [{
                    tensor: input,
                    data: inputData
                }]);
                if (data.labels[i] === probsVal.get())
                {
                    numCorrect++;
                }
            }
            var accuracy = numCorrect * 100 / data.images.length;
            document.getElementById('helloWorld').innerHTML = accuracy + '%';
        });
    };
    xhr.onerror = function (err) {
        return alert(err); 
    };
    xhr.send();
});
function buildModelMathAPI(math, data, vars) {
    var hidden1W = vars['hidden1/weights'];
    var hidden1B = vars['hidden1/biases'];
    var hidden2W = vars['hidden2/weights'];
    var hidden2B = vars['hidden2/biases'];
    var softmaxW = vars['softmax_linear/weights'];
    var softmaxB = vars['softmax_linear/biases'];
    return function(x) {
        return math.scope(function() {
            var hidden1 = math.relu(math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B));
            var hidden2 = math.relu(math.add(math.vectorTimesMatrix(hidden1, hidden2W), hidden2B));
            var logits = math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB);
            return math.argMax(logits);
        });
    };
}
function buildModelGraphAPI(data, vars) {
    var input = g.placeholder('input', [784]);
    var hidden1W = g.constant(vars['hidden1/weights']);
    var hidden1B = g.constant(vars['hidden1/biases']);
    var hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B));
    var hidden2W = g.constant(vars['hidden2/weights']);
    var hidden2B = g.constant(vars['hidden2/biases']);
    var hidden2 = g.relu(g.add(g.matmul(hidden1, hidden2W), hidden2B));
    var softmaxW = g.constant(vars['softmax_linear/weights']);
    var softmaxB = g.constant(vars['softmax_linear/biases']);
    var logits = g.add(g.matmul(hidden2, softmaxW), softmaxB);
    return [input, g.argmax(logits)];
}
function buildModelLayersAPI(data, vars) {
    var input = g.placeholder('input', [784]);
    var hidden1W = vars['hidden1/weights'];
    var hidden1B = vars['hidden1/biases'];
    var hidden1 = g.layers.dense('hidden1', input, hidden1W.shape[1], function(x) {
        return g.relu(x);
    }, true, new dl.NDArrayInitializer(hidden1W), new dl.NDArrayInitializer(hidden1B));
    var hidden2W = vars['hidden2/weights'];
    var hidden2B = vars['hidden2/biases'];
    var hidden2 = g.layers.dense('hidden2', hidden1, hidden2W.shape[1], function(x) {
        return g.relu(x);
    }, true, new dl.NDArrayInitializer(hidden2W), new dl.NDArrayInitializer(hidden2B));
    var softmaxW = vars['softmax_linear/weights'];
    var softmaxB = vars['softmax_linear/biases'];
    var logits = g.layers.dense('softmax', hidden2, softmaxW.shape[1], null, true, new dl.NDArrayInitializer(softmaxW), new dl.NDArrayInitializer(softmaxB));
    return [input, g.argmax(logits)];
}

var canvas = $("#canvas").get(0);
var touchableDevice = ('ontouchstart' in window);
if (canvas.getContext)
{
    var context = canvas.getContext('2d');
    var drawing = false;
    var prev = {};
    canvas.width = 2 * $("#canvas").width();
    canvas.height = 2 * $("#canvas").height();
    context.scale(2.0, 2.0);
    context.lineJoin = "round";
    context.lineCap = "round";
    context.lineWidth = 20;
    context.strokeStyle = 'rgb(0, 0, 0)';
    $("#canvas").bind('touchstart mousedown', function(e) {
        e.preventDefault();
        prev = getPointOnCanvas(this, event, e);
        drawing = true;
    });
    $("#canvas").bind('touchmove mousemove', function(e) {
        if (drawing == false) return;
        e.preventDefault();
        curr = getPointOnCanvas(this, event, e);
        context.beginPath();
        context.moveTo(prev.x, prev.y);
        context.lineTo(curr.x, curr.y);
        context.stroke();
        prev = curr;
    });
    $("#canvas").bind('touchend mouseup mouseleave', function(e) {
        drawing = false;
    });
    var getPointOnCanvas = function(elem, windowEvent, touchEvent) {
        return {
            x : (touchableDevice ? windowEvent.changedTouches[0].clientX : touchEvent.clientX) - $(elem).offset().left,
            y : (touchableDevice ? windowEvent.changedTouches[0].clientY : touchEvent.clientY) - $(elem).offset().top
        };
    };
    $("#run_button").click(function() {
        test();
    });
    $("#delete_button").click(function() {
        context.clearRect(0, 0, 280, 280);
    });
    var getImageBuffer = function(context, width, height) {
        var tmpCanvas = $('<canvas>').get(0);
        tmpCanvas.width = width;
        tmpCanvas.height = height;
        var tmpContext = tmpCanvas.getContext('2d');
        tmpContext.drawImage(context.canvas, 0, 0, width, height);
        var image = tmpContext.getImageData(0, 0, width, height);
        var buffer = []
        for (var i = 0; i < image.data.length; i += 4) 
        {
            var sum = image.data[i + 0] + image.data[i + 1] + image.data[i + 2] + image.data[i + 3];
            buffer.push(Math.min(sum, 255));
        }
        return buffer;
    };
}
var test = function() {
    var p = getImageBuffer(context, 28, 28);
    var _a = buildModelLayersAPI(p, vars2),
        input = _a[0],
        probs = _a[1];
    var sess = new dl.Session(input.node.graph, math);
    math.scope(function() {
        var inputData = dl.Array1D.new(p);
        var probsVal = sess.eval(probs, [{
            tensor: input,
            data: inputData
        }]);
        alert(probsVal.get());
    });
}




成果物

http://jsdo.it/ohisama1/kHXF