この記事はRust Advent Calendar 2023 シリーズ3の22日目の記事です.
TensorFlow Lite for Microcontrollers インタープリターをRustで再実装しました.
(ずっと25日に登録したと勘違いしていて,大急ぎで書いています.)
BerryLiteのリポジトリは以下です.
また,BerryLiteを使用したiOSデモアプリのリポジトリは以下です.
この記事では,まずBerryLiteの概要と使用方法について説明します.
最後に, BerryLiteを組み込んだiOSでもアプリを実装したので,紹介します.
BerryLiteを開発した動機
TensorFlow Lite for Microcontrollersとは,組み込み機器向けのTensorFlow Lite ランタイムです.
TensorFlow で学習したモデルファイルを使用して,IoT機器などにDeep Learningモデルを組み込むことが可能であり,非常に便利なソフトウェアです.
しかし,TensorFlow Lite MicrocontrollersはC++で実装されており,Rustで書かれた組み込み機器用のソフトウェアで使うには,以下のデメリットがあります.
-
cargo
を使用して,ビルド・クレートの管理ができない - 追加でFFI用のC インターフェースを作成する必要がある
- FFI で関数を呼び出す部分がunsafeになってしまう (呼び出し部分の安全性をプログラマが担保する必要がある)
BerryLiteは全てRustで実装されているので,上記のデメリットは解決されます.
BerryLiteの使用方法
TensorFlow Lite for Microcontrollersは,person detectionというサンプルモデルと実装を提供しています.
person detection は人かどうかを判定するモデルであり,アーキテクチャとしてはmobilenet v1ベースのモデルです. また,モデルは8bit量子化されています.
BerryLiteによる,person detectionの実装例を説明します.
実装の詳細はリポジトリのexampleディレクトリにあります.
use berrylite::kernel::micro_operator::i8::avg_pool2d_i8::OpAvgPool2DInt8;
use berrylite::kernel::micro_operator::i8::conv2d_i8::OpConv2DInt8;
use berrylite::kernel::micro_operator::i8::depthwise_conv2d_i8::OpDepthWiseConv2DInt8;
use berrylite::kernel::micro_operator::i8::fully_connected_i8::OpFullyConnectedInt8;
use berrylite::kernel::micro_operator::i8::reshape_i8::OpReshapeInt8;
use berrylite::kernel::micro_operator::i8::softmax_i8::OpSoftMaxInt8;
use berrylite::micro_allocator::{ArenaAllocator, BumpArenaAllocator};
use berrylite::micro_errors::Result;
use berrylite::micro_interpreter::BLiteInterpreter;
use berrylite::micro_op_resolver::BLiteOpResolver;
use berrylite::tflite_schema_generated::tflite;
const BUFFER: &[u8; 300568] = include_bytes!("../resources/models/person_detect.tflite");
const ARENA_SIZE: usize = 130 * 1024;
static mut ARENA: [u8; ARENA_SIZE] = [0; ARENA_SIZE];
fn set_input(
interpreter: &mut BLiteInterpreter<'_, i8>,
input_h: usize,
input_w: usize,
_input_zero_point: i32,
image: &[u8],
) {
for h in 0..input_h {
for w in 0..input_w {
let v = image[h * input_w + w];
// println!("{} {}", v, v as i8);
interpreter.input.data[h * input_w + w] = v as i8;
}
}
}
fn predict(image: &[u8]) -> Result<usize> {
// 1. モデルの読み込み
let model = tflite::root_as_model(BUFFER).unwrap();
// 2. アロケータの初期化
let mut allocator = unsafe { BumpArenaAllocator::new(&mut ARENA) };
// 3. 使用するレイヤーの登録
let mut op_resolver = BLiteOpResolver::<7, _, _>::new();
op_resolver.add_op(OpFullyConnectedInt8::fully_connected_int8())?;
op_resolver.add_op(OpReshapeInt8::reshape_int8())?;
op_resolver.add_op(OpConv2DInt8::conv2d_int8())?;
op_resolver.add_op(OpAvgPool2DInt8::avg_pool2d_int8())?;
op_resolver.add_op(OpSoftMaxInt8::softmax_int8())?;
op_resolver.add_op(OpDepthWiseConv2DInt8::depthwise_conv2d_int8())?;
// 4. インタープリタの初期化
let mut interpreter = BLiteInterpreter::new(&mut allocator, &op_resolver, &model)?;
let (_input_scale, input_zero_point) = interpreter.get_input_quantization_params().unwrap();
let (output_scale, output_zero_point) = interpreter.get_output_quantization_params().unwrap();
println!("{:?}", allocator.description());
// 5. 入力のセット
set_input(&mut interpreter, 96, 96, input_zero_point, image);
// 6. 推論の実行
interpreter.invoke()?;
// 7. 推論結果の取得
let output = interpreter.output;
dbg!(output);
let mut num_prob = 0.;
let mut num = 0;
for (i, &y_pred) in output.data.iter().enumerate() {
let prob = output_scale * (y_pred as i32 - output_zero_point) as f32;
dbg!(prob);
if prob > num_prob {
num_prob = prob;
num = i;
}
}
Ok(num)
}
実装は以下の7ステップからなります.
これらのステップはTensorFlow Lite for Microcontrollersとほぼ同等のAPIです.
- モデルの読み込み
TensorFlow Liteモデルを flatbuffersを用いて読み込みます. - インタープリタ用のメモリアロケータの初期化
組み込み機器用のTensroFlow runtimeなので,メモリアロケータは独自に実装してあります. - モデルないで使用するレイヤーを登録
person detectionモデルが使用する 全結合層やDepthWiseConvを登録します. - インタープリタの初期化
BLiteInterpreterを初期化し,作成します. - 入力のセット
入力画像をインタプリタの入力用テンソルにセットします. - 推論の実行
- 推論結果の取得
推論結果をインタプリタの出力用テンソルから取得します.ここで,モデルが量子化されているので,確率を得るためには出力を浮動小数点に変換する必要があります.(2値分類なので,単純に値が大きい方でも良いですが...)
BerryLiteを使用したiOSデモアプリ
BerryLiteを使用して,iOSデモアプリを作成しました.
リポジトリは公開しています.
モデルが小さくて性能がそんなに無いので,画像の中心で人を捉えないと正しく判定してくれないです...
また,TensorFlow Lite Hubで公開されているモデルは量子化のフォーマットがuint8で未対応なので,そのうち対応して動かしたいです.
iOSへの組み込みはBerryLite (Rust) --> (cxx) --> BerryLite (Rust) + C++ interface --> (ios.coolchain.cmake) --> BerryLiteCxx.framework の順で最終的にiOSのframeworkにして,使用しています.
この時,今年のWWDC2023で発表された,SwiftからC++を直接呼ぶ機構を使用して,C++ラッパー関数をSwiftで直接呼んでいます.
最後に
TensorFlow Lite runtimeは非常に言語処理系的な実装(まんまインタープリタ)になっており,違いはASTが入力されるかモデルのグラフが入力されるかの違いでしか無いです.
言語処理系の自作が好きな人はDeep Learing runtimeの自作にハマる可能性は大きいと思います.
ただ,言語処理系とは異なり,Deep Learning runtimeは多少の誤差ならそれっぽい出力をするので,デバッグが難しいと感じました.