LoginSignup
4
0

More than 3 years have passed since last update.

Flutter で MLKit のカスタムモデルを使う

Posted at

Flutter ではバッテリー、カメラ、イメージピッカーなど、React Native ではサードパーティから提供されているようなプラグインが公式で多数提供されています。公式で提供されているプラグインは flutter/plugins というリポジトリで開発されていますが、README を見るとその充実ぶりが分かるかと思います。公式から提供されていることによって、似たようなサードパーティのプラグインが乱立することが避けられていると思います。

残念なのは、細かいところに微妙に手が届いていないプラグインがあるということです。さらに悪いことに、Flutter の開発者がまだ多くないということが原因で、サードパーティもまだあまり充実していません。今回は MLKit for Firebase のカスタムモデルを Flutter から利用した例を紹介します。

使用したライブラリ

公式からは firebase_ml_vision というライブラリが提供されています。Firebase から提供されている機械学習モデルをそのまま利用する場合はこれでよいのですが、残念ながらカスタムモデルを使いたい場合には対応していません。

サードパーティのライブラリとしては flutter_mlkit というものがあります。このライブラリはカスタムモデルに対応しているので、これを使います。今回は v0.13.1 ベースのバージョンを使いました。カスタムモデルによっては複数入力、複数出力となることもありこれまで対応していなかったのですが、最近のプルリクエスト で複数入力、複数出力にも対応しています。ちなみに複数入力の扱いにバグを見つけたので プルリクエストを送っているところ です。

Flutter から Firebase を使う共通設定を行う

カスタムモデルを使う場合に限らずいつも必要な設定として Flutter から Firebase を使うため共通設定が必要です。以下の手順に従いました。

Flutter アプリに Firebase を追加する  |  Firebase

プロジェクト内の変更としては以下のものが必要になります。

  • com.google.gms:google-services を依存関係に追加(4.2.0 を使いました)
  • ↑ を Gradle プラグインとして有効化 (apply plugin: 'com.google.gms.google-services')

AndroidX を有効化する

flutter_mlkit のほうで AndroidX の機能を使っている箇所があります。具体的には以下のものです。

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

というわけで、利用している側でも設定を有効にしなければビルドできません。プロジェクトレベルの gradle.properties に以下の設定を追加しました。

 org.gradle.jvmargs=-Xmx1536M
+android.useAndroidX=true
+android.enableJetifier=true

公式の情報としては以下を参照してください。

flutter_mlkit にパッチをあてる

複数入力の扱いにバグがあったので修正しました。本家のほうが治っていなければ以下のパッチをあててください。

サンプルコード

2×5 の大きさの行列の和を計算するカスタムモデルを提供いただいたので、以下のようなサンプルコードを作成しました。

  • initTflite() ではカスタムモデルの読み込みを行っています。実際に処理を実行する直前までモデルのダウンロードを遅延させたいと思って _onActionButtonPressed() の中で await initTflite() のように書いていたのですが、なぜかモデルがダウンロードできませんでした。flutter_mlkit の中で非同期処理にバグがあるのではないかと疑っています。
  • _onActionButtonPressed() の中でモデルの計算処理を行っています。FirebaseModelInputOutputOptions は最近のアップデートで複数入力、複数出力に対応しました。
  • Firebase ML とのデータのやりとりは Uint8List を使って行います。残念なことに、多次元配列を1次元にうまくパックしてやる必要があるだけでなく、複数入力の場合には連結して渡してやるといった処理が必要になります。
  • Float32List から Uint8List への変換でバイトオーダーなどを気にしないといけないかと不安になりますが、ここについては多少マシで Float32List#buffer.asUint8List() というメソッドを使うことができます。
  • FirebaseModelInterpreter#run() の戻り値がかなり厄介で List<dynamic> となっています。実態としては 2×5 の大きさの行列が1つ出力されるので List<List<List<double>>> のようなものがなのですが、これがなぜかキャストできませんでした。これについては Dart の型システムを理解して出直して来たいと思っています。
import 'dart:typed_data';

import 'package:flutter/material.dart';
import 'package:mlkit/mlkit.dart';

void main() => runApp(MyApp());

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Tflite add',
      theme: ThemeData(
        primarySwatch: Colors.blue,
      ),
      home: MyHomePage(),
    );
  }
}

class MyHomePage extends StatefulWidget {
  @override
  _MyHomePageState createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  String _left = "";
  String _right = "";
  String _output = "";
  FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.instance;
  FirebaseModelManager manager = FirebaseModelManager.instance;

  Future<void> initTflite() {
    return this.manager.registerRemoteModelSource(
          FirebaseRemoteModelSource(
            modelName: "add",
            enableModelUpdates: true,
          ),
        );
  }

  void _onLeftOpChanged(String op) {
    setState(() {
      _left = op;
    });
  }

  void _onRightOpChanged(String op) {
    setState(() {
      _right = op;
    });
  }

  @override
  void initState() {
    super.initState();
    initTflite();
  }

  void _onActionButtonPressed() async {
    final leftNum = double.parse(_left);
    final rightNum = double.parse(_right);

    const ROW = 2;
    const COLUMN = 5;

    const options = FirebaseModelInputOutputOptions([
      FirebaseModelIOOption(FirebaseModelDataType.FLOAT32, [ROW, COLUMN]),
      FirebaseModelIOOption(FirebaseModelDataType.FLOAT32, [ROW, COLUMN]),
    ], [
      FirebaseModelIOOption(FirebaseModelDataType.FLOAT32, [ROW, COLUMN]),
    ]);

    final left = List<double>(ROW * COLUMN);
    final right = List<double>(ROW * COLUMN);
    for (var i = 0; i < ROW; i++) {
      for (var j = 0; j < COLUMN; j++) {
        left[COLUMN * i + j] = leftNum;
        right[COLUMN * i + j] = rightNum;
      }
    }
    final concat = <double>[]..addAll(left)..addAll(right);
    print(concat);
    final input = float32ListToUint8List(Float32List.fromList(concat));
    print(input);

    final output = await this.interpreter.run("add", options, input);
    print(output);

    setState(() {
      _output = (output[0][0][0] as double).toStringAsFixed(2);
    });
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      body: Padding(
        padding: const EdgeInsets.all(16.0),
        child: Center(
          child: Column(
            mainAxisAlignment: MainAxisAlignment.center,
            children: <Widget>[
              TextField(
                decoration: InputDecoration(
                  labelText: 'left-hand operand',
                ),
                onChanged: _onLeftOpChanged,
                keyboardType: TextInputType.numberWithOptions(
                  signed: true,
                  decimal: true,
                ),
              ),
              TextField(
                decoration: InputDecoration(
                  labelText: 'right-hand operand',
                ),
                onChanged: _onRightOpChanged,
                keyboardType: TextInputType.numberWithOptions(
                  signed: true,
                  decimal: true,
                ),
              ),
              Padding(
                padding: const EdgeInsets.only(top: 32.0),
                child: Text(
                  _output,
                  style: Theme.of(context).textTheme.display1,
                ),
              ),
            ],
          ),
        ),
      ),
      floatingActionButton: FloatingActionButton(
        onPressed: _onActionButtonPressed,
        tooltip: 'Execute',
        child: Icon(Icons.add),
      ),
    );
  }
}

Uint8List float32ListToUint8List(Float32List list) {
  return list.buffer.asUint8List();
}

結果

Tensorflow Lite のカスタムモデルを使って計算する簡単なデモができました。

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0