FlutterとTFLiteシリーズ第二弾!
今回はPoseNetを使った姿勢推定をやります!
ソースはこちら:https://github.com/bigface0202/pose_estimation_flutter
PoseNetとは
PoseNetは姿勢推定のモデルの1つです。姿勢推定はOpenPoseを皮切りに様々なモデルが存在しております。
今回用いたPoseNetはGoogleが作ったモデルになります。
内部構造に関する説明は他の記事を参考にして頂くとして、簡単にまとめれば画像を姿勢推定モデルに入力することで、画像内に写っている人物の関節17点を推定することができます。
使ったライブラリ
image_pickerを使って画像をギャラリーやカメラから取ってきて、tfliteを使って推論するという流れです。
姿勢推定モデルとディレクトリの配置
モデルはこちらからダウンロードしてください。
というのも、最初はPoseNetで公開されているモデルを使っていたのですが、出力の後処理部分がうまくいかず(画像の縮尺が合わない?)諦めました。
FlutterでNew Projectをした状態から、
assets
└── posenet_mv1_075_float_from_checkpoints.tflite.tflite
...
lib
├── image_input.dart
├── index_screen.dart
└── main.dart
...
assetsとlibの中身がこうなります。他はpubspec.yamlにライブラリ記述するくらいです。
コードの中身
main.dartとindex_screen.dart
import 'package:flutter/material.dart';
import 'index_screen.dart';
void main() {
runApp(MyApp());
}
class MyApp extends StatelessWidget {
// This widget is the root of your application.
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'SEE FOOD',
theme: ThemeData(
primarySwatch: Colors.blue,
primaryColor: Colors.black,
),
home: IndexScreen(),
);
}
}
import 'dart:io';
import "package:flutter/material.dart";
import "./image_input.dart";
class IndexScreen extends StatelessWidget {
File _pickedImage;
void _selectImage(File pickedImage) {
_pickedImage = pickedImage;
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('POSE ESTIMATION'),
),
body: ImageInput(_selectImage),
);
}
}
この2つは特に特筆することはありません。
image_input.dart
import 'dart:io';
import 'dart:ui' as ui;
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';
class ImageInput extends StatefulWidget {
final Function onSelectImage;
ImageInput(this.onSelectImage);
@override
_ImageInputState createState() => _ImageInputState();
}
class _ImageInputState extends State<ImageInput> {
// File _storedImage;
final picker = ImagePicker();
bool loading = true;
Map<int, dynamic> keyPoints;
ui.Image image;
Future<void> _takePicture() async {
setState(() {
loading = true;
});
final imageFile = await picker.getImage(
source: ImageSource.camera,
);
if (imageFile == null) {
return;
}
poseEstimation(File(imageFile.path));
}
Future<void> _getImageFromGallery() async {
setState(() {
loading = true;
});
final imageFile = await picker.getImage(
source: ImageSource.gallery,
);
if (imageFile == null) {
return;
}
poseEstimation(File(imageFile.path));
}
static Future loadModel() async {
Tflite.close();
try {
await Tflite.loadModel(
model: 'assets/posenet_mv1_075_float_from_checkpoints.tflite',
);
} on PlatformException {
print("Failed to load the model");
}
}
Future poseEstimation(File imageFile) async {
final imageByte = await imageFile.readAsBytes();
image = await decodeImageFromList(imageByte);
// Prediction
List recognition = await Tflite.runPoseNetOnImage(
path: imageFile.path,
imageMean: 125.0, // defaults to 117.0
imageStd: 125.0, // defaults to 1.0
numResults: 2, // defaults to 5
threshold: 0.7, // defaults to 0.1
nmsRadius: 10,
asynch: true,
);
// Extract keypoints from recognition
if (recognition.length > 0) {
setState(() {
keyPoints = new Map<int, dynamic>.from(recognition[0]['keypoints']);
});
} else {
keyPoints = {};
}
setState(() {
loading = false;
});
}
@override
void initState() {
super.initState();
loadModel().then((val) {
setState(() {});
});
}
@override
Widget build(BuildContext context) {
return SingleChildScrollView(
child: Container(
padding: EdgeInsets.all(10),
child: Column(
children: [
loading
? Container(
width: 380,
height: 500,
alignment: Alignment.center,
decoration: BoxDecoration(
border: Border.all(width: 1, color: Colors.grey),
),
child: Text(
'No Image Taken',
textAlign: TextAlign.center,
),
)
: FittedBox(
child: SizedBox(
width: image.width.toDouble(),
height: image.height.toDouble(),
child: CustomPaint(
painter: CirclePainter(keyPoints, image),
),
),
),
Row(
mainAxisAlignment: MainAxisAlignment.spaceAround,
children: [
Expanded(
child: FlatButton.icon(
icon: Icon(Icons.photo_camera),
label: Text('カメラ'),
textColor: Theme.of(context).primaryColor,
onPressed: _takePicture,
),
),
Expanded(
child: FlatButton.icon(
icon: Icon(Icons.photo_library),
label: Text('ギャラリー'),
textColor: Theme.of(context).primaryColor,
onPressed: _getImageFromGallery,
),
),
],
),
],
),
),
);
}
}
class CirclePainter extends CustomPainter {
final Map params;
final ui.Image image;
CirclePainter(this.params, this.image);
@override
void paint(ui.Canvas canvas, Size size) {
final paint = Paint();
if (image != null) {
canvas.drawImage(image, Offset(0, 0), paint);
}
paint.color = Colors.red;
if (params.isNotEmpty) {
params.forEach((index, param) {
canvas.drawCircle(
Offset(size.width * param['x'], size.height * param['y']),
10,
paint);
});
print("Done!");
}
}
@override
bool shouldRepaint(covariant CirclePainter oldDelegate) => false;
// image != oldDelegate.image || params != oldDelegate.params;
}
こちらのコードで画像の取り出し・姿勢の推定を行っております。順を追って説明していきましょう。
モデルの読み込み
まずは、姿勢推定のモデルをロードします。
static Future loadModel() async {
Tflite.close();
try {
await Tflite.loadModel(
model: 'assets/posenet_mv1_075_float_from_checkpoints.tflite',
);
} on PlatformException {
print("Failed to load the model");
}
}
こちらはinitState部分に記述することでアプリを開いたタイミングでモデルをロードできるようにします。
@override
void initState() {
super.initState();
loadModel().then((val) {
setState(() {});
});
}
画像の取り出し
次に、image_pickerを使って、_takePicture
と_getImageFromGallery
で画像をカメラ、もしくはギャラリーから取ってきています。
Future<void> _takePicture() async {
setState(() {
loading = true;
});
final imageFile = await picker.getImage(
source: ImageSource.camera,
);
if (imageFile == null) {
return;
}
poseEstimation(File(imageFile.path));
}
Future<void> _getImageFromGallery() async {
setState(() {
loading = true;
});
final imageFile = await picker.getImage(
source: ImageSource.gallery,
);
if (imageFile == null) {
return;
}
poseEstimation(File(imageFile.path));
}
カメラorギャラリーの違いはsource
の違いで、使い方は以下のサイトがわかりやすいと思います。
【Flutter】【Dart】Image Pickerで画像を選択する
取り出した画像はFile
としてposeEstimation
に渡します。
姿勢推定
Future poseEstimation(File imageFile) async {
final imageByte = await imageFile.readAsBytes();
image = await decodeImageFromList(imageByte);
// Prediction
List recognition = await Tflite.runPoseNetOnImage(
path: imageFile.path,
imageMean: 125.0, // defaults to 117.0
imageStd: 125.0, // defaults to 1.0
numResults: 2, // defaults to 5
threshold: 0.7, // defaults to 0.1
nmsRadius: 10,
asynch: true,
);
// Extract keypoints from recognition
if (recognition.length > 0) {
setState(() {
keyPoints = new Map<int, dynamic>.from(recognition[0]['keypoints']);
});
} else {
keyPoints = {};
}
setState(() {
loading = false;
});
}
TFliteには標準でrunPoseNetOnImage
が搭載されておりますが、事前にモデルをロードしていないとエラーになります。
推論は画像のファイルパスから、画像のバイナリデータから、フレームからなどを選択できますが、今回はファイルパスから推論を行うようにしています。画像の画素平均値や標準偏差等のパラメータは初期値のままです。ちなみにnumResultsを設定することで出力できる数、今回で言えば姿勢推定をする人数を設定することができます。
推論後、出力結果をkeyPoints
に格納します。
推論結果の描画
loading
? Container(
width: 380,
height: 500,
alignment: Alignment.center,
decoration: BoxDecoration(
border: Border.all(width: 1, color: Colors.grey),
),
child: Text(
'No Image Taken',
textAlign: TextAlign.center,
),
)
: FittedBox(
child: SizedBox(
width: image.width.toDouble(),
height: image.height.toDouble(),
child: CustomPaint(
painter: CirclePainter(keyPoints, image),
),
),
),
推論が終わったかどうかの判断をloading
に委ねています。
推論が終わり次第、CustomPaint
を使って画像の描画を行います。
CustomPaint
自体は大きさを持たないので、Container
やSizedBox
で大きさを指定してあげる必要があります。
class CirclePainter extends CustomPainter {
final Map params;
final ui.Image image;
CirclePainter(this.params, this.image);
@override
void paint(ui.Canvas canvas, Size size) {
final paint = Paint();
if (image != null) {
canvas.drawImage(image, Offset(0, 0), paint);
}
paint.color = Colors.red;
if (params.isNotEmpty) {
params.forEach((index, param) {
canvas.drawCircle(
Offset(size.width * param['x'], size.height * param['y']),
10,
paint);
});
print("Done!");
}
}
@override
bool shouldRepaint(covariant CirclePainter oldDelegate) => false;
}
推論に用いた画像に対して、関節部位に赤丸を付与するためにCustomPainter
を用いました。
paint()
を用意してcannvas
にcanvas.drawImage
で画像を描画後、その上からcanvas.drawCircle
で関節位置に赤丸を付与します。
推論後に得られるxやyは正規化された値になっているので、画像のサイズをかけ合わせてあげる必要があります。
shouldRepaint
は一度描画した後、再描画する必要がある場合(お絵描きアプリなど)は設定する必要がありますが、今回の場合は再描画することはありませんね。動画使う場合は必要になるのでしょうかね、そこらへんはまた検証が必要そうです。
完成品
いい感じですね。姿勢を推定できる写真では関節の描画を、推定できない写真(ホットドッグ)では描画しないようにできています。!
2枚目のクリスティアーノ・ロナウドの写真では後ろの人の足を関節として捉えちゃっていますね。
複数人数の場合は関節と関節のつなぎ合わせ部分で整合性とれるように調整するのですが、そこらへんTFLiteのモデルだとどうなっているんでしょうか…
また詳しく見てみようと思います。