LoginSignup
0
0

More than 5 years have passed since last update.

jsdoでdeeplearn.js その6

Last updated at Posted at 2017-09-26

概要

jsdoでdeeplearn.jsやってみた。
tensorflowで学習したモデルを使ってみた。
手書き数字を推定してみた。

サンプルコード

var MANIFEST_FILE = '/assets/q/K/x/K/qKxKH';
function CheckpointLoader(urlPath) {
    this.urlPath = urlPath;
    if (this.urlPath.charAt(this.urlPath.length - 1) !== '/')
    {
        this.urlPath += '/';
    }
}
CheckpointLoader.prototype.loadManifest = function() {
    var _this = this;
    return new Promise(function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.open('GET', _this.urlPath + MANIFEST_FILE);
        xhr.onload = function() {
            _this.checkpointManifest = JSON.parse(xhr.responseText);
            resolve();
        };
        xhr.onerror = function(error) {
            throw new Error(MANIFEST_FILE + " not found at " + _this.urlPath + ". " + error);
        };
        xhr.send();
    });
};
CheckpointLoader.prototype.getCheckpointManifest = function() {
    var _this = this;
    if (this.checkpointManifest == null) 
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                resolve(_this.checkpointManifest);
            });
        });
    }
    return new Promise(function(resolve, reject) {
        resolve(_this.checkpointManifest);
    });
};
CheckpointLoader.prototype.getAllVariables = function() {
    var _this = this;
    if (this.variables != null) 
    {
        return new Promise(function(resolve, reject) {
            resolve(_this.variables);
        });
    }
    return new Promise(function(resolve, reject) {
        _this.getCheckpointManifest().then(function(checkpointDefinition) {
            var variableNames = Object.keys(_this.checkpointManifest);
            var variablePromises = [];
            for (var i = 0; i < variableNames.length; i++) 
            {
                variablePromises.push(_this.getVariable(variableNames[i]));
            }
            Promise.all(variablePromises).then(function(variables) {
                _this.variables = {};
                for (var i = 0; i < variables.length; i++)
                {
                    _this.variables[variableNames[i]] = variables[i];
                }
                resolve(_this.variables);
            });
        });
    });
};
CheckpointLoader.prototype.getVariable = function(varName) {
    var _this = this;
    if (!(varName in this.checkpointManifest))
    {
        throw new Error('Cannot load non-existant variable ' + varName);
    }
    var variableRequestPromiseMethod = function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.responseType = 'arraybuffer';
        var fname = _this.checkpointManifest[varName].filename;
        xhr.open('GET', _this.urlPath + fname);
        xhr.onload = function() {
            var values = new Float32Array(xhr.response);
            var ndarray = dl.NDArray.make(_this.checkpointManifest[varName].shape, {
                values: values 
            });
            resolve(ndarray);
        };
        xhr.onerror = function(error) {
            throw new Error('Could not fetch variable ' + varName + ': ' + error);
        };
        xhr.send();
    };
    if (this.checkpointManifest == null)
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                new Promise(variableRequestPromiseMethod).then(resolve);
            });
        });
    }
    return new Promise(variableRequestPromiseMethod);
};
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var vars2;
var reader = new CheckpointLoader('http://jsrun.it');
reader.getAllVariables().then(function(vars) {
    vars2 = vars;
    var xhr = new XMLHttpRequest();
    xhr.open('GET', 'http://jsrun.it/assets/O/U/w/U/OUwUs');
    xhr.onload = function() {
        var data = JSON.parse(xhr.responseText);
        var _a = buildModelLayersAPI(data, vars),
            input = _a[0],
            probs = _a[1];
        var sess = new dl.Session(input.node.graph, math);
        math.scope(function() {
            var numCorrect = 0;
            for (var i = 0; i < data.images.length; i++) 
            {
                var inputData = dl.Array1D.new(data.images[i]);
                var probsVal = sess.eval(probs, [{
                    tensor: input,
                    data: inputData
                }]);
                if (data.labels[i] === probsVal.get())
                {
                    numCorrect++;
                }
            }
            var accuracy = numCorrect * 100 / data.images.length;
            document.getElementById('helloWorld').innerHTML = accuracy + '%';
        });
    };
    xhr.onerror = function (err) {
        return alert(err); 
    };
    xhr.send();
});
function buildModelMathAPI(math, data, vars) {
    var hidden1W = vars['hidden1/weights'];
    var hidden1B = vars['hidden1/biases'];
    var hidden2W = vars['hidden2/weights'];
    var hidden2B = vars['hidden2/biases'];
    var softmaxW = vars['softmax_linear/weights'];
    var softmaxB = vars['softmax_linear/biases'];
    return function(x) {
        return math.scope(function() {
            var hidden1 = math.relu(math.add(math.vectorTimesMatrix(x, hidden1W), hidden1B));
            var hidden2 = math.relu(math.add(math.vectorTimesMatrix(hidden1, hidden2W), hidden2B));
            var logits = math.add(math.vectorTimesMatrix(hidden2, softmaxW), softmaxB);
            return math.argMax(logits);
        });
    };
}
function buildModelGraphAPI(data, vars) {
    var input = g.placeholder('input', [784]);
    var hidden1W = g.constant(vars['hidden1/weights']);
    var hidden1B = g.constant(vars['hidden1/biases']);
    var hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B));
    var hidden2W = g.constant(vars['hidden2/weights']);
    var hidden2B = g.constant(vars['hidden2/biases']);
    var hidden2 = g.relu(g.add(g.matmul(hidden1, hidden2W), hidden2B));
    var softmaxW = g.constant(vars['softmax_linear/weights']);
    var softmaxB = g.constant(vars['softmax_linear/biases']);
    var logits = g.add(g.matmul(hidden2, softmaxW), softmaxB);
    return [input, g.argmax(logits)];
}
function buildModelLayersAPI(data, vars) {
    var input = g.placeholder('input', [784]);
    var hidden1W = vars['hidden1/weights'];
    var hidden1B = vars['hidden1/biases'];
    var hidden1 = g.layers.dense('hidden1', input, hidden1W.shape[1], function(x) {
        return g.relu(x);
    }, true, new dl.NDArrayInitializer(hidden1W), new dl.NDArrayInitializer(hidden1B));
    var hidden2W = vars['hidden2/weights'];
    var hidden2B = vars['hidden2/biases'];
    var hidden2 = g.layers.dense('hidden2', hidden1, hidden2W.shape[1], function(x) {
        return g.relu(x);
    }, true, new dl.NDArrayInitializer(hidden2W), new dl.NDArrayInitializer(hidden2B));
    var softmaxW = vars['softmax_linear/weights'];
    var softmaxB = vars['softmax_linear/biases'];
    var logits = g.layers.dense('softmax', hidden2, softmaxW.shape[1], null, true, new dl.NDArrayInitializer(softmaxW), new dl.NDArrayInitializer(softmaxB));
    return [input, g.argmax(logits)];
}

var canvas = $("#canvas").get(0);
var touchableDevice = ('ontouchstart' in window);
if (canvas.getContext)
{
    var context = canvas.getContext('2d');
    var drawing = false;
    var prev = {};
    canvas.width = 2 * $("#canvas").width();
    canvas.height = 2 * $("#canvas").height();
    context.scale(2.0, 2.0);
    context.lineJoin = "round";
    context.lineCap = "round";
    context.lineWidth = 20;
    context.strokeStyle = 'rgb(0, 0, 0)';
    $("#canvas").bind('touchstart mousedown', function(e) {
        e.preventDefault();
        prev = getPointOnCanvas(this, event, e);
        drawing = true;
    });
    $("#canvas").bind('touchmove mousemove', function(e) {
        if (drawing == false) return;
        e.preventDefault();
        curr = getPointOnCanvas(this, event, e);
        context.beginPath();
        context.moveTo(prev.x, prev.y);
        context.lineTo(curr.x, curr.y);
        context.stroke();
        prev = curr;
    });
    $("#canvas").bind('touchend mouseup mouseleave', function(e) {
        drawing = false;
    });
    var getPointOnCanvas = function(elem, windowEvent, touchEvent) {
        return {
            x : (touchableDevice ? windowEvent.changedTouches[0].clientX : touchEvent.clientX) - $(elem).offset().left,
            y : (touchableDevice ? windowEvent.changedTouches[0].clientY : touchEvent.clientY) - $(elem).offset().top
        };
    };
    $("#run_button").click(function() {
        test();
    });
    $("#delete_button").click(function() {
        context.clearRect(0, 0, 280, 280);
    });
    var getImageBuffer = function(context, width, height) {
        var tmpCanvas = $('<canvas>').get(0);
        tmpCanvas.width = width;
        tmpCanvas.height = height;
        var tmpContext = tmpCanvas.getContext('2d');
        tmpContext.drawImage(context.canvas, 0, 0, width, height);
        var image = tmpContext.getImageData(0, 0, width, height);
        var buffer = []
        for (var i = 0; i < image.data.length; i += 4) 
        {
            var sum = image.data[i + 0] + image.data[i + 1] + image.data[i + 2] + image.data[i + 3];
            buffer.push(Math.min(sum, 255));
        }
        return buffer;
    };
}
var test = function() {
    var p = getImageBuffer(context, 28, 28);
    var _a = buildModelLayersAPI(p, vars2),
        input = _a[0],
        probs = _a[1];
    var sess = new dl.Session(input.node.graph, math);
    math.scope(function() {
        var inputData = dl.Array1D.new(p);
        var probsVal = sess.eval(probs, [{
            tensor: input,
            data: inputData
        }]);
        alert(probsVal.get());
    });
}




成果物

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0