13
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasAdvent Calendar 2017

Day 5

keras_to_tensorflow

Last updated at Posted at 2017-12-06

#概要
jsdoに、keras.jsで学習したグラフを使おうとしたが、だめだった。
だめなやつ
deeplearn.jsに、鞍替えするため、kerasのモデルをtensorflowに、変換した。
手順を記載する。

#環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.2

#kerasでモデルを学習して、セーブする。

import numpy as np
from tensorflow.contrib.keras.python.keras.models import Model
from tensorflow.contrib.keras.python.keras.layers import Input, Dense, Activation
from tensorflow.contrib.keras.python.keras.optimizers import SGD
import matplotlib.pyplot as plt

x = np.arange(200).reshape(-1, 1) / 30
y = np.sin(x)

inputs = Input(shape = (1, ))
m = Dense(30)(inputs)
m = Activation('sigmoid')(m)
m = Dense(10)(m)
m = Activation('sigmoid')(m)
m = Dense(1)(m)
model = Model(inputs, m)
sgd = SGD(lr = 0.1)
model.compile(loss = 'mean_squared_error', optimizer = sgd)
model.summary();
model.fit(x, y, epochs = 400, batch_size = 10, verbose = 0)
preds = model.predict(x)
plt.plot(x, y, 'b', x, preds, 'r--')
plt.savefig("keras13.png")
plt.show()
with open('keras13_arch.json', 'w') as f:
	f.write(model.to_json())
model.save('keras13.h5')
print ("save model")

keras13.png

#keras_to_tensorflowで、pbファイルに変換する。

#pbファイルを、読んで、inputとoutputを確認。

import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
	with gfile.FastGFile("keras13.pb", 'rb') as f:
		graph_def = tf.GraphDef()
		graph_def.ParseFromString(f.read())
		_ = tf.import_graph_def(graph_def, name = '')
	for op in tf.get_default_graph().get_operations():
		print (op.name)
		for output in op.outputs:
			print ('  ', output.name)


#pbファイルを、読み込んで動作確認。

import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
import matplotlib.pyplot as plt

x = np.arange(200).reshape(-1, 1) / 30
y = np.sin(x)
g = []
with tf.Session() as sess:
	with gfile.FastGFile("keras13.pb", 'rb') as f:
		graph_def = tf.GraphDef()
		graph_def.ParseFromString(f.read())
		_ = tf.import_graph_def(graph_def, name = '')
	for i in range(200):
		result = sess.run('output_node0:0', feed_dict = {
			'input_1:0': [x[i]]
		})
		print (result)
		g.append(result)
	plt.plot(g)
	plt.savefig("predic13.png")
	plt.show()

predic13.png

以上。

13
7
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?