Tensorflow lite for Microcontrollersは、メモリが数十キロバイトしかないマイクロコントローラでも機械学習モデルを実行できるように設計されているようで、Arduinoベースのマイクロコントローラでも、加速度計データからのジェスチャー分類、カメラデータを使用した画像分類などができるとのことで、手持ちのWio Terminalで試してみました。
#Arduino Tensorflow liteライブラリのインストール
Arduino IDEのライブラリマネージャ上で、Arduino TensorFlow Liteキーワードで検索してください。下図の通り、Tensorflow liteのライブラリがフィルタリング表示されていますが、precompiledではない最新バージョンをインストールしてください。
#Hello Worldを動かしてみる
ライブラリのインストールが完了すれば、[ファイル]_[スケッチ例]メニューから、Arduino_TesorFlowLite->hello_worldをロードすることができます。
hello_worldのサンプルプログラムは、そのままコンパイルと書き込みに成功するはずです。書き込み後、シリアルプロッタを表示すると下図のようにSin波形が描画されていくのがわかると思います。
##Wio Terminalの画面上にプロット
せっかく画面があるということなので、生成されたSine波を画面上にプロットしてみます。まずは、TFT_eSPI.hをインクルードし、fillCircleでサークルの描画となります。
setup()関数でTFT_eSPIの初期化、loop()関数内のHandleOutputの後あたりに画面表示用のdrawSine関数を追加しました。
#include <TensorFlowLite.h>
#include "main_functions.h"
#include "constants.h"
#include "output_handler.h"
#include "sine_model_data.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
#include "tensorflow/lite/experimental/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
#include"TFT_eSPI.h"
TFT_eSPI tft;
// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
int inference_count = 0;
// Create an area of memory to use for input, output, and intermediate arrays.
// Finding the minimum value for your model may require some trial and error.
constexpr int kTensorArenaSize = 2 * 1024;
uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is important for Arduino compatibility.
void setup() {
tft.begin();
tft.setRotation(3);
tft.fillScreen(TFT_BLACK);
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = µ_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_sine_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
error_reporter->Report(
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// This pulls in all the operation implementations we need.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::ops::micro::AllOpsResolver resolver;
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
error_reporter->Report("AllocateTensors() failed");
return;
}
// Obtain pointers to the model's input and output tensors.
input = interpreter->input(0);
output = interpreter->output(0);
// Keep track of how many inferences we have performed.
inference_count = 0;
}
int _x, _y = 0;
// The name of this function is important for Arduino compatibility.
void loop() {
tft.fillCircle(_x, _y, 8, TFT_BLACK);
// Calculate an x value to feed into the model. We compare the current
// inference_count to the number of inferences per cycle to determine
// our position within the range of possible x values the model was
// trained on, and use this to calculate a value.
float position = static_cast<float>(inference_count) /
static_cast<float>(kInferencesPerCycle);
float x_val = position * kXrange;
// Place our calculated x value in the model's input tensor
input->data.f[0] = x_val;
// Run inference, and report any error
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
error_reporter->Report("Invoke failed on x_val: %f\n",
static_cast<double>(x_val));
return;
}
// Read the predicted y value from the model's output tensor
float y_val = output->data.f[0];
// Output the results. A custom HandleOutput function can be implemented
// for each supported hardware target.
HandleOutput(error_reporter, x_val, y_val);
drawSine(x_val, y_val);
// Increment the inference_counter, and reset it if we have reached
// the total number per cycle
inference_count += 1;
if (inference_count >= kInferencesPerCycle) inference_count = 0;
}
void drawSine(float x_value, float y_value) {
char header[32];
sprintf(header, "x=%f y=%f", x_value, y_value);
tft.drawString(header, 0, 0);
_x = tft.width() * (x_value / 6.28);
_y = tft.height() * (y_value + 1) / 2;
tft.fillCircle(_x, _y, 8, TFT_WHITE);
delay(10);
}
#動かしてみる
こんな感じで動きました。
Tensorflow liteは、Arduino Nano 33 BLE SenseやSTM32F746 Discovery kit、Espressif ESP32-DevKitCなどいくつかのデバイスで検証されていますが、M5StackやWio Terminalで動かそうとすると、そのままのサンプルコードでは動かず、入力系のコードを変えていく必要がありそうです。