#概要
raspberry pi 1でtensorflow liteやってみた。
tfliteファイルを作ってみた。
kerasモデルから作ってみた。
データセットは、fizzbuzz.
#環境
tensorflow 1.12
#kerasモデルを学習してセーブする。
import numpy as np
from tensorflow.contrib.keras.api.keras.models import Sequential
from tensorflow.contrib.keras.api.keras.layers import Dense, Activation
from tensorflow.contrib.keras.api.keras.models import Model
def binary_encode(i, num_digits):
return np.array([i >> d & 1 for d in range(num_digits)])
def fizz_buzz_encode(i):
if i % 15 == 0:
return np.array([0, 0, 0, 1])
elif i % 5 == 0:
return np.array([0, 0, 1, 0])
elif i % 3 == 0:
return np.array([0, 1, 0, 0])
else:
return np.array([1, 0, 0, 0])
def fizz_buzz(i, prediction):
return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
NUM_DIGITS = 7
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
trY = np.array([fizz_buzz_encode(i) for i in range(1, 101)])
model = Sequential()
model.add(Dense(64, input_dim = 7))
model.add(Activation('tanh'))
model.add(Dense(4, input_dim = 64))
model.add(Activation('softmax'))
model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
model.fit(trX, trY, epochs = 3600, batch_size = 64)
model.save('fizzbuzz5.h5')
print ("save model")
#kerasファイルからtfliteファイルを作る。
import tensorflow as tf
import tensorflow.contrib.lite as lite
converter = lite.TFLiteConverter.from_keras_model_file("fizzbuzz5.h5")
tflite_model = converter.convert()
open("fizzbuzz.tflite", "wb").write(tflite_model)
print ("ok")
#tfliteファイルを検証する。
import numpy as np
import tensorflow as tf
import tensorflow.contrib.lite as lite
def binary_encode(i, num_digits):
return np.array([i >> d & 1 for d in range(num_digits)])
def fizz_buzz_encode(i):
if i % 15 == 0:
return np.array([0, 0, 0, 1])
elif i % 5 == 0:
return np.array([0, 0, 1, 0])
elif i % 3 == 0:
return np.array([0, 1, 0, 0])
else:
return np.array([1, 0, 0, 0])
def fizz_buzz(i, prediction):
return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
NUM_DIGITS = 7
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
interpreter = lite.Interpreter(model_path = "fizzbuzz.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
#print (input_details)
#print (output_details)
for i in range(1, 100):
input_data = np.array([trX[i - 1]], dtype = np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (fizz_buzz(i, np.argmax(output_data[0])))
以上。