LoginSignup
0

More than 5 years have passed since last update.

raspberry pi 1でtensorflow lite その6

Posted at

概要

raspberry pi 1でtensorflow liteやってみた。
tfliteファイルを作ってみた。
freeze_graphなpbファイルから作ってみた。
データセットは、xor.

環境

tensorflow 1.12

モデルを学習してfreezeなpbファイルに変換する。

import tensorflow as tf
import tensorflow.contrib.lite as lite
import numpy as np
from tensorflow.python.framework import graph_util

X = [[0, 0], [0, 1], [1, 0], [1, 1]]
Y = [[1, 0], [0, 1], [0, 1], [1, 0]]
x = tf.placeholder(tf.float32, shape = [None, 2], name = "input")
y = tf.placeholder(tf.float32, shape = [None, 2])
w1 = tf.Variable(tf.random_uniform([2, 2], -1, 1, seed = 0))
w2 = tf.Variable(tf.random_uniform([2, 2], -1, 1, seed = 0))
b1 = tf.Variable(tf.zeros([2]))
b2 = tf.Variable(tf.zeros([2]))
h1 = tf.sigmoid(tf.matmul(x, w1) + b1)
h2 = tf.nn.softmax(tf.matmul(h1, w2) + b2, name = "output")
cost = -tf.reduce_sum(y * tf.log(h2))
opti = tf.train.GradientDescentOptimizer(0.1).minimize(cost)
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  for i in range(10000):
    sess.run(opti, feed_dict = {
      x: X,
      y: Y
    })
  for i in [[1, 1], [1, 0], [0, 1], [0, 0]]:
    print (i, sess.run(h2, feed_dict = {
      x: [i],
    }))
  output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["input", "output"])
  with tf.gfile.GFile("xor2_graph.pb", "wb") as f:
    f.write(output_graph_def.SerializeToString())

pbファイルからtfliteファイルに変換する。

graph_def_file = "xor2_graph.pb"
input_arrays = ["input"]
output_arrays = ["output"]
converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("xor2_model.tflite", "wb").write(tflite_model)

tfliteファイルを用いて、検証する。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.lite as lite

interpreter = lite.Interpreter(model_path = "xor2_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print (input_details)
print (output_details)
input_shape = input_details[0]['shape']
input_data = np.array([[0.0, 0.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)
input_data = np.array([[1.0, 0.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)
input_data = np.array([[0.0, 1.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)
input_data = np.array([[1.0, 1.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)

以上。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0