LoginSignup
21
5

More than 1 year has passed since last update.

Flutter Webで画像分類を行う(AutoML Vision, TensorFlow.js)

Last updated at Posted at 2021-12-01

はじめに

スプラトゥーン2のプレイヤー向けに、シーン抽出機能付きの動画プレイヤーアプリをFlutter Webで作成しました。プレイの録画ファイルをブラウザにドラッグアンドドロップすると、試合の開始と終了、試合中のたおしたシーンとやられたシーンを頭出しします。

使ってみたい方はこちら。

画像の引用

この記事では任天堂株式会社のゲーム、スプラトゥーン2のスクリーンショットを引用しています。

頭出しの方法

各シーンの頭出しは、動画からフレーム画像を抽出して、Google AutoML Visionで作成した学習モデルを使って予測することで行っています。学習モデルはフレーム画像を以下の5クラスに分類するものです。

 試合の開始 試合の終了
たおした  やられた その他

学習モデルの作成

学習モデルは以下の枚数のラベル付けしたフレーム画像をAutoML Visionにインポートして作成しています。

シーン 枚数
試合の開始 5537
試合の終了 23494
たおした 1810
やられた 1532
その他 19950
全体 52323

AutoML Visionの画面はこのようになります。

スクリーンショット 2021-11-28 2.05.47.png

ラベル付け作業について

約5万枚の画像を手作業で分類するのは大変ですが、Google Cloud Vision APIのテキスト検出を使用して一部自動化しています。その様子は、やられたシーンだけですが、こちらの記事で紹介しています。

スプラトゥーン2のプレイ動画から、やられたシーンだけをディープラーニングで自動抽出する

また、手作業をやりやすくするために、Flutter Desktopで簡易的なアノテーションツールも作成していて、こちらの記事で紹介しています。

Flutter Desktopで簡易アノテーションツールを作る

学習モデルのエクスポート

今回はブラウザ内で予測したかったので、TensorFlow.js形式でモデルを出力しました。

スクリーンショット 2021-11-28 2.32.06.png

ブラウザのJavaScriptからAutoML Visionで作成したモデルを使って、画像分類を行う

ブラウザのJavaScriptからAutoML Visionで作成したモデルを使って、画像分類を行う方法は、公式で紹介されています。

Edge TensorFlow.js チュートリアル

しかし今回はFlutter Webで作成したアプリなので、こちらで紹介されているJavaScriptのコードをFlutterから呼び出す方法を探す必要があります。

pub.devからパッケージを探す

2021年3月4日に行われたFlutter Engageモバイルアプリからウェブアプリへセッションでは、既存のJavaScriptライブラリにはプラグインからアクセスできて、pub.devでは対応プラットフォームがラベルとして表示されるとありました。しかし探してみたところ、Webラベルが付いているTensorFlowのプラグインは見つけられませんでした。

スクリーンショット 2021-11-28 3.23.19.png

FlutterからJavaScriptメソッドを直接呼ぶ

プラグインはないので公式で紹介されているTensorFlow.jsを使うJavaScript関数を直接呼ぶことにします。

学習モデルとjsファイルの設置

プロジェクト直下のwebディレクトリに学習モデルと、TensorFlow.jsを使うjsファイルであるpredict.jsを設置します。学習モデルはmodelディレクトリにまとめました。

スクリーンショット 2021-11-28 3.30.24.png

jsファイルの取り込み

web/index.htmlファイルの所定の位置にjsファイルを取り込むためのscriptタグを挿入します。

index.html
<!DOCTYPE html>
<html>

<head>
  <!-- 略 -->
</head>

<body>
  <!-- 以下3行を挿入する -->
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.8.4/dist/tf.min.js"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-automl@1.0.0/dist/tf-automl.min.js"></script>
  <script src="predict.js"></script>
  <!-- This script installs service_worker.js to provide PWA functionality to
       application. For more information, see:
       https://developers.google.com/web/fundamentals/primers/service-workers -->
  <script>
    if ('serviceWorker' in navigator) {
      window.addEventListener('flutter-first-frame', function () {
        navigator.serviceWorker.register('flutter_service_worker.js');
      });
    }
  </script>
  <script src="main.dart.js" type="application/javascript"></script>
  <!-- 略 -->
</body>

</html>

jsファイルの中身

モデルの読み込みは1回やればよく、画像の分類は動画の0.5秒おき等たくさん行うので、別の関数にしました。

predict.js
// 分類モデル
var iKutModel = null;
// モデル読込済みフラグ
var iKutLoaded = false;

// モデルを読み込む
async function loadImageClassification() {
    if (!iKutLoaded) {
        // 試合の開始モデル読込
        iKutModel = await tf.automl.loadImageClassification('model/model.json');
        // 読込完了フラグ
        iKutLoaded = true;
    }
    return 0;
}


// 分類する
async function classify(image) {
    const result = await iKutModel.classify(image);
    return JSON.stringify(result, null, 2);
}

jsパッケージを使い、JavaScript関数を直接呼び出す

Dartはjsパッケージを使うことで、JavaScript関数を呼び出すことができます。

参考 JavaScript interoperability

API Referenceを参考に、dartファイルで関数とアノテーションを設定しました。
asyncキーワードで宣言された非同期関数はPromiseを返却しますが、ここでは戻り値をObjectにします。その理由は追って説明します。
ImageElementはHTMLのimgタグになります。HTMLタグはdart:htmlライブラリに定義されています。

predict.dart
@JS()
library predict;

import 'dart:html';

import 'package:js/js.dart';

@JS('loadImageClassification')
external Object loadImageClassification();

@JS('classify')
external Object classify(ImageElement img);

定義した関数はそのまま呼ぶことが可能です。JavaScriptのPromiseはjs_utilライブラリpromiseToFuture関数Futureに変換できます。戻り値をObjectにした理由は、promiseToFuture関数の引数がObjectだからです。

await promiseToFuture(loadImageClassification());
String result = await promiseToFuture(classify(imageElement));

JavaScript側のnullに注意

Dartにはnull safetyの仕組みがありますが、Dart側でnon-nullableになっている変数にJavaScriptから返却されたnullを代入すると、実行時にUncaught Errorになってしまいます。ご注意ください。

JavaScript関数呼び出しの前後について

動画のフレーム画像をimg要素にする

画像分類の入力はHTMLのimg要素なので、動画のフレームをimg要素にします。

まずvideo要素とcanvas要素を作成します。canvas要素のサイズは学習モデルの入力層の形に合わせました。

final VideoElement _videoElement = VideoElement();
final CanvasElement _canvasElement = CanvasElement(width: 224, height: 224);

このアプリではドラッグアンドドロップで動画ファイルを読み込みます。ドラッグアンドドロップはflutter_dropzoneプラグインのDropzoneViewで使うことができます。

DropzoneView(
  operation: DragOperation.copy,
  cursor: CursorType.Default,
  onLoaded: () {},
  onError: (ev) {},
  onLeave: () {},
  onHover: () {},
  onDrop: (ev) {
    // ファイルがドラッグアンドドロップされた
    // ファイルのオブジェクトURLを取得できる
    final file = ev as File;
    final url = Url.createObjectUrl(file);
  }
});

取得したオブジェクトURLをvideo要素に読み込ませます。

_videoElement.src = url;

動画の長さが確定するとloadedmetadataイベントが呼ばれます。そうすると次に再生時間を更新できるようになるので、欲しいフレームの再生時間に更新します。

_videoElement.addEventListener("loadedmetadata", (event) {
  // 再生時間を更新
  _videoElement.currentTime = 0.0;
});

次にtimeupdateイベントが呼ばれます。そこでvideo要素をcanvas要素のデータURLを経由してimg要素に変換します。

_videoElement.addEventListener("timeupdate", (event) {
  // video要素をcanvas要素に描き込む
  final context = _canvasElement.context2D;
  context.drawImageScaled(_videoElement, 0, 0, 224, 224);
  // canvas要素をデータURLに変換する
  final dataUrl = _canvasElement.toDataUrl();
  // img要素を作成する
  final imageElement = new ImageElement();
  // 画像の読み込み完了イベント
  imageElement.addEventListener("load", (event) async {
    // フレーム画像を分類する
    String result = await promiseToFuture(classify(imageElement));
    // 次節に続く
  });
  // データURLをimg要素に渡す。
  imageElement.src = dataUrl;
});

JSON文字列から分類結果を取得する

分類結果のJSON文字列はこのようになっています。各ラベルの確率をprobフィールドから得ることができます。probが一番大きいラベルを取得すれば良いです。

[
  {
    "label": "start",
    "prob": 0.012603357434272766
  },
  {
    "label": "other",
    "prob": 0.009589558467268944
  },
  {
    "label": "end",
    "prob": 0.010704034939408302
  },
  {
    "label": "death",
    "prob": 0.9580927491188049
  },
  {
    "label": "kill",
    "prob": 0.009010315872728825
  }
]

ラベル取得の処理はこちらになります。

String _getLabel(String jsonString) {
  // JSON文字列をデコードする
  List<dynamic> dynamicClasses = jsonDecode(jsonString);
  // label - prob のMapに変換する
  Map<String, double> classes = Map();
  dynamicClasses.forEach((e) {
    classes[e['label']] = e['prob'];
  });
  // Mapから値が一番大きいキーを得る
  String maxLabel = "";
  double maxProb = -1.0;
  classes.forEach((label, prob) {
    if (prob > maxProb) {
      maxLabel = label;
      maxProb = prob;
    }
  });
  return maxLabel;
}

まとめ

  • プラグインのFlutter Webへの対応状況はpub.devのWebラベルの有無で確認できます。
  • jsパッケージを使うことによって、TensorFlow.jsのようなプラグインが提供されていないJavaScriptの処理も、Flutterから呼び出すことが可能です。
  • JavaScript側でPromiseになっているオブジェクトはjs_utilライブラリpromiseToFuture関数を使うことで、Futureのオブジェクトに変換することができます。
21
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
5