9
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

FlutterとTFLiteでPoseEstimation

Last updated at Posted at 2020-12-09

FlutterとTFLiteシリーズ第二弾!
今回はPoseNetを使った姿勢推定をやります!
ソースはこちら:https://github.com/bigface0202/pose_estimation_flutter

PoseNetとは

PoseNetは姿勢推定のモデルの1つです。姿勢推定はOpenPoseを皮切りに様々なモデルが存在しております。
今回用いたPoseNetはGoogleが作ったモデルになります。
内部構造に関する説明は他の記事を参考にして頂くとして、簡単にまとめれば画像を姿勢推定モデルに入力することで、画像内に写っている人物の関節17点を推定することができます。

使ったライブラリ

image_pickerを使って画像をギャラリーやカメラから取ってきて、tfliteを使って推論するという流れです。

姿勢推定モデルとディレクトリの配置

モデルはこちらからダウンロードしてください。
というのも、最初はPoseNetで公開されているモデルを使っていたのですが、出力の後処理部分がうまくいかず(画像の縮尺が合わない?)諦めました。

FlutterでNew Projectをした状態から、

assets
└── posenet_mv1_075_float_from_checkpoints.tflite.tflite
...
lib
├── image_input.dart
├── index_screen.dart
└── main.dart
...

assetsとlibの中身がこうなります。他はpubspec.yamlにライブラリ記述するくらいです。

コードの中身

main.dartとindex_screen.dart

main.dart
import 'package:flutter/material.dart';
import 'index_screen.dart';

void main() {
  runApp(MyApp());
}

class MyApp extends StatelessWidget {
  // This widget is the root of your application.
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'SEE FOOD',
      theme: ThemeData(
        primarySwatch: Colors.blue,
        primaryColor: Colors.black,
      ),
      home: IndexScreen(),
    );
  }
}
index_screen.dart
import 'dart:io';

import "package:flutter/material.dart";

import "./image_input.dart";

class IndexScreen extends StatelessWidget {
  File _pickedImage;

  void _selectImage(File pickedImage) {
    _pickedImage = pickedImage;
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text('POSE ESTIMATION'),
      ),
      body: ImageInput(_selectImage),
    );
  }
}

この2つは特に特筆することはありません。

image_input.dart

image_input.dart
import 'dart:io';
import 'dart:ui' as ui;

import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';

class ImageInput extends StatefulWidget {
  final Function onSelectImage;

  ImageInput(this.onSelectImage);

  @override
  _ImageInputState createState() => _ImageInputState();
}

class _ImageInputState extends State<ImageInput> {
  // File _storedImage;
  final picker = ImagePicker();
  bool loading = true;
  Map<int, dynamic> keyPoints;
  ui.Image image;

  Future<void> _takePicture() async {
    setState(() {
      loading = true;
    });
    final imageFile = await picker.getImage(
      source: ImageSource.camera,
    );
    if (imageFile == null) {
      return;
    }
    poseEstimation(File(imageFile.path));
  }

  Future<void> _getImageFromGallery() async {
    setState(() {
      loading = true;
    });
    final imageFile = await picker.getImage(
      source: ImageSource.gallery,
    );
    if (imageFile == null) {
      return;
    }
    poseEstimation(File(imageFile.path));
  }

  static Future loadModel() async {
    Tflite.close();
    try {
      await Tflite.loadModel(
        model: 'assets/posenet_mv1_075_float_from_checkpoints.tflite',
      );
    } on PlatformException {
      print("Failed to load the model");
    }
  }

  Future poseEstimation(File imageFile) async {
    final imageByte = await imageFile.readAsBytes();
    image = await decodeImageFromList(imageByte);
    // Prediction
    List recognition = await Tflite.runPoseNetOnImage(
      path: imageFile.path,
      imageMean: 125.0, // defaults to 117.0
      imageStd: 125.0, // defaults to 1.0
      numResults: 2, // defaults to 5
      threshold: 0.7, // defaults to 0.1
      nmsRadius: 10,
      asynch: true,
    );
    // Extract keypoints from recognition
    if (recognition.length > 0) {
      setState(() {
        keyPoints = new Map<int, dynamic>.from(recognition[0]['keypoints']);
      });
    } else {
      keyPoints = {};
    }
    setState(() {
      loading = false;
    });
  }

  @override
  void initState() {
    super.initState();
    loadModel().then((val) {
      setState(() {});
    });
  }

  @override
  Widget build(BuildContext context) {
    return SingleChildScrollView(
      child: Container(
        padding: EdgeInsets.all(10),
        child: Column(
          children: [
            loading
                ? Container(
                    width: 380,
                    height: 500,
                    alignment: Alignment.center,
                    decoration: BoxDecoration(
                      border: Border.all(width: 1, color: Colors.grey),
                    ),
                    child: Text(
                      'No Image Taken',
                      textAlign: TextAlign.center,
                    ),
                  )
                : FittedBox(
                    child: SizedBox(
                      width: image.width.toDouble(),
                      height: image.height.toDouble(),
                      child: CustomPaint(
                        painter: CirclePainter(keyPoints, image),
                      ),
                    ),
                  ),
            Row(
              mainAxisAlignment: MainAxisAlignment.spaceAround,
              children: [
                Expanded(
                  child: FlatButton.icon(
                    icon: Icon(Icons.photo_camera),
                    label: Text('カメラ'),
                    textColor: Theme.of(context).primaryColor,
                    onPressed: _takePicture,
                  ),
                ),
                Expanded(
                  child: FlatButton.icon(
                    icon: Icon(Icons.photo_library),
                    label: Text('ギャラリー'),
                    textColor: Theme.of(context).primaryColor,
                    onPressed: _getImageFromGallery,
                  ),
                ),
              ],
            ),
          ],
        ),
      ),
    );
  }
}

class CirclePainter extends CustomPainter {
  final Map params;
  final ui.Image image;
  CirclePainter(this.params, this.image);

  @override
  void paint(ui.Canvas canvas, Size size) {
    final paint = Paint();
    if (image != null) {
      canvas.drawImage(image, Offset(0, 0), paint);
    }
    paint.color = Colors.red;
    if (params.isNotEmpty) {
      params.forEach((index, param) {
        canvas.drawCircle(
            Offset(size.width * param['x'], size.height * param['y']),
            10,
            paint);
      });
      print("Done!");
    }
  }

  @override
  bool shouldRepaint(covariant CirclePainter oldDelegate) => false;
  // image != oldDelegate.image || params != oldDelegate.params;
}

こちらのコードで画像の取り出し・姿勢の推定を行っております。順を追って説明していきましょう。

モデルの読み込み

まずは、姿勢推定のモデルをロードします。

モデルの読み込み
static Future loadModel() async {
    Tflite.close();
    try {
      await Tflite.loadModel(
        model: 'assets/posenet_mv1_075_float_from_checkpoints.tflite',
      );
    } on PlatformException {
      print("Failed to load the model");
    }
  }

こちらはinitState部分に記述することでアプリを開いたタイミングでモデルをロードできるようにします。

initState
@override
  void initState() {
    super.initState();
    loadModel().then((val) {
      setState(() {});
    });
  }

画像の取り出し

次に、image_pickerを使って、_takePicture_getImageFromGalleryで画像をカメラ、もしくはギャラリーから取ってきています。

画像取り出し部分
Future<void> _takePicture() async {
    setState(() {
      loading = true;
    });
    final imageFile = await picker.getImage(
      source: ImageSource.camera,
    );
    if (imageFile == null) {
      return;
    }
    poseEstimation(File(imageFile.path));
  }

  Future<void> _getImageFromGallery() async {
    setState(() {
      loading = true;
    });
    final imageFile = await picker.getImage(
      source: ImageSource.gallery,
    );
    if (imageFile == null) {
      return;
    }
    poseEstimation(File(imageFile.path));
  }

カメラorギャラリーの違いはsourceの違いで、使い方は以下のサイトがわかりやすいと思います。
【Flutter】【Dart】Image Pickerで画像を選択する
取り出した画像はFileとしてposeEstimationに渡します。

姿勢推定

poseEstimation
 Future poseEstimation(File imageFile) async {
    final imageByte = await imageFile.readAsBytes();
    image = await decodeImageFromList(imageByte);
    // Prediction
    List recognition = await Tflite.runPoseNetOnImage(
      path: imageFile.path,
      imageMean: 125.0, // defaults to 117.0
      imageStd: 125.0, // defaults to 1.0
      numResults: 2, // defaults to 5
      threshold: 0.7, // defaults to 0.1
      nmsRadius: 10,
      asynch: true,
    );
    // Extract keypoints from recognition
    if (recognition.length > 0) {
      setState(() {
        keyPoints = new Map<int, dynamic>.from(recognition[0]['keypoints']);
      });
    } else {
      keyPoints = {};
    }
    setState(() {
      loading = false;
    });
  }

TFliteには標準でrunPoseNetOnImageが搭載されておりますが、事前にモデルをロードしていないとエラーになります。
推論は画像のファイルパスから、画像のバイナリデータから、フレームからなどを選択できますが、今回はファイルパスから推論を行うようにしています。画像の画素平均値や標準偏差等のパラメータは初期値のままです。ちなみにnumResultsを設定することで出力できる数、今回で言えば姿勢推定をする人数を設定することができます。
推論後、出力結果をkeyPointsに格納します。

推論結果の描画

描画部分
loading
    ? Container(
        width: 380,
        height: 500,
        alignment: Alignment.center,
        decoration: BoxDecoration(
          border: Border.all(width: 1, color: Colors.grey),
        ),
        child: Text(
          'No Image Taken',
          textAlign: TextAlign.center,
        ),
      )
    : FittedBox(
        child: SizedBox(
          width: image.width.toDouble(),
          height: image.height.toDouble(),
          child: CustomPaint(
            painter: CirclePainter(keyPoints, image),
          ),
        ),
      ),

推論が終わったかどうかの判断をloadingに委ねています。
推論が終わり次第、CustomPaintを使って画像の描画を行います。
CustomPaint自体は大きさを持たないので、ContainerSizedBoxで大きさを指定してあげる必要があります。

CustomPainter
class CirclePainter extends CustomPainter {
  final Map params;
  final ui.Image image;
  CirclePainter(this.params, this.image);

  @override
  void paint(ui.Canvas canvas, Size size) {
    final paint = Paint();
    if (image != null) {
      canvas.drawImage(image, Offset(0, 0), paint);
    }
    paint.color = Colors.red;
    if (params.isNotEmpty) {
      params.forEach((index, param) {
        canvas.drawCircle(
            Offset(size.width * param['x'], size.height * param['y']),
            10,
            paint);
      });
      print("Done!");
    }
  }

  @override
  bool shouldRepaint(covariant CirclePainter oldDelegate) => false;
}

推論に用いた画像に対して、関節部位に赤丸を付与するためにCustomPainterを用いました。
paint()を用意してcannvascanvas.drawImageで画像を描画後、その上からcanvas.drawCircleで関節位置に赤丸を付与します。
推論後に得られるxやyは正規化された値になっているので、画像のサイズをかけ合わせてあげる必要があります。
shouldRepaintは一度描画した後、再描画する必要がある場合(お絵描きアプリなど)は設定する必要がありますが、今回の場合は再描画することはありませんね。動画使う場合は必要になるのでしょうかね、そこらへんはまた検証が必要そうです。

完成品

PoseEstimation_final.gif
いい感じですね。姿勢を推定できる写真では関節の描画を、推定できない写真(ホットドッグ)では描画しないようにできています。!
2枚目のクリスティアーノ・ロナウドの写真では後ろの人の足を関節として捉えちゃっていますね。
複数人数の場合は関節と関節のつなぎ合わせ部分で整合性とれるように調整するのですが、そこらへんTFLiteのモデルだとどうなっているんでしょうか…
また詳しく見てみようと思います。

9
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?