#概要
kerasのモデルをtensorflowに、変換して、deeplearn.jsで、使ってみた。
手順を記載する。
#環境
windows 7 sp1 64bit
anaconda3
tensorflow 1.2
#kerasのモデルをtensorflowに、変換したpbファイルから、manifesit.jsonとversのファイルをつくるコード。
import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import string
import json
FILENAME_CHARS = string.ascii_letters + string.digits + '_'
def _var_name_to_filename(var_name):
chars = []
for c in var_name:
if c in FILENAME_CHARS:
chars.append(c)
elif c == '/':
chars.append('_')
return ''.join(chars)
chk_fpath = "./"
output_dir = "./deep"
tf.gfile.MakeDirs(output_dir)
manifest = {}
var_filenames_strs = []
with gfile.FastGFile("keras13.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name = '')
for n0 in graph_def.node:
print (n0.op)
if n0.op == 'Const':
print (n0.name)
tensor = n0.attr['value'].tensor
size = len(tensor.tensor_content)
print (size)
size2 = len(tensor.tensor_shape.dim)
#print(tensor.tensor_shape.dim[0].size)
#print(tensor.tensor_shape.dim[1].size)
if size > 0:
name = n0.name
var_filename = _var_name_to_filename(name)
if size2 > 1:
manifest[name] = {
'filename': var_filename,
'shape': [tensor.tensor_shape.dim[0].size, tensor.tensor_shape.dim[1].size]
}
else:
manifest[name] = {
'filename': var_filename,
'shape': [tensor.tensor_shape.dim[0].size]
}
print ('Writing variable ' + name + '...')
with open(os.path.join(output_dir, var_filename), 'wb') as f:
f.write(tensor.tensor_content)
var_filenames_strs.append("\"" + var_filename + "\"")
manifest_fpath = os.path.join(output_dir, 'manifest.json')
print ('Writing manifest to ' + manifest_fpath)
with open(manifest_fpath, 'w') as f:
f.write(json.dumps(manifest, indent = 2, sort_keys = True))
print ("ok")
#jsdoにファイルアップロード
以下をアップする。
dense_1_bias
dense_1_kernel
dense_2_bias
dense_2_kernel
dense_3_kernel
manifest.json
#モデルを作る。
kerasのモデル
inputs = Input(shape = (1, ))
m = Dense(30)(inputs)
m = Activation('sigmoid')(m)
m = Dense(10)(m)
m = Activation('sigmoid')(m)
m = Dense(1)(m)
model = Model(inputs, m)
sgd = SGD(lr = 0.1)
model.compile(loss = 'mean_squared_error', optimizer = sgd)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 1) 0
_________________________________________________________________
dense_1 (Dense) (None, 30) 60
_________________________________________________________________
activation_1 (Activation) (None, 30) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 310
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
_________________________________________________________________
dense_3 (Dense) (None, 1) 11
=================================================================
Total params: 381
Trainable params: 381
Non-trainable params: 0
_________________________________________________________________
deeplearn.jsのモデル
var input = g.placeholder('input', [1]);
var hidden1W = g.constant(vars['dense_1/kernel']);
var hidden1B = g.constant(vars['dense_1/bias']);
var hidden1 = g.sigmoid(g.add(g.matmul(input, hidden1W), hidden1B));
var hidden2W = g.constant(vars['dense_2/kernel']);
var hidden2B = g.constant(vars['dense_2/bias']);
var hidden2 = g.sigmoid(g.add(g.matmul(hidden1, hidden2W), hidden2B));
var hidden3W = g.constant(vars['dense_3/kernel']);
var hidden3B = g.constant(vars['dense_3/bias']);
var logits = g.matmul(hidden2, hidden3W);
return [input, logits];
#写真
#サンプルコード
deeplearnjs
function CheckpointLoader(urlPath) {
this.urlPath = urlPath;
if (this.urlPath.charAt(this.urlPath.length - 1) !== '/')
{
this.urlPath += '/';
}
}
CheckpointLoader.prototype.loadManifest = function() {
var _this = this;
return new Promise(function(resolve, reject) {
var xhr = new XMLHttpRequest();
xhr.open('GET', _this.urlPath + MANIFEST_FILE);
xhr.onload = function() {
_this.checkpointManifest = JSON.parse(xhr.responseText);
resolve();
};
xhr.onerror = function(error) {
alert(MANIFEST_FILE + " not found at " + _this.urlPath + ". " + error);
};
xhr.send();
});
};
CheckpointLoader.prototype.getCheckpointManifest = function() {
var _this = this;
if (this.checkpointManifest == null)
{
return new Promise(function(resolve, reject) {
_this.loadManifest().then(function() {
resolve(_this.checkpointManifest);
});
});
}
return new Promise(function(resolve, reject) {
resolve(_this.checkpointManifest);
});
};
CheckpointLoader.prototype.getAllVariables = function() {
var _this = this;
if (this.variables != null)
{
return new Promise(function(resolve, reject) {
resolve(_this.variables);
});
}
return new Promise(function(resolve, reject) {
_this.getCheckpointManifest().then(function(checkpointDefinition) {
var variableNames = Object.keys(_this.checkpointManifest);
var variablePromises = [];
for (var i = 0; i < variableNames.length; i++)
{
variablePromises.push(_this.getVariable(variableNames[i]));
}
Promise.all(variablePromises).then(function(variables) {
_this.variables = {};
for (var i = 0; i < variables.length; i++)
{
_this.variables[variableNames[i]] = variables[i];
}
resolve(_this.variables);
});
});
});
};
CheckpointLoader.prototype.getVariable = function(varName) {
var _this = this;
if (!(varName in this.checkpointManifest))
{
alert('Cannot load non-existant variable ' + varName);
}
var variableRequestPromiseMethod = function(resolve, reject) {
var xhr = new XMLHttpRequest();
xhr.responseType = 'arraybuffer';
var fname = _this.checkpointManifest[varName].filename;
xhr.open('GET', _this.urlPath + fname);
xhr.onload = function() {
var values = new Float32Array(xhr.response);
var ndarray = dl.NDArray.make(_this.checkpointManifest[varName].shape, {
values: values
});
resolve(ndarray);
};
xhr.onerror = function(error) {
alert('Could not fetch variable ' + varName + ': ' + error);
};
xhr.send();
};
if (this.checkpointManifest == null)
{
return new Promise(function(resolve, reject) {
_this.loadManifest().then(function() {
new Promise(variableRequestPromiseMethod).then(resolve);
});
});
}
return new Promise(variableRequestPromiseMethod);
};
var MANIFEST_FILE = '/assets/g/m/b/V/gmbVz';
var dl = deeplearn;
var g = new dl.Graph();
var math = new dl.NDArrayMathCPU();
var vars2;
function buildModelGraphAPI(vars) {
var input = g.placeholder('input', [1]);
var hidden1W = g.constant(vars['dense_1/kernel']);
var hidden1B = g.constant(vars['dense_1/bias']);
var hidden1 = g.sigmoid(g.add(g.matmul(input, hidden1W), hidden1B));
var hidden2W = g.constant(vars['dense_2/kernel']);
var hidden2B = g.constant(vars['dense_2/bias']);
var hidden2 = g.sigmoid(g.add(g.matmul(hidden1, hidden2W), hidden2B));
var hidden3W = g.constant(vars['dense_3/kernel']);
var hidden3B = g.constant(vars['dense_3/bias']);
var logits = g.matmul(hidden2, hidden3W);
return [input, logits];
}
var reader = new CheckpointLoader('http://jsrun.it');
reader.getAllVariables().then(function(vars) {
vars2 = vars;
var _a = buildModelGraphAPI(vars);
var input = _a[0];
var probs = _a[1];
var sess = new dl.Session(input.node.graph, math);
math.scope(function() {
var t;
var s = 200;
var sin = new Float32Array(s);
for (t = 0; t < s; t++)
{
var data = [t / 30];
var inputData = dl.Array1D.new(data);
var probsVal = sess.eval(probs, [{
tensor: input,
data: inputData
}]);
sin[t] = probsVal.getValues();
}
draw(sin, 0);
});
});
var canvas = document.getElementById("canvas");
var ctx = canvas.getContext("2d");
function draw(data, n) {
var hc = n * 100 + 150;
ctx.strokeStyle = "#f00";
ctx.lineWidth = 1;
ctx.moveTo(0, hc);
for (var i = 1; i < canvas.width; i++)
{
ctx.lineTo(i, hc - data[i] * 30);
}
ctx.stroke();
}
#成果物
http://jsdo.it/ohisama1/uxpK
以上。