#概要
jsdoに、keras.jsで学習したグラフを使おうとしたが、だめだった。
だめなやつ
deeplearn.jsに、鞍替えするため、kerasのモデルをtensorflowに、変換した。
手順を記載する。
#環境
windows 7 sp1 64bit
anaconda3
tensorflow 1.2
#kerasでモデルを学習して、セーブする。
import numpy as np
from tensorflow.contrib.keras.python.keras.models import Model
from tensorflow.contrib.keras.python.keras.layers import Input, Dense, Activation
from tensorflow.contrib.keras.python.keras.optimizers import SGD
import matplotlib.pyplot as plt
x = np.arange(200).reshape(-1, 1) / 30
y = np.sin(x)
inputs = Input(shape = (1, ))
m = Dense(30)(inputs)
m = Activation('sigmoid')(m)
m = Dense(10)(m)
m = Activation('sigmoid')(m)
m = Dense(1)(m)
model = Model(inputs, m)
sgd = SGD(lr = 0.1)
model.compile(loss = 'mean_squared_error', optimizer = sgd)
model.summary();
model.fit(x, y, epochs = 400, batch_size = 10, verbose = 0)
preds = model.predict(x)
plt.plot(x, y, 'b', x, preds, 'r--')
plt.savefig("keras13.png")
plt.show()
with open('keras13_arch.json', 'w') as f:
f.write(model.to_json())
model.save('keras13.h5')
print ("save model")
#keras_to_tensorflowで、pbファイルに変換する。
#pbファイルを、読んで、inputとoutputを確認。
import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
with gfile.FastGFile("keras13.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name = '')
for op in tf.get_default_graph().get_operations():
print (op.name)
for output in op.outputs:
print (' ', output.name)
#pbファイルを、読み込んで動作確認。
import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(200).reshape(-1, 1) / 30
y = np.sin(x)
g = []
with tf.Session() as sess:
with gfile.FastGFile("keras13.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name = '')
for i in range(200):
result = sess.run('output_node0:0', feed_dict = {
'input_1:0': [x[i]]
})
print (result)
g.append(result)
plt.plot(g)
plt.savefig("predic13.png")
plt.show()
以上。