Help us understand the problem. What is going on with this article?

FlutterとTFLiteで”Hotdog or Not hotdog”

Hot dog or Not hot dogとは

"What would you say if I told you there is an app on the market . . ."
Silicon Valley Season4 Epi.4で飲食向けのShazamを目指すべくチェン・イェンがアプリを作ったが、"ホットドック"か"ホットドックではない"しか見極められないクソアプリを作り上げる。

環境

  • Ubuntu20.04
  • Python3.7.9(ホットドッグ識別モデル作成用)
  • Flutter(アプリ作成用)

ソースコード全体はこちら
https://github.com/bigface0202/Hotdog_or_NotHotdog

モデルの準備

画像の収集

画像はChrome拡張機能のImage Downloaderで集めます。
Screenshot from 2020-11-29 09-43-11.png
学習用のデータセットの中にhotdogかhotdog以外かのフォルダを作り、そこにダウンロードした画像を入れます。
hotdogにはもちろんホットドッグの写真、not_hotdogには寿司や味噌汁、ピザなどホットドッグ以外を入れましょう。
今回は転移学習の力に期待してhotdog、hotdog以外で各130枚程度画像を用意しましたが、本当は数千枚単位あったほうがちゃんと識別できるようになります。
また、バリデーション用にもデータセットを用意します。なるべく、トレーニングで使った画像と被らないようにしたほうが良いです(トレーニング用の画像の検索フレーズに日本語を使っていたら、バリデーション用の画像検索フレーズを英語にするなど)。

モデルの学習

2クラスのみ分類なので転移学習と微調整を参考にMobileNetv2をベースに学習させていきます。
データセットと同じフォルダの階層にjupyter notebookを作成しましょう。
基本はコピペでいいですが、トレーニング用のデータセットパスとバリデーション用のデータセットパスは自分で設定した名前に変更しましょう。

# 自分で作ったデータセットを使う場合は、ここは要らない
#_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
#path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
#PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

# trainとvalidationに自分で作ったデータセットのパスを指定する
train_dir = "./dataset_hg"
validation_dir = "./valdataset_hg"

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)

あとはひたすら、サンプルのjupyter notebookをコピペしていきます。
もし3クラス以上の分類であれば、分類ヘッドを追加するの部分を変更する必要があります。
Screenshot from 2020-11-29 09-56-02.png
ちょっと怪しい画像も入っていますが、今回はヨシとします。
学習はとりあえず100エポック回しました。画像枚数が少ないこともあって、RTX2070の環境で大体5分くらいで終わりました。
その結果がこちらです。
Screenshot from 2020-11-29 10-12-23.png
Accuracyも0.9に到達してるので、これを使ってテストデータで検証してみましょう。
Screenshot from 2020-11-29 10-15-35.png
結構うまく識別できてますね、良い感じです。
最後に学習済みのモデルを保存しましよう

学習に使ったjupyternotebookで以下を実行
model.save('saved_model/hotdog')

モデルのコンバージョン

トレーニング済みのpbモデルからtfliteへコンバージョンします。参考:訓練後の量子化
Flutter上でサクサク動いてもらいたいので、重みの量子化も行います。
以下はさっきのjupyter notebookの続きでもいいですし、新たに作ってもらっても構いません。

converter = tf.lite.TFLiteConverter.from_saved_model("saved_model/hotdog")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
with open('./hotdog.tflite', 'wb') as w:
    w.write(tflite_quant_model)
print("Weight Quantization complete")

これでTFLiteモデル作成は完了です。一旦お疲れ様でした。

Flutterアプリ作成

環境設定

以下のライブラリを使います。

※Google fontsは本家に近づけるために使っているだけなので、使わなくてもいいです。

ディレクトリの構成

FlutterでNew Projectをした状態から、

assets
├── hotdog.tflite
└── labels.txt
...
lib
├── image_input.dart
├── index_screen.dart
└── main.dart
...

assetsとlibの中身がこうなります。他はpubspec.yamlにライブラリ記述するくらいです。
assetsの中身についても、pubspec.yamlに追記することを忘れずに。
hotdog.tfliteはTFLiteのモデルファイル、labels.txtは識別クラス数です。
識別クラス数は今回2つですが、モデルの出力が”ホットドッグか”、”ホットドッグではない”かを識別する確信度のみが出力されるようになっているので、labels.txtにはnot-hotdogのみ書いておいてください。

labels.txt
not-hotdog

余談

出力には2クラス分の確信度が出てくると思っていたため、ラベルに"hotdog", "not-hotdog"と書いていたのが原因で1時間くらいエラーでコケてました…
ちゃんとpython側の出力で確認しないとダメですね。

main.dart & index_screen.dart

main.dart
import 'package:flutter/material.dart';
import 'index_screen.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  // This widget is the root of your application.
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'SEE FOOD',
      theme: ThemeData(
        primarySwatch: Colors.blue,
        primaryColor: Colors.black,
      ),
      home: IndexScreen(),
    );
  }
}

index_screen.dart
import 'dart:io';

import "package:flutter/material.dart";

import "./image_input.dart";

class IndexScreen extends StatelessWidget {
  File _pickedImage;

  void _selectImage(File pickedImage) {
    _pickedImage = pickedImage;
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('SEE FOOD'),
      ),
      body: ImageInput(_selectImage),
    );
  }
}

この2つは特に特筆することはありません。
アプリ名が”SEE FOOD”なのは、アーリック・バックマンの咄嗟の思い付きでついた名前です。
ここもストーリー見てみてください。くだらなさすぎて僕は大好きです。

image_input.dart

とりあえず全体を記します。

image_input.dart
import 'dart:io';
import 'dart:math';

import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';
import 'package:google_fonts/google_fonts.dart';

class ImageInput extends StatefulWidget {
  final Function onSelectImage;

  ImageInput(this.onSelectImage);

  @override
  _ImageInputState createState() => _ImageInputState();
}

class _ImageInputState extends State<ImageInput> {
  File _storedImage;
  final picker = ImagePicker();
  String resultText = '';
  bool isHotdog = false;
  bool isRecognized = false;

  Future<void> _takePicture() async {
    final imageFile = await picker.getImage(
      source: ImageSource.camera,
    );
    if (imageFile == null) {
      return;
    }
    setState(() {
      _storedImage = File(imageFile.path);
    });
    predictHotdog(File(imageFile.path));
  }

  Future<void> _getImageFromGallery() async {
    final imageFile = await picker.getImage(
      source: ImageSource.gallery,
    );
    if (imageFile == null) {
      return;
    }
    setState(() {
      _storedImage = File(imageFile.path);
    });
    predictHotdog(File(imageFile.path));
  }

  static Future loadModel() async {
    Tflite.close();
    try {
      await Tflite.loadModel(
          model: 'assets/hotdog.tflite', labels: 'assets/labels.txt');
    } on PlatformException {
      print("Failed to load the model");
    }
  }

  Future predictHotdog(File image) async {
    var recognition = await Tflite.runModelOnImage(
      path: image.path,
      imageMean: 117, // defaults to 117.0
      imageStd: 117, // defaults to 1.0
      numResults: 2, // defaults to 5
      threshold: 0.2, // defaults to 0.1
      asynch: true,
    );

    if (recognition.isNotEmpty) {
      if (recognition[0]["confidence"] < 0.5) {
        setState(() {
          isRecognized = true;
          isHotdog = true;
          resultText = "Hotdog";
        });
      } else {
        setState(() {
          isRecognized = true;
          isHotdog = false;
          resultText = "Not Hotdog";
        });
      }
    }
  }

  @override
  void initState() {
    super.initState();
    loadModel().then((val) {
      setState(() {});
    });
  }

  @override
  Widget build(BuildContext context) {
    final Size size = MediaQuery.of(context).size;
    return Column(
      children: [
        Stack(
          children: [
            Container(
              width: size.width,
              height: 480,
              alignment: Alignment.center,
              decoration: BoxDecoration(
                border: Border.all(width: 1, color: Colors.grey),
              ),
              child: _storedImage != null
                  ? Image.file(
                      _storedImage,
                      fit: BoxFit.cover,
                      width: double.infinity,
                    )
                  : Text(
                      'No Image Taken',
                      textAlign: TextAlign.center,
                    ),
            ),
            isRecognized
                ? Stack(
                    children: [
                      Container(
                        color: isHotdog ? Colors.green : Colors.red,
                        height: 80,
                        padding: EdgeInsets.all(10),
                        alignment: Alignment.topCenter,
                        child: Text(
                          "$resultText",
                          style: GoogleFonts.bungeeInline(
                            textStyle: TextStyle(
                              fontSize: 40,
                              fontWeight: FontWeight.normal,
                            ),
                          ),
                          textAlign: TextAlign.center,
                        ),
                      ),
                      Container(
                        height: 120,
                        alignment: Alignment.bottomCenter,
                        child: CircleAvatar(
                          maxRadius: 35,
                          backgroundColor: isHotdog ? Colors.green : Colors.red,
                          child: Icon(
                            isHotdog ? Icons.check : Icons.clear,
                            size: 50,
                            color: Colors.white,
                          ),
                        ),
                      ),
                    ],
                  )
                : Container(),
          ],
        ),
        Row(
          mainAxisAlignment: MainAxisAlignment.spaceAround,
          children: [
            Expanded(
              child: FlatButton.icon(
                icon: Icon(Icons.photo_camera),
                label: Text('カメラ'),
                textColor: Theme.of(context).primaryColor,
                onPressed: _takePicture,
              ),
            ),
            Expanded(
              child: FlatButton.icon(
                icon: Icon(Icons.photo_library),
                label: Text('ギャラリー'),
                textColor: Theme.of(context).primaryColor,
                onPressed: _getImageFromGallery,
              ),
            ),
          ],
        ),
      ],
    );
  }
}


TFLite

モデルの読み込み
static Future loadModel() async {
    Tflite.close();
    try {
      await Tflite.loadModel(
          model: 'assets/hotdog.tflite', labels: 'assets/labels.txt');
    } on PlatformException {
      print("Failed to load the model");
    }
  }

これでモデルを読み込むことができます。
この関数をinitStateに記述することで、画面が開いたら読み込むようにしています。

推論部分
Future predictHotdog(File image) async {
    var recognition = await Tflite.runModelOnImage(
      path: image.path,
      imageMean: 117, // defaults to 117.0
      imageStd: 117, // defaults to 1.0
      numResults: 2, // defaults to 5
      threshold: 0.2, // defaults to 0.1
      asynch: true,
    );

    if (recognition.isNotEmpty) {
      if (recognition[0]["confidence"] < 0.5) {
        setState(() {
          isRecognized = true;
          isHotdog = true;
          resultText = "Hotdog";
        });
      } else {
        setState(() {
          isRecognized = true;
          isHotdog = false;
          resultText = "Not Hotdog";
        });
      }
    }
  }

TFlite.runModelOnImageに画像を渡すことで推論を行えます。パラメータについては自分のモデルに合わせてください(正直のところ、私は試行錯誤的に適当に決めてます)。
recognitionにはConfidenceとラベル名が出力されますが、今回はラベル名には意味がありません。
Confidenceが0.5以下だと"Hotdog", 0.5以上だと"Not Hotdog"なのでそれに合わせて結果のテキストを分岐させています。

Confidenceが1を超える…?

今回の出力がちょっとおかしい気がします。TF上では出力結果に対してSigmoidを適用してクラス分類するのですが、Flutter上の出力(ここではConfidence)に対してSigmoidを適用してもPythonと同じ結果になりません。
ただ、”ホットドッグ”か”ホットドッグではない”かによって、Confidenceの値が大きく異なる(大体0.5を境にしている)ので、とりあえずはこの値を使って判別しています。
モデルの出力には各クラスの確信度を出力するようにしないといけないのかもしれません。
また原因がわかったら修正します。

画像の読み込み

画像をカメラorギャラリーから取得する
Future<void> _takePicture() async {
    final imageFile = await picker.getImage(
      source: ImageSource.camera,
    );
    if (imageFile == null) {
      return;
    }
    setState(() {
      _storedImage = File(imageFile.path);
    });
    predictHotdog(File(imageFile.path));
  }

  Future<void> _getImageFromGallery() async {
    final imageFile = await picker.getImage(
      source: ImageSource.gallery,
    );
    if (imageFile == null) {
      return;
    }
    setState(() {
      _storedImage = File(imageFile.path);
    });
    predictHotdog(File(imageFile.path));
  }

画像の読み込みにはimage_pickerを使っています。sourceを指定するだけでカメラ、もしくはギャラリーから画像を取ってこれるので便利です。
また、画像を指定したらすぐにpredictHotdogへ渡すことによって、すぐ推論を回すようにしています。
これでホットドッグかホットドッグではないかを見極めることができます。

モデル読み込み時のエラー

モデルが読み込めない
java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed

というエラーが発生しました。.tfliteのファイルを読み込むことができないようです。

./android/app/build.gradleに以下の記述を追加してください。

android/app/build.gradle
...
android{
...
    aaptOptions {
        noCompress "tflite"
    }
}

"It's food for Shazam!"

HotdogApp_final.gif

最後に

最終的にこのエピソードでは、この”ホットドッグ”か”ホットドッグではない”かの認識率を究極に高めたことによって、”あるホットドッグ”の識別に応用できるとの理由からPeriscope社へアプリを事業譲渡することによってイグジット達成。
”あるホットドッグ”については、Qiitaに書いたらアカウント削除されそうなので是非Silicon Valleyを見てください。
ちなみに自分が作った程度の認識率では、”ある識別”に応用できるほどの精度はないので、自由に使ってください。

bigface00
まぁまぁ顔がでかいほうだと思います。
global_walkers
グローバルウォーカーズでは、AIのコンピュータビジョンによる検知を生かし、カメラによる物体検知、損傷検知、文字認識など、様々な分野でAIを活用したソリューションサービスを展開しております。高精度なAI開発には、高品質かつ大量のデータが必要となる中、弊社では教師データを作成する独自のプラットフォームを有しており、データの作成からシステムの構築までをワンストップでできることを強みとしております。
https://www.globalwalkers.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away