この記事は全部俺 Advent Calendar 2018の16日目の記事です。
やること
Flutterを使って、Firebase MLKitのカスタムモデルを動かします。
Firebase MLKitでは、TensorflowLiteのモデルをサポートしています。
動作しているイメージはこんな感じです。
結果画面が早くてみにくいですが、ちゃんとtoilet tissueが一番スコアが高くなっていることがわかります。
(前回の記事でAndroidのgifを載せたので今回はiPhone XRでの動作例を載せています。)
注意
mlkitを使用するために使うライブラリは、公式のfirebase_ml_visionではなく、flutter_mlkitにしています。
これは、公式が未だにカスタムモデルをサポートしていないためです。
公式にカスタムモデルがサポートされた場合は、そちらを使用したほうが良いと思いますし、むしろカスタムモデルを使用しない前提なら、公式のfirebase_ml_visionを使用したほうが良いと思います。(実は公式リポジトリもあんまり更新されていないのですが。。)
flutter_mlkitの作者様も、「公式のfirebase_ml_visionを使用することを検討してください」と言っています。
公式がカスタムモデルをサポートしたら、また記事を上げるつもりです。
おそらくFirebase MLKitがβじゃなくなってからサポートされるのでは?と思っています。
カスタムモデルを使用せずに組み込みのモデルを使用する場合は、こちらの記事を参考にしてください。(公式リポジトリではなくflutter_mlkitを使用しています。)
組み込みモデルだけでも、以下のことができます。
- 画像からの文字起こし
- 画像に写っているもののラベリング
- バーコードスキャン
- 顔画像取り出し
カスタムモデルのダウンロードとFirebase MLKitの準備
ここから、mobilenet_v1_1.0_224_quant.tgz
をダウンロードしてきて解凍します。
このモデルは、Inputとなる画像をここに記載されているラベルで分類するモデルとなります。
ダウンロードしたあとは、Flutterのプロジェクトを作成しておきます。(すでに作成してあるものを使う場合はスルーでOKです。)
名前は何でもいいのですが、今回はProject名をmlkit_sample
、Company Domainをmlkit
として、パッケージ名がmlkit.mlkitsample
という名前になるように設定しました。
以下、この名称を使用しますが、変更する場合は適宜読み替えてください。
Firebaseコンソールから、MLKitを選択し、「カスタム」タブから「カスタムモデルを追加」を選択します。
カスタムモデルの名前を入力することになるので、好きな名前を入れます。ここでは、先ほどのモデル名称をもとにmobilenet_v1_224_quant
という名前にしました。
この名前が後でソースコードからアクセスする際に使用する名前になるので、わかりやすい名前をつけておくと良いです。
その後、実際に使用するカスタムモデルをアップロードするので、先ほど解凍したモデルから、mobilenet_v1_1.0_224_quant.tflite
を選んでアップロードしておきます。
モデルのアップロードが完了したら、「公開」を押します。これでFirebase上での作業は完了です。
必要であれば、google-services.json
とGoogleService-Info.plist
をダウンロードしてプロジェクトに配置しておきます。
具体的な手順はこちらを参照してください。
Flutterソースコード
labels.txtの作成とpubspeck.yamlの記載
まず、今回のカスタムモデルが使用するlabels.txt
を作成します。
ここを参考に、プロジェクト配下にassets/label.txt
を作成しましょう。
ソースコードからアクセスできるようにpubspeck.yaml
に- assets/labels.txt
を記載しておきます。
pubspeck.yaml
の全量は以下のようになります。
name: mlkit_sample
description: mlkit_sample application with custom model.
version: 1.0.0+1
environment:
sdk: ">=2.0.0-dev.68.0 <3.0.0"
dependencies:
flutter:
sdk: flutter
cupertino_icons: ^0.1.2
image_picker:
mlkit:
image:
dev_dependencies:
flutter_test:
sdk: flutter
flutter:
uses-material-design: true
assets:
- assets/labels.txt
lib以下のソースコード
以下の2ファイルを作成します。
ソースコードは、こちらからお借りしたものをもとに編集しました。
※長くなったので折りたたみました。
lib/main.dart
import 'package:mlkit_sample/ml_detail.dart';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'dart:io';
const String TEXT_SCANNER = 'TEXT_SCANNER';
const String BARCODE_SCANNER = 'BARCODE_SCANNER';
const String LABEL_SCANNER = 'LABEL_SCANNER';
const String FACE_SCANNER = 'FACE_SCANNER';
const String CUSTOM_MODEL = 'CUSTOM_MODEL';
void main() => runApp(new MyApp());
class MyApp extends StatelessWidget {
@override
Widget build(BuildContext context) {
return new MaterialApp(
debugShowCheckedModeBanner: false,
title: 'Flutter Demo',
theme: new ThemeData(
primarySwatch: Colors.blue,
),
home: new MLHome(),
);
}
}
class MLHome extends StatefulWidget {
MLHome({Key key}) : super(key: key);
@override
State<StatefulWidget> createState() => _MLHomeState();
}
class _MLHomeState extends State<MLHome> {
static const String CAMERA_SOURCE = 'CAMERA_SOURCE';
static const String GALLERY_SOURCE = 'GALLERY_SOURCE';
final GlobalKey<ScaffoldState> _scaffoldKey = GlobalKey<ScaffoldState>();
File _file;
String _selectedScanner = TEXT_SCANNER;
@override
Widget build(BuildContext context) {
final columns = List<Widget>();
columns.add(buildRowTitle(context, 'Select Scanner Type'));
columns.add(buildSelectScannerRowWidget(context));
columns.add(buildRowTitle(context, 'Pick Image'));
columns.add(buildSelectImageRowWidget(context));
return Scaffold(
key: _scaffoldKey,
appBar: AppBar(
centerTitle: true,
title: Text('MLKit Demo'),
),
body: SingleChildScrollView(
child: Column(
children: columns,
),
));
}
Widget buildRowTitle(BuildContext context, String title) {
return Center(
child: Padding(
padding: EdgeInsets.symmetric(horizontal: 8.0, vertical: 16.0),
child: Text(
title,
style: Theme.of(context).textTheme.headline,
),
));
}
Widget buildSelectImageRowWidget(BuildContext context) {
return Row(
children: <Widget>[
Expanded(
child: Padding(
padding: EdgeInsets.symmetric(horizontal: 8.0),
child: RaisedButton(
color: Colors.blue,
textColor: Colors.white,
splashColor: Colors.blueGrey,
onPressed: () {
onPickImageSelected(CAMERA_SOURCE);
},
child: const Text('Camera')),
)),
Expanded(
child: Padding(
padding: EdgeInsets.symmetric(horizontal: 8.0),
child: RaisedButton(
color: Colors.blue,
textColor: Colors.white,
splashColor: Colors.blueGrey,
onPressed: () {
onPickImageSelected(GALLERY_SOURCE);
},
child: const Text('Gallery')),
))
],
);
}
Widget buildSelectScannerRowWidget(BuildContext context) {
return Wrap(
children: <Widget>[
RadioListTile<String>(
title: Text('Text Recognition'),
groupValue: _selectedScanner,
value: TEXT_SCANNER,
onChanged: onScannerSelected,
),
RadioListTile<String>(
title: Text('Barcode Scanner'),
groupValue: _selectedScanner,
value: BARCODE_SCANNER,
onChanged: onScannerSelected,
),
RadioListTile<String>(
title: Text('Label Scanner'),
groupValue: _selectedScanner,
value: LABEL_SCANNER,
onChanged: onScannerSelected,
),
RadioListTile<String>(
title: Text('Face Scanner'),
groupValue: _selectedScanner,
value: FACE_SCANNER,
onChanged: onScannerSelected,
),
RadioListTile<String>(
title: Text('Custom Model'),
groupValue: _selectedScanner,
value: CUSTOM_MODEL,
onChanged: onScannerSelected,
)
],
);
}
Widget buildImageRow(BuildContext context, File file) {
return SizedBox(
height: 500.0,
child: Image.file(
file,
fit: BoxFit.fitWidth,
));
}
Widget buildDeleteRow(BuildContext context) {
return Center(
child: Padding(
padding: EdgeInsets.symmetric(horizontal: 8.0, vertical: 8.0),
child: RaisedButton(
color: Colors.red,
textColor: Colors.white,
splashColor: Colors.blueGrey,
onPressed: () {
setState(() {
_file = null;
});
;
},
child: const Text('Delete Image')),
),
);
}
void onScannerSelected(String scanner) {
setState(() {
_selectedScanner = scanner;
});
}
void onPickImageSelected(String source) async {
var imageSource;
if (source == CAMERA_SOURCE) {
imageSource = ImageSource.camera;
} else {
imageSource = ImageSource.gallery;
}
final scaffold = _scaffoldKey.currentState;
try {
final file = await ImagePicker.pickImage(source: imageSource);
if (file == null) {
throw Exception('File is not available');
}
Navigator.push(
context,
new MaterialPageRoute(
builder: (context) => MLDetail(file, _selectedScanner)),
);
} catch (e) {
scaffold.showSnackBar(SnackBar(
content: Text(e.toString()),
));
}
}
}
lib/ml_detail.dart
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'dart:io';
import 'dart:async';
import 'package:mlkit/mlkit.dart';
import 'package:mlkit_sample/main.dart';
import 'package:image/image.dart' as img;
class MLDetail extends StatefulWidget {
final File _file;
final String _scannerType;
MLDetail(this._file, this._scannerType);
@override
State<StatefulWidget> createState() {
return _MLDetailState();
}
}
class _MLDetailState extends State<MLDetail> {
FirebaseVisionTextDetector textDetector = FirebaseVisionTextDetector.instance;
FirebaseVisionBarcodeDetector barcodeDetector =
FirebaseVisionBarcodeDetector.instance;
FirebaseVisionLabelDetector labelDetector =
FirebaseVisionLabelDetector.instance;
FirebaseVisionFaceDetector faceDetector = FirebaseVisionFaceDetector.instance;
List<VisionText> _currentTextLabels = <VisionText>[];
List<VisionBarcode> _currentBarcodeLabels = <VisionBarcode>[];
List<VisionLabel> _currentLabelLabels = <VisionLabel>[];
List<VisionFace> _currentFaceLabels = <VisionFace>[];
List<_CustomLabelScore> _currentCustomLabels = <_CustomLabelScore>[];
// For firebase MLKit custom model
FirebaseModelInterpreter interpreter = FirebaseModelInterpreter.instance;
FirebaseModelManager manager = FirebaseModelManager.instance;
Stream sub;
StreamSubscription<dynamic> subscription;
@override
void initState() {
super.initState();
sub = new Stream.empty();
subscription = sub.listen((_) => _getImageSize)..onDone(analyzeLabels);
manager.registerCloudModelSource(
FirebaseCloudModelSource(modelName: "mobilenet_v1_224_quant"));
}
void analyzeLabels() async {
try {
var currentLabels;
if (widget._scannerType == TEXT_SCANNER) {
currentLabels = await textDetector.detectFromPath(widget._file.path);
if (this.mounted) {
setState(() {
_currentTextLabels = currentLabels;
});
}
} else if (widget._scannerType == BARCODE_SCANNER) {
currentLabels = await barcodeDetector.detectFromPath(widget._file.path);
if (this.mounted) {
setState(() {
_currentBarcodeLabels = currentLabels;
});
}
} else if (widget._scannerType == LABEL_SCANNER) {
currentLabels = await labelDetector.detectFromPath(widget._file.path);
if (this.mounted) {
setState(() {
_currentLabelLabels = currentLabels;
});
}
} else if (widget._scannerType == FACE_SCANNER) {
currentLabels = await faceDetector.detectFromPath(widget._file.path);
if (this.mounted) {
setState(() {
_currentFaceLabels = currentLabels;
});
}
} else if (widget._scannerType == CUSTOM_MODEL) {
var customLabels = (await rootBundle.loadString('assets/labels.txt')).split("\n");
var imageBytes = (await rootBundle.load(widget._file.path)).buffer;
img.Image image = img.decodeJpg(imageBytes.asUint8List());
image = img.copyResize(image, 224, 224);
var results = await interpreter.run(
"mobilenet_v1_224_quant",
FirebaseModelInputOutputOptions(
0,
FirebaseModelDataType.BYTE,
[1, 224, 224, 3],
0,
FirebaseModelDataType.BYTE,
[1, 1001]),
imageToByteList(image));
var labelScores = <_CustomLabelScore>[];
for (int i = 0; i < customLabels.length; i++) {
if (results[i] > 0) {
labelScores.add(_CustomLabelScore(customLabels[i], results[i]));
}
}
labelScores.sort((a, b) => b.score.compareTo(a.score));
print(labelScores[0].label + ", " + labelScores[1].label);
if (this.mounted) {
setState(() {
_currentCustomLabels = labelScores;
});
}
}
} catch (e) {
print("MyEx: " + e.toString());
}
}
// int model
Uint8List imageToByteList(img.Image image) {
var _inputSize = 224;
var convertedBytes = new Uint8List(1 * _inputSize * _inputSize * 3);
var buffer = new ByteData.view(convertedBytes.buffer);
int pixelIndex = 0;
for (var i = 0; i < _inputSize; i++) {
for (var j = 0; j < _inputSize; j++) {
var pixel = image.getPixel(i, j);
buffer.setUint8(pixelIndex, (pixel >> 16) & 0xFF);
pixelIndex++;
buffer.setUint8(pixelIndex, (pixel >> 8) & 0xFF);
pixelIndex++;
buffer.setUint8(pixelIndex, (pixel) & 0xFF);
pixelIndex++;
}
}
return convertedBytes;
}
@override
void dispose() {
// TODO: implement dispose
super.dispose();
subscription?.cancel();
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
centerTitle: true,
title: Text(widget._scannerType),
),
body: Column(
children: <Widget>[
buildImage(context),
widget._scannerType == TEXT_SCANNER
? buildTextList(_currentTextLabels)
: widget._scannerType == BARCODE_SCANNER
? buildBarcodeList<VisionBarcode>(_currentBarcodeLabels)
: widget._scannerType == FACE_SCANNER
? buildBarcodeList<VisionFace>(_currentFaceLabels)
: widget._scannerType == LABEL_SCANNER
? buildBarcodeList<VisionLabel>(_currentLabelLabels)
: buildBarcodeList<_CustomLabelScore>(_currentCustomLabels)
],
));
}
Widget buildImage(BuildContext context) {
return Expanded(
flex: 2,
child: Container(
decoration: BoxDecoration(color: Colors.black),
child: Center(
child: widget._file == null
? Text('No Image')
: FutureBuilder<Size>(
future: _getImageSize(
Image.file(widget._file, fit: BoxFit.fitWidth)),
builder:
(BuildContext context, AsyncSnapshot<Size> snapshot) {
if (snapshot.hasData) {
return Container(
foregroundDecoration: (widget._scannerType ==
TEXT_SCANNER)
? TextDetectDecoration(
_currentTextLabels, snapshot.data)
: (widget._scannerType == FACE_SCANNER)
? FaceDetectDecoration(
_currentFaceLabels, snapshot.data)
: (widget._scannerType == BARCODE_SCANNER)
? BarcodeDetectDecoration(
_currentBarcodeLabels,
snapshot.data)
: LabelDetectDecoration(
_currentLabelLabels, snapshot.data),
child:
Image.file(widget._file, fit: BoxFit.fitWidth));
} else {
return CircularProgressIndicator();
}
},
),
)),
);
}
Widget buildBarcodeList<T>(List<T> barcodes) {
if (barcodes.length == 0) {
return Expanded(
flex: 1,
child: Center(
child: Text('Nothing detected',
style: Theme.of(context).textTheme.subhead),
),
);
}
return Expanded(
flex: 1,
child: Container(
child: ListView.builder(
padding: const EdgeInsets.all(1.0),
itemCount: barcodes.length,
itemBuilder: (context, i) {
var text;
final barcode = barcodes[i];
switch (widget._scannerType) {
case BARCODE_SCANNER:
VisionBarcode res = barcode as VisionBarcode;
text = "Raw Value: ${res.rawValue}";
break;
case FACE_SCANNER:
VisionFace res = barcode as VisionFace;
text =
"Raw Value: ${res.smilingProbability},${res.trackingID}";
break;
case LABEL_SCANNER:
VisionLabel res = barcode as VisionLabel;
text = "Raw Value: ${res.label}";
break;
case CUSTOM_MODEL:
_CustomLabelScore res = barcode as _CustomLabelScore;
text = "Raw Value: ${res.label}, Score: ${res.score}";
break;
}
return _buildTextRow(text);
}),
),
);
}
Widget buildTextList(List<VisionText> texts) {
if (texts.length == 0) {
return Expanded(
flex: 1,
child: Center(
child: Text('No text detected',
style: Theme.of(context).textTheme.subhead),
));
}
return Expanded(
flex: 1,
child: Container(
child: ListView.builder(
padding: const EdgeInsets.all(1.0),
itemCount: texts.length,
itemBuilder: (context, i) {
return _buildTextRow(texts[i].text);
}),
),
);
}
Widget _buildTextRow(text) {
return ListTile(
title: Text(
"$text",
),
dense: true,
);
}
Future<Size> _getImageSize(Image image) {
Completer<Size> completer = Completer<Size>();
image.image.resolve(ImageConfiguration()).addListener(
(ImageInfo info, bool _) => completer.complete(
Size(info.image.width.toDouble(), info.image.height.toDouble())));
return completer.future;
}
}
/*
This code uses the example from azihsoyn/flutter_mlkit
https://github.com/azihsoyn/flutter_mlkit/blob/master/example/lib/main.dart
*/
class BarcodeDetectDecoration extends Decoration {
final Size _originalImageSize;
final List<VisionBarcode> _barcodes;
BarcodeDetectDecoration(List<VisionBarcode> barcodes, Size originalImageSize)
: _barcodes = barcodes,
_originalImageSize = originalImageSize;
@override
BoxPainter createBoxPainter([VoidCallback onChanged]) {
return _BarcodeDetectPainter(_barcodes, _originalImageSize);
}
}
class _BarcodeDetectPainter extends BoxPainter {
final List<VisionBarcode> _barcodes;
final Size _originalImageSize;
_BarcodeDetectPainter(barcodes, originalImageSize)
: _barcodes = barcodes,
_originalImageSize = originalImageSize;
@override
void paint(Canvas canvas, Offset offset, ImageConfiguration configuration) {
final paint = Paint()
..strokeWidth = 2.0
..color = Colors.red
..style = PaintingStyle.stroke;
final _heightRatio = _originalImageSize.height / configuration.size.height;
final _widthRatio = _originalImageSize.width / configuration.size.width;
for (var barcode in _barcodes) {
final _rect = Rect.fromLTRB(
offset.dx + barcode.rect.left / _widthRatio,
offset.dy + barcode.rect.top / _heightRatio,
offset.dx + barcode.rect.right / _widthRatio,
offset.dy + barcode.rect.bottom / _heightRatio);
canvas.drawRect(_rect, paint);
}
canvas.restore();
}
}
class TextDetectDecoration extends Decoration {
final Size _originalImageSize;
final List<VisionText> _texts;
TextDetectDecoration(List<VisionText> texts, Size originalImageSize)
: _texts = texts,
_originalImageSize = originalImageSize;
@override
BoxPainter createBoxPainter([VoidCallback onChanged]) {
return _TextDetectPainter(_texts, _originalImageSize);
}
}
class _TextDetectPainter extends BoxPainter {
final List<VisionText> _texts;
final Size _originalImageSize;
_TextDetectPainter(texts, originalImageSize)
: _texts = texts,
_originalImageSize = originalImageSize;
@override
void paint(Canvas canvas, Offset offset, ImageConfiguration configuration) {
final paint = Paint()
..strokeWidth = 2.0
..color = Colors.red
..style = PaintingStyle.stroke;
final _heightRatio = _originalImageSize.height / configuration.size.height;
final _widthRatio = _originalImageSize.width / configuration.size.width;
for (var text in _texts) {
final _rect = Rect.fromLTRB(
offset.dx + text.rect.left / _widthRatio,
offset.dy + text.rect.top / _heightRatio,
offset.dx + text.rect.right / _widthRatio,
offset.dy + text.rect.bottom / _heightRatio);
canvas.drawRect(_rect, paint);
}
canvas.restore();
}
}
class FaceDetectDecoration extends Decoration {
final Size _originalImageSize;
final List<VisionFace> _faces;
FaceDetectDecoration(List<VisionFace> faces, Size originalImageSize)
: _faces = faces,
_originalImageSize = originalImageSize;
@override
BoxPainter createBoxPainter([VoidCallback onChanged]) {
return _FaceDetectPainter(_faces, _originalImageSize);
}
}
class _FaceDetectPainter extends BoxPainter {
final List<VisionFace> _faces;
final Size _originalImageSize;
_FaceDetectPainter(faces, originalImageSize)
: _faces = faces,
_originalImageSize = originalImageSize;
@override
void paint(Canvas canvas, Offset offset, ImageConfiguration configuration) {
final paint = Paint()
..strokeWidth = 2.0
..color = Colors.red
..style = PaintingStyle.stroke;
final _heightRatio = _originalImageSize.height / configuration.size.height;
final _widthRatio = _originalImageSize.width / configuration.size.width;
for (var face in _faces) {
final _rect = Rect.fromLTRB(
offset.dx + face.rect.left / _widthRatio,
offset.dy + face.rect.top / _heightRatio,
offset.dx + face.rect.right / _widthRatio,
offset.dy + face.rect.bottom / _heightRatio);
canvas.drawRect(_rect, paint);
}
canvas.restore();
}
}
class LabelDetectDecoration extends Decoration {
final Size _originalImageSize;
final List<VisionLabel> _labels;
LabelDetectDecoration(List<VisionLabel> labels, Size originalImageSize)
: _labels = labels,
_originalImageSize = originalImageSize;
@override
BoxPainter createBoxPainter([VoidCallback onChanged]) {
return _LabelDetectPainter(_labels, _originalImageSize);
}
}
class _LabelDetectPainter extends BoxPainter {
final List<VisionLabel> _labels;
final Size _originalImageSize;
_LabelDetectPainter(labels, originalImageSize)
: _labels = labels,
_originalImageSize = originalImageSize;
@override
void paint(Canvas canvas, Offset offset, ImageConfiguration configuration) {
final paint = Paint()
..strokeWidth = 2.0
..color = Colors.red
..style = PaintingStyle.stroke;
final _heightRatio = _originalImageSize.height / configuration.size.height;
final _widthRatio = _originalImageSize.width / configuration.size.width;
for (var label in _labels) {
final _rect = Rect.fromLTRB(
offset.dx + label.rect.left / _widthRatio,
offset.dy + label.rect.top / _heightRatio,
offset.dx + label.rect.right / _widthRatio,
offset.dy + label.rect.bottom / _heightRatio);
canvas.drawRect(_rect, paint);
}
canvas.restore();
}
}
class _CustomLabelScore {
String label;
int score;
_CustomLabelScore(this.label, this.score);
}
これをビルドすると、冒頭の動画のようなアプリケーションが動作しているはずです!
まとめ
Firebase MLKit経由でTensorflow Liteモデルをモバイル端末上で動作させることができました。
Firebase MLKitを使用すると、モデルのバージョン管理やリリースなどが簡単にできるようになる上、それがクロスプラットフォームで実現できるのでとても良いですね!
ただ、Tensorflow Lite自体の制約やFirebase MLKitがまだβであること、Flutter公式リポジトリがカスタムモデルに対応していないことを考えると、まだまだ一般利用には程遠い現状だと思います。
今後も引き続きウォッチしていこうと思っているので、なにか動きがあったらまた記事にしようと思います!