LoginSignup
2
0

More than 5 years have passed since last update.

keras_to_tensorflow その2

Last updated at Posted at 2017-12-09

概要

kerasのモデルをtensorflowに、変換して、deeplearn.jsで、使ってみた。
手順を記載する。

環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.2

kerasのモデルをtensorflowに、変換したpbファイルから、manifesit.jsonとversのファイルをつくるコード。

import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import string
import json

FILENAME_CHARS = string.ascii_letters + string.digits + '_'
def _var_name_to_filename(var_name):
    chars = []
    for c in var_name:
        if c in FILENAME_CHARS:
            chars.append(c)
        elif c == '/':
            chars.append('_')
    return ''.join(chars)

chk_fpath = "./"
output_dir = "./deep"
tf.gfile.MakeDirs(output_dir)
manifest = {}
var_filenames_strs = []

with gfile.FastGFile("keras13.pb", 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name = '')
    for n0 in graph_def.node:
        print (n0.op)
        if n0.op == 'Const':
            print (n0.name)
            tensor = n0.attr['value'].tensor
            size = len(tensor.tensor_content)
            print (size)
            size2 = len(tensor.tensor_shape.dim)
            #print(tensor.tensor_shape.dim[0].size)
            #print(tensor.tensor_shape.dim[1].size)
            if size > 0:
                name = n0.name
                var_filename = _var_name_to_filename(name)
                if size2 > 1:
                    manifest[name] = {
                        'filename': var_filename,
                        'shape': [tensor.tensor_shape.dim[0].size, tensor.tensor_shape.dim[1].size]
                    }
                else:
                    manifest[name] = {
                        'filename': var_filename,
                        'shape': [tensor.tensor_shape.dim[0].size]
                    }
                print ('Writing variable ' + name + '...')
                with open(os.path.join(output_dir, var_filename), 'wb') as f:
                    f.write(tensor.tensor_content)
                var_filenames_strs.append("\"" + var_filename + "\"")
    manifest_fpath = os.path.join(output_dir, 'manifest.json')
    print ('Writing manifest to ' + manifest_fpath)
    with open(manifest_fpath, 'w') as f:
        f.write(json.dumps(manifest, indent = 2, sort_keys = True))
    print ("ok")



jsdoにファイルアップロード

以下をアップする。

dense_1_bias
dense_1_kernel
dense_2_bias
dense_2_kernel
dense_3_kernel
manifest.json

モデルを作る。

kerasのモデル

inputs = Input(shape = (1, ))
m = Dense(30)(inputs)
m = Activation('sigmoid')(m)
m = Dense(10)(m)
m = Activation('sigmoid')(m)
m = Dense(1)(m)
model = Model(inputs, m)
sgd = SGD(lr = 0.1)
model.compile(loss = 'mean_squared_error', optimizer = sgd)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 1)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 30)                60        
_________________________________________________________________
activation_1 (Activation)    (None, 30)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                310       
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 11        
=================================================================
Total params: 381
Trainable params: 381
Non-trainable params: 0
_________________________________________________________________

deeplearn.jsのモデル

    var input = g.placeholder('input', [1]);
    var hidden1W = g.constant(vars['dense_1/kernel']);
    var hidden1B = g.constant(vars['dense_1/bias']);
    var hidden1 = g.sigmoid(g.add(g.matmul(input, hidden1W), hidden1B));
    var hidden2W = g.constant(vars['dense_2/kernel']);
    var hidden2B = g.constant(vars['dense_2/bias']);
    var hidden2 = g.sigmoid(g.add(g.matmul(hidden1, hidden2W), hidden2B));
    var hidden3W = g.constant(vars['dense_3/kernel']);
    var hidden3B = g.constant(vars['dense_3/bias']);
    var logits = g.matmul(hidden2, hidden3W);
    return [input, logits];

写真

image.png

サンプルコード

deeplearnjs

function CheckpointLoader(urlPath) {
    this.urlPath = urlPath;
    if (this.urlPath.charAt(this.urlPath.length - 1) !== '/')
    {
        this.urlPath += '/';
    }
}
CheckpointLoader.prototype.loadManifest = function() {
    var _this = this;
    return new Promise(function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.open('GET', _this.urlPath + MANIFEST_FILE);
        xhr.onload = function() {
            _this.checkpointManifest = JSON.parse(xhr.responseText);
            resolve();
        };
        xhr.onerror = function(error) {
            alert(MANIFEST_FILE + " not found at " + _this.urlPath + ". " + error);
        };
        xhr.send();
    });
};
CheckpointLoader.prototype.getCheckpointManifest = function() {
    var _this = this;
    if (this.checkpointManifest == null) 
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                resolve(_this.checkpointManifest);
            });
        });
    }
    return new Promise(function(resolve, reject) {
        resolve(_this.checkpointManifest);
    });
};
CheckpointLoader.prototype.getAllVariables = function() {
    var _this = this;
    if (this.variables != null) 
    {
        return new Promise(function(resolve, reject) {
            resolve(_this.variables);
        });
    }
    return new Promise(function(resolve, reject) {
        _this.getCheckpointManifest().then(function(checkpointDefinition) {
            var variableNames = Object.keys(_this.checkpointManifest);
            var variablePromises = [];
            for (var i = 0; i < variableNames.length; i++) 
            {
                variablePromises.push(_this.getVariable(variableNames[i]));
            }
            Promise.all(variablePromises).then(function(variables) {
                _this.variables = {};
                for (var i = 0; i < variables.length; i++)
                {
                    _this.variables[variableNames[i]] = variables[i];
                }
                resolve(_this.variables);
            });
        });
    });
};
CheckpointLoader.prototype.getVariable = function(varName) {
    var _this = this;
    if (!(varName in this.checkpointManifest))
    {
        alert('Cannot load non-existant variable ' + varName);
    }
    var variableRequestPromiseMethod = function(resolve, reject) {
        var xhr = new XMLHttpRequest();
        xhr.responseType = 'arraybuffer';
        var fname = _this.checkpointManifest[varName].filename;
        xhr.open('GET', _this.urlPath + fname);
        xhr.onload = function() {
            var values = new Float32Array(xhr.response);
            var ndarray = dl.NDArray.make(_this.checkpointManifest[varName].shape, {
                values: values 
            });
            resolve(ndarray);
        };
        xhr.onerror = function(error) {
            alert('Could not fetch variable ' + varName + ': ' + error);
        };
        xhr.send();
    };
    if (this.checkpointManifest == null)
    {
        return new Promise(function(resolve, reject) {
            _this.loadManifest().then(function() {
                new Promise(variableRequestPromiseMethod).then(resolve);
            });
        });
    }
    return new Promise(variableRequestPromiseMethod);
};


var MANIFEST_FILE = '/assets/g/m/b/V/gmbVz';
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var vars2;
function buildModelGraphAPI(vars) {
    var input = g.placeholder('input', [1]);
    var hidden1W = g.constant(vars['dense_1/kernel']);
    var hidden1B = g.constant(vars['dense_1/bias']);
    var hidden1 = g.sigmoid(g.add(g.matmul(input, hidden1W), hidden1B));
    var hidden2W = g.constant(vars['dense_2/kernel']);
    var hidden2B = g.constant(vars['dense_2/bias']);
    var hidden2 = g.sigmoid(g.add(g.matmul(hidden1, hidden2W), hidden2B));
    var hidden3W = g.constant(vars['dense_3/kernel']);
    var hidden3B = g.constant(vars['dense_3/bias']);
    var logits = g.matmul(hidden2, hidden3W);
    return [input, logits];
}
var reader = new CheckpointLoader('http://jsrun.it');
reader.getAllVariables().then(function(vars) {
    vars2 = vars;
    var _a = buildModelGraphAPI(vars);
    var input = _a[0];
    var probs = _a[1];
    var sess = new dl.Session(input.node.graph, math);
    math.scope(function() {
        var t;
        var s = 200;
        var sin = new Float32Array(s);
        for (t = 0; t < s; t++)
        {
            var data = [t / 30];
            var inputData = dl.Array1D.new(data);
            var probsVal = sess.eval(probs, [{
                tensor: input,
                data: inputData
            }]);
            sin[t] = probsVal.getValues();
        }
        draw(sin, 0);
    });
});

var canvas = document.getElementById("canvas");
var ctx = canvas.getContext("2d");
function draw(data, n) {
    var hc = n * 100 + 150;
    ctx.strokeStyle = "#f00";
    ctx.lineWidth = 1;
    ctx.moveTo(0, hc);
    for (var i = 1; i < canvas.width; i++) 
    {
        ctx.lineTo(i, hc - data[i] * 30);
    }
    ctx.stroke();
}

成果物

以上。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0