LoginSignup
0

More than 5 years have passed since last update.

raspberry pi 1でtensrorflow lite その15

Posted at

概要

raspberry pi 1でtensorflow liteやってみた。
マイク入力を扱うので、alsaを叩いてみた。
そのまま、出力するモデルを学習してみた。

データセットの作成。

import numpy as np

sample_rate = 44100.
nsamples = 320
t = np.arange(nsamples) / sample_rate
vin = np.sin(2 * np.pi * 440. * t) 
vsa = np.zeros((nsamples, 2))
a = 0
for i in range(nsamples):
  vsa[i, 0] = vin[i]
  vsa[i, 1] = vin[i]
data = np.c_[vsa[ : , 0], vsa[ : , 1], vin]
np.savetxt('sin.csv', data, delimiter = ',', header = "x0,x1,y")
print ("ok")

モデルを学習してtfliteファイルを作る。

import numpy as np 
import tensorflow as tf
import tensorflow.contrib.lite as lite
import matplotlib.pyplot as plt

data = np.loadtxt('sin.csv', delimiter = ',', unpack = True)
x = np.c_[data[0], data[1]]
y = data[2]
fig = plt.figure(1)
ax = fig.add_subplot(311)
ax.plot(data[0][1 : 300])
ax = fig.add_subplot(312)
ax.plot(y[1 : 300])
ax = fig.add_subplot(313)
x_in = tf.placeholder("float", [None, 2])
y_out = tf.placeholder("float", [None, 1])
w1 = tf.Variable(tf.random_uniform([2, 8]))
b1 = tf.Variable(tf.zeros([8]))
y1 = tf.nn.tanh(tf.matmul(x_in, w1) + b1)
w2 = tf.Variable(tf.random_uniform([8, 1]))
b2 = tf.Variable(tf.zeros([1]))
y2 = tf.nn.tanh(tf.matmul(y1, w2) + b2)
loss = tf.nn.l2_loss(y2 - y_out)
train = tf.train.AdamOptimizer(0.001).minimize(loss)
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  for i in range(10001):
    sess.run(train, feed_dict = {
      x_in: x,
      y_out: y.reshape(320, 1)
    })
    if i % 1000 == 0:
      summary = sess.run(loss, feed_dict = {
        x_in: x,
        y_out: vin.reshape(320, 1)
      })
      print (i, summary)
  test_y = sess.run(y2, feed_dict = {
    x_in: x
  })
  converter = lite.TFLiteConverter.from_session(sess, [x_in], [y2])
  tflite_model = converter.convert()
  open("voice1.tflite", "wb").write(tflite_model)
  ax.plot(test_y[1 : 300])
  fig.set_tight_layout(True)
  plt.show()

tfliteファイルを検証する。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.lite as lite
import matplotlib.pyplot as plt

data = np.loadtxt('sin.csv', delimiter = ',', unpack = True)
x = np.c_[data[0], data[1]]
y = data[2]
fig = plt.figure(1)
ax = fig.add_subplot(311)
ax.plot(data[0][1 : 300])
ax = fig.add_subplot(312)
ax.plot(y[1 : 300])
ax = fig.add_subplot(313)
interpreter = lite.Interpreter(model_path = "voice1.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
pred = []
for i in x:
  input_data = np.array([i], dtype = np.float32)
  interpreter.set_tensor(input_details[0]['index'], input_data)
  interpreter.invoke()
  output_data = interpreter.get_tensor(output_details[0]['index'])
  pred.append(output_data[0][0])
ax.plot(pred[1 : 300])
fig.set_tight_layout(True)
plt.show()

v1.png

以上。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0