Hot dog or Not hot dogとは
"What would you say if I told you there is an app on the market . . ."
Silicon Valley Season4 Epi.4で飲食向けのShazamを目指すべくチェン・イェンがアプリを作ったが、"ホットドック"か"ホットドックではない"しか見極められないクソアプリを作り上げる。
環境
- Ubuntu20.04
- Python3.7.9(ホットドッグ識別モデル作成用)
- Flutter(アプリ作成用)
ソースコード全体はこちら
https://github.com/bigface0202/Hotdog_or_NotHotdog
モデルの準備
画像の収集
画像はChrome拡張機能のImage Downloaderで集めます。
学習用のデータセットの中にhotdogかhotdog以外かのフォルダを作り、そこにダウンロードした画像を入れます。
hotdogにはもちろんホットドッグの写真、not_hotdogには寿司や味噌汁、ピザなどホットドッグ以外を入れましょう。
今回は転移学習の力に期待してhotdog、hotdog以外で各130枚程度画像を用意しましたが、本当は数千枚単位あったほうがちゃんと識別できるようになります。
また、バリデーション用にもデータセットを用意します。なるべく、トレーニングで使った画像と被らないようにしたほうが良いです(トレーニング用の画像の検索フレーズに日本語を使っていたら、バリデーション用の画像検索フレーズを英語にするなど)。
モデルの学習
2クラスのみ分類なので転移学習と微調整を参考にMobileNetv2をベースに学習させていきます。
データセットと同じフォルダの階層にjupyter notebookを作成しましょう。
基本はコピペでいいですが、トレーニング用のデータセットパスとバリデーション用のデータセットパスは自分で設定した名前に変更しましょう。
# 自分で作ったデータセットを使う場合は、ここは要らない
#_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
#path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
#PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
# trainとvalidationに自分で作ったデータセットのパスを指定する
train_dir = "./dataset_hg"
validation_dir = "./valdataset_hg"
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
train_dataset = image_dataset_from_directory(train_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
validation_dataset = image_dataset_from_directory(validation_dir,
shuffle=True,
batch_size=BATCH_SIZE,
image_size=IMG_SIZE)
あとはひたすら、サンプルのjupyter notebookをコピペしていきます。
もし3クラス以上の分類であれば、分類ヘッドを追加するの部分を変更する必要があります。
ちょっと怪しい画像も入っていますが、今回はヨシとします。
学習はとりあえず100エポック回しました。画像枚数が少ないこともあって、RTX2070の環境で大体5分くらいで終わりました。
その結果がこちらです。
Accuracyも0.9に到達してるので、これを使ってテストデータで検証してみましょう。
結構うまく識別できてますね、良い感じです。
最後に学習済みのモデルを保存しましよう
model.save('saved_model/hotdog')
モデルのコンバージョン
トレーニング済みのpbモデルからtfliteへコンバージョンします。参考:訓練後の量子化
Flutter上でサクサク動いてもらいたいので、重みの量子化も行います。
以下はさっきのjupyter notebookの続きでもいいですし、新たに作ってもらっても構いません。
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model/hotdog")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
with open('./hotdog.tflite', 'wb') as w:
w.write(tflite_quant_model)
print("Weight Quantization complete")
これでTFLiteモデル作成は完了です。一旦お疲れ様でした。
Flutterアプリ作成
環境設定
以下のライブラリを使います。
※Google fontsは本家に近づけるために使っているだけなので、使わなくてもいいです。
ディレクトリの構成
FlutterでNew Projectをした状態から、
assets
├── hotdog.tflite
└── labels.txt
...
lib
├── image_input.dart
├── index_screen.dart
└── main.dart
...
assetsとlibの中身がこうなります。他はpubspec.yamlにライブラリ記述するくらいです。
assetsの中身についても、pubspec.yamlに追記することを忘れずに。
hotdog.tfliteはTFLiteのモデルファイル、labels.txtは識別クラス数です。
識別クラス数は今回2つですが、モデルの出力が”ホットドッグか”、”ホットドッグではない”かを識別する確信度のみが出力されるようになっているので、labels.txtにはnot-hotdogのみ書いておいてください。
not-hotdog
余談
出力には2クラス分の確信度が出てくると思っていたため、ラベルに"hotdog", "not-hotdog"と書いていたのが原因で1時間くらいエラーでコケてました…
ちゃんとpython側の出力で確認しないとダメですね。
main.dart & index_screen.dart
import 'package:flutter/material.dart';
import 'index_screen.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
// This widget is the root of your application.
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'SEE FOOD',
theme: ThemeData(
primarySwatch: Colors.blue,
primaryColor: Colors.black,
),
home: IndexScreen(),
);
}
}
import 'dart:io';
import "package:flutter/material.dart";
import "./image_input.dart";
class IndexScreen extends StatelessWidget {
File _pickedImage;
void _selectImage(File pickedImage) {
_pickedImage = pickedImage;
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('SEE FOOD'),
),
body: ImageInput(_selectImage),
);
}
}
この2つは特に特筆することはありません。
アプリ名が”SEE FOOD”なのは、アーリック・バックマンの咄嗟の思い付きでついた名前です。
ここもストーリー見てみてください。くだらなさすぎて僕は大好きです。
image_input.dart
とりあえず全体を記します。
import 'dart:io';
import 'dart:math';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';
import 'package:google_fonts/google_fonts.dart';
class ImageInput extends StatefulWidget {
final Function onSelectImage;
ImageInput(this.onSelectImage);
@override
_ImageInputState createState() => _ImageInputState();
}
class _ImageInputState extends State<ImageInput> {
File _storedImage;
final picker = ImagePicker();
String resultText = '';
bool isHotdog = false;
bool isRecognized = false;
Future<void> _takePicture() async {
final imageFile = await picker.getImage(
source: ImageSource.camera,
);
if (imageFile == null) {
return;
}
setState(() {
_storedImage = File(imageFile.path);
});
predictHotdog(File(imageFile.path));
}
Future<void> _getImageFromGallery() async {
final imageFile = await picker.getImage(
source: ImageSource.gallery,
);
if (imageFile == null) {
return;
}
setState(() {
_storedImage = File(imageFile.path);
});
predictHotdog(File(imageFile.path));
}
static Future loadModel() async {
Tflite.close();
try {
await Tflite.loadModel(
model: 'assets/hotdog.tflite', labels: 'assets/labels.txt');
} on PlatformException {
print("Failed to load the model");
}
}
Future predictHotdog(File image) async {
var recognition = await Tflite.runModelOnImage(
path: image.path,
imageMean: 117, // defaults to 117.0
imageStd: 117, // defaults to 1.0
numResults: 2, // defaults to 5
threshold: 0.2, // defaults to 0.1
asynch: true,
);
if (recognition.isNotEmpty) {
if (recognition[0]["confidence"] < 0.5) {
setState(() {
isRecognized = true;
isHotdog = true;
resultText = "Hotdog";
});
} else {
setState(() {
isRecognized = true;
isHotdog = false;
resultText = "Not Hotdog";
});
}
}
}
@override
void initState() {
super.initState();
loadModel().then((val) {
setState(() {});
});
}
@override
Widget build(BuildContext context) {
final Size size = MediaQuery.of(context).size;
return Column(
children: [
Stack(
children: [
Container(
width: size.width,
height: 480,
alignment: Alignment.center,
decoration: BoxDecoration(
border: Border.all(width: 1, color: Colors.grey),
),
child: _storedImage != null
? Image.file(
_storedImage,
fit: BoxFit.cover,
width: double.infinity,
)
: Text(
'No Image Taken',
textAlign: TextAlign.center,
),
),
isRecognized
? Stack(
children: [
Container(
color: isHotdog ? Colors.green : Colors.red,
height: 80,
padding: EdgeInsets.all(10),
alignment: Alignment.topCenter,
child: Text(
"$resultText",
style: GoogleFonts.bungeeInline(
textStyle: TextStyle(
fontSize: 40,
fontWeight: FontWeight.normal,
),
),
textAlign: TextAlign.center,
),
),
Container(
height: 120,
alignment: Alignment.bottomCenter,
child: CircleAvatar(
maxRadius: 35,
backgroundColor: isHotdog ? Colors.green : Colors.red,
child: Icon(
isHotdog ? Icons.check : Icons.clear,
size: 50,
color: Colors.white,
),
),
),
],
)
: Container(),
],
),
Row(
mainAxisAlignment: MainAxisAlignment.spaceAround,
children: [
Expanded(
child: FlatButton.icon(
icon: Icon(Icons.photo_camera),
label: Text('カメラ'),
textColor: Theme.of(context).primaryColor,
onPressed: _takePicture,
),
),
Expanded(
child: FlatButton.icon(
icon: Icon(Icons.photo_library),
label: Text('ギャラリー'),
textColor: Theme.of(context).primaryColor,
onPressed: _getImageFromGallery,
),
),
],
),
],
);
}
}
TFLite
static Future loadModel() async {
Tflite.close();
try {
await Tflite.loadModel(
model: 'assets/hotdog.tflite', labels: 'assets/labels.txt');
} on PlatformException {
print("Failed to load the model");
}
}
これでモデルを読み込むことができます。
この関数をinitStateに記述することで、画面が開いたら読み込むようにしています。
Future predictHotdog(File image) async {
var recognition = await Tflite.runModelOnImage(
path: image.path,
imageMean: 117, // defaults to 117.0
imageStd: 117, // defaults to 1.0
numResults: 2, // defaults to 5
threshold: 0.2, // defaults to 0.1
asynch: true,
);
if (recognition.isNotEmpty) {
if (recognition[0]["confidence"] < 0.5) {
setState(() {
isRecognized = true;
isHotdog = true;
resultText = "Hotdog";
});
} else {
setState(() {
isRecognized = true;
isHotdog = false;
resultText = "Not Hotdog";
});
}
}
}
TFlite.runModelOnImageに画像を渡すことで推論を行えます。パラメータについては自分のモデルに合わせてください(正直のところ、私は試行錯誤的に適当に決めてます)。
recognitionにはConfidenceとラベル名が出力されますが、今回はラベル名には意味がありません。
Confidenceが0.5以下だと"Hotdog", 0.5以上だと"Not Hotdog"なのでそれに合わせて結果のテキストを分岐させています。
Confidenceが1を超える…?
今回の出力がちょっとおかしい気がします。TF上では出力結果に対してSigmoidを適用してクラス分類するのですが、Flutter上の出力(ここではConfidence)に対してSigmoidを適用してもPythonと同じ結果になりません。
ただ、”ホットドッグ”か”ホットドッグではない”かによって、Confidenceの値が大きく異なる(大体0.5を境にしている)ので、とりあえずはこの値を使って判別しています。
モデルの出力には各クラスの確信度を出力するようにしないといけないのかもしれません。
また原因がわかったら修正します。
画像の読み込み
Future<void> _takePicture() async {
final imageFile = await picker.getImage(
source: ImageSource.camera,
);
if (imageFile == null) {
return;
}
setState(() {
_storedImage = File(imageFile.path);
});
predictHotdog(File(imageFile.path));
}
Future<void> _getImageFromGallery() async {
final imageFile = await picker.getImage(
source: ImageSource.gallery,
);
if (imageFile == null) {
return;
}
setState(() {
_storedImage = File(imageFile.path);
});
predictHotdog(File(imageFile.path));
}
画像の読み込みにはimage_pickerを使っています。sourceを指定するだけでカメラ、もしくはギャラリーから画像を取ってこれるので便利です。
また、画像を指定したらすぐにpredictHotdogへ渡すことによって、すぐ推論を回すようにしています。
これでホットドッグかホットドッグではないかを見極めることができます。
モデル読み込み時のエラー
java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed
というエラーが発生しました。.tfliteのファイルを読み込むことができないようです。
./android/app/build.gradleに以下の記述を追加してください。
...
android{
...
aaptOptions {
noCompress "tflite"
}
}
"It's food for Shazam!"
最後に
最終的にこのエピソードでは、この”ホットドッグ”か”ホットドッグではない”かの認識率を究極に高めたことによって、”あるホットドッグ”の識別に応用できるとの理由からPeriscope社へアプリを事業譲渡することによってイグジット達成。
”あるホットドッグ”については、Qiitaに書いたらアカウント削除されそうなので是非Silicon Valleyを見てください。
ちなみに自分が作った程度の認識率では、”ある識別”に応用できるほどの精度はないので、自由に使ってください。