LoginSignup
1
0

More than 1 year has passed since last update.

flutterとfirebaseとtflite

Last updated at Posted at 2022-12-03

この記事は鈴鹿高専Advent Calendar 2022 24日目の記事です

今回はPCKに向けたアプリ開発で、DBや画像認識などのユーザーからは直接見えない処理の実装を担当したのでそれについて書いていきます

0. はじめに

今回は全てをクラスとして実装し、UI側(MVCのView?)では直接パッケージを叩かずに済むよう隠蔽化しました
こうするとDB設計を変えてもインターフェースを変えずに済んだり、同じ処理を内部で使いまわしたりできるのでとてもやりやすかったです

1. firebaseAuth

ではまず始めにfirebase authentication、特にgoogleアカウントでの認証の実装についてです
使うパッケージはfirebase_authgoogle_sign_inの2つです
firebase_authにgoogleSignInの関数が用意されているという浅い考えを持っていたので、実装当初は少しめんどくさく感じました

ログイン処理

まずはログイン処理の流れを説明します

  1. googleアカウントへのサインイン画面を出してログインするアカウントを選択してもらう

    これはgoogle_sign_inパッケージはを使います

  2. どのアカウントでログインしたかの情報が取れるので、そのアカウントのauthentication(認証)を取得し、firebaseにログインするために必要なcredential(資格証明)を取得
    これはfirebase_authに必要な関数が全てあります

  3. 最後に取得したcredentialをfirebaseAuthの認証する関数に渡してログイン完了

ここまでのコードを書くと

import 'package:firebase_auth/firebase_auth.dart';
import 'package:google_sign_in/google_sign_in.dart';

Future<User?> signIn() async {
    GoogleSignInAuthentication googleAuth;
    AuthCredential credential;
    GoogleSignInAccount? googleUser;
    final auth = FirebaseAuth.instance;
    
    final googleSignIn = GoogleSignIn(scope: ['email']);
    googleUser = await googleSignIn.signIn(); //サインインするアカウントを選択してもらう
    
    //アカウントが取得できなかった場合
    if(googleUser == null){
        return null;
    }

    //アカウント情報から認証情報を取得
    googleAuth = await googleUser!.authentication;
    credential = GoogleAuthProvider.credential(
        idToken: googleAuth.idToekn,
        accessToekn: googleAuth.accessToken,
    );
    
    //firebaseAuthで認証(サインイン)する
    final result = await auth.signInWithCredential(credential);
    return result.user;
}

ログイン中のユーザー情報の取得

これはとても簡単なのでコードを紹介するだけにします

User? currentUser(){
    return FIrebaseAuth.instance.currentUser;
}

ログアウト処理

最後にログアウト処理です
firebaseAuthでログアウトすれば良いだけと思っていましたが違いました(ここでバグを生みました...)
google_sign_inにはログインに用いたgoogleアカウントの情報が、firebase_authには認証情報でログインしたという情報が保存されています、そのため両方のパッケージでログアウト処理をしないといけません

処理としてはとても単純で

Future<void> signOut() async {
    awiat googleSignIn.signOut();
    await FirebaseAuth.instance.signOut();
    return;
}

これでログアウトが実装できます

アプリに実装したクラス

ログイン認証のためのインスタンスが生えまくっても色々めんどうなのでシングルトンパターンを採用しました

コードはこちら
import 'package:firebase_auth/firebase_auth.dart';
import 'package:google_sign_in/google_sign_in.dart';

class Auth{
    static final Auth _instance = Auth._internal();
    final _auth = FIrebasAuth.instance;
    Auth._internal();
    factory Auth() => _instance;

    GoogleSignInAccount? googleUser;
    
    final googleSignIn = GoogleSignIn(scope: ['email']);

    Future<User?> signIn() async {
        GoogleSignInAuthentication googleAugh;
        AuthCredential credential;
        
        googleUser = await googleSignIn.signIn();
        
        if(googleUser == null){
            return null;
        }
        
        googleAuth = await googleUser!.authentication;
        credential = GoogleAuthProvider.credential(
            idToken: googleAuth.idToekn,
            accessToekn: googleAuth.accessToken,
        );
        
        final result = await auth.signInWithCredential(credential);
        return result.user;
    }
    
    Future<void> signOut() async{
        await googleSignIn.signOut();
        await _auth.signOut();
        return;
    }
    
    User? currentUser() {
        return _auth.currentUser;
    }
}

2. Firestore

次はFirestoreです、使うパッケージはcloud_firestoreだけです

Firestoreとは何か、についてはこちらを見てください、dart以外の言語で使いたい場合にも参考になります

データ構造はこのような感じです(Firebaseの公式ドキュメントの画像です)

collection(フォルダ)の中にdocument(データの塊)があり、その中にdataが入っているというものです、documentの中には更にsubCollectionを追加することもできます

今回のアプリではユーザー情報などのアプリを使うのに必要な情報を全てFirestoreに保存しました
そのため情報の種類毎にクラスを定義してオブジェクトを生成してUI側に渡すという設計をしています
オブジェクトの取得のために静的メンバ関数を定義しました

情報毎にcollectionを作成しdocumentを追加しています、今回は汎用的な例としてhogecollectionの中にkeyがfugakey: string element1: string element2: int element3: arrayの要素を持ったdocumentの取得(と更新及び登録)についてのコードを示します

import 'package:cloud_firestore/cloud_firestore.dart';

class Hoge{
    String key;
    String element1;
    int element2;
    List<dynamic> element3;

    Hoge({
          required this.key,
          required this.element1,
          required this.element2,
          required this.element3
    });

    Hoge.fromJson(Map<String, Object?> json)
        : key = json['key'] as String,
          element1 = json['element1'] as String,
          element2 = json['element2'] as int,
          element3 = json['element3'] as List<dynamic>;
    
    Map<String, Object?> toJson(){
        return {
            'key': key,
            'element1': element1,
            'element2': element2,
            'element3': element3,
        }
    }

    static DocumentRederence<Hoge> _getRef(String key){
        return FirebaseFirestore.instance
            .collection('hoge')
            .doc(key)
            .withConverter<Hoge>(
                fromFirestore: (snapshot, _) => Hoge.fromJson(snapshot.data()!),
                toFirestore: (value, _) => value.toJson()
            );
    }

    Future<void> save() async {
        await _getRef(key).set(this);
    }

    static Future<Hoge?> getHoge(String key) async {
        final doc = await _getRef(key).get();
        return doc.data();
    }
}

//取得
final data = await Hoge.getHoge("fuga");

//更新、登録
final newData = Hoge(key="key", element1="element1", element2=2, element3=["hogehoge"]);
await newData.save();

converterを使うことでDocumentReferenceから直接オブジェクトを生成でき、登録や更新をするときも自分自身のオブジェクトを関数に渡すだけでDBを操作できます
fromJsonの名前付きコンストラクタやtoJsonなどの関数はconverterを使うのに必要な関数です、要素の増減はこの2つの関数を調整してあげるだけでできるのでとても楽になります

3. tensorflowLiteと画像処理

続いてはtensorflowLiteを用いた画像分類とそのための画像処理(前処理)に関してです
tensorflowLiteのモデルは自作したtensorflowモデルをtflite形式に変換して実装しました
tensorflowモデルについては鈴鹿高専Advent Calendar 2022 3日目の記事を見てください

tensorflowモデル to tflite

まずはflutterでtensorflowを使うためにモデルを変換するところを説明します
変換はpythonで簡単にできるのでコードを示して終わりです

import tensorflow as tf
import os
from keras.models import load_model
import tensorflow_addons as tfa

modelPath = "./model.h5" #kerasモデルファイルのパス

#modelの読み込み rrelu関数を使ったモデルだがtfliteには無いのでカスタムオブジェクトとして追加
model = load_model(modelPath, custom_objects={"rrelu": tfa.activations.rrelu})

#convert
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

open("./model.tflite", "wb").write(tflite_model)

これで終わりです

前処理と推論

前処理と推論ではdart:math package:image/image.dart package:tflite_flutter/tflite_flutter.dart package:tflite_flutter_helper/tflite_flutter_helper.dart package:collection/collection.dartを使います

前処理

前処理では入力画像を推論モデルに入力できる形式に変換します
今回は推論モデルの入力サイズが(height, width) = (384, 216)だったので、縦横の縮尺率が等しくなるよう計算どちらに合わせるかを計算して変換させています

import 'dart:math';
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

TensorImage _preProcess(Image _inputImage) {
    final width = 216;
    final height = 384;
    //リサイズする倍率の計算
    final resizeRatio =
        max(height / _inputImage.height, width / _inputImage.width);
    //リサイズ後のサイズ
    final resizedWidth = (_inputImage.width * resizeRatio).floor();
    final resizedHeight = (_inputImage.height * resizeRatio).floor();
    return ImageProcessorBuilder()
        .add(ResizeOp(
            resizedHeight, resizedWidth, ResizeMethod.NEAREST_NEIGHBOUR))
        .add(ResizeWithCropOrPadOp(
          height,
          width,
        ))
        .build()
        .process(_inputImage);
}

推論

これで前処理ができたので推論をさせます
まずはtfliteのモデルをインポートし、推論するためのオブジェクトのようなものを生成します
推論モデルのトレーニングで、各ピクセルの値を255で割って正規化していたのでその処理を一緒に埋め込んで読み込みます
分類ラベルの読み込みもこのタイミングでやります

// 必要な変数
bool _isInited = false;
bool _isModelLoaded = false;
bool _isLabelsLoaded = false;
late Interpreter _interpreter;
final _interpreterOptions = InterpreterOptions();
late List<int> _inputShape;
late List<int> _outputShape;
late TensorBuffer _outputBuffer;
late TfLiteType _inputType;
late TfLiteType _outputType;
late TensorImage _inputImage;
NormalizeOp get _postProcessNormalizeOp => NormalizeOp(0, 1);
late List<String> labels;
late SequentialProcessor<TensorBuffer> _probabilityProcessor;

/// モデルの読み込み
Future<void> loadModel() async {
    if (_isModelLoaded) return;
    try {
        _interpreter = await Interpreter.fromAsset("model.tflite",
            options: _interpreterOptions);
        _inputShape = _interpreter.getInputTensor(0).shape;
        _inputType = _interpreter.getInputTensor(0).type;
        _outputShape = _interpreter.getOutputTensor(0).shape;
        _outputType = _interpreter.getOutputTensor(0).type;
        _outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType);
        _probabilityProcessor =
            TensorProcessorBuilder().add(_postProcessNormalizeOp).build();
        _isModelLoaded = true;
        return;
    } catch (e) {
        throw Exception("Failed to load model");
    }
}

/// ラベルの読み込み
Future<void> loadLabels() async {
    if (_isLabelsLoaded) return;
    labels = await FileUtil.loadLabels("assets/labels.txt");
    _isLabelsLoaded = true;
    return;
}

これで読み込みはできたので推論する処理を書きます、変数は上記のものが生きているとします

/// 推論
Future<List<Category>> predict(Object image) async {
    if (!_isInited) await init();
    if (image is CameraImage) {
        image = ImageUtils.convertYUV420ToImage(image);
    }
    if (image is! Image) {
        throw Exception("Invalid image type");
    }
    _inputImage = TensorImage(_inputType);
    _inputImage.loadImage(image);
    _inputImage = _preProcess();

    _interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer());
    Map<String, double> labeledProb = TensorLabel.fromList(
            labels, _probabilityProcessor.process(_outputBuffer))
        .getMapWithFloatValue();
    final pred = getSortedProbability(labeledProb);
    List<Category> categories = [];
    for (var result in pred) {
        categories.add(Category(result.key, result.value));
    }
    return categories;
}

/// 確率が高い順にソート
List<MapEntry<String, double>> getSortedProbability(
        Map<String, double> labeledProb) {
    var pq = PriorityQueue<MapEntry<String, double>>(compare);
    pq.addAll(labeledProb.entries);

    // sort
    List<MapEntry<String, double>> sorted = [];
    while (pq.isNotEmpty) {
        sorted.add(pq.removeFirst());
    }

    return sorted;
}

/// ソートする際の基準
int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {
    if (e1.value > e2.value) {
        return -1;
    } else if (e1.value == e2.value) {
        return 0;
    } else {
        return 1;
    }
}

今回は推論結果を扱いやすいよう確率が高い順にソートされたCategoryクラスの配列を返しています
Categoryクラスはtflite_flutter_helperに定義されているクラスで、ラベルと確率をメンバに持っているものです

4. おわりに

今回の開発ではDB設計を変更することが何回かありましたが隠蔽化していたおかげでUI側への影響を最小限に抑えることができました
でも、そもそもDB設計を何回も変更することが良くないので初めにきちんと設計をするべきだったと思います

去年はRealtime Databaseを使い今年はFirestoreを使いましたが、クエリをそんなに使わないので正直どちらでも良いのではと思ってしまいました

tfliteを用いた画像分類では、tflite_flutterパッケージの公式がサンプルコードを載せていたので実装自体はすぐに終わりました
ただ、機械学習についての知識が乏しく前処理で何をしているかを理解していなかったので、学習時と推論時で前処理のやり方を統一していないなどというミスをしていました...

まだまだよわよわだなと改めて感じた開発でした

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0