LoginSignup
6
0

More than 1 year has passed since last update.

テスト駆動でMediaPipeのCalculatorをつくる

Posted at

テスト駆動でMediaPipeのCalculatorをつくる

はじめに

テスト駆動でMediaPipeのCalculatorをつくる流れについて紹介します。

CalculatorというのはMediaPipeにおける処理の単位で、このCalculatorをノードとした有向グラフを記述することで処理の流れを記述することができます。(MediaPipeフレームワークの各概念に関する詳しい説明は公式ドキュメントの「Framework Concepts」を参照してください。)

MediaPipeで新しい機能やアプリケーションを作りたいとき、既存のCalculatorの組み合わせでは実現できないため新しくCalculatorをつくる必要がある場合があります。

こういった場合で、Calculatorをテスト駆動で新しく作っていく流れについて書いてきます。もし参考になれば幸いです。

対象読者としては下記を念頭に置いています。

  • MediaPipeのインストールができる
  • 基本的なC++のプログラムが書ける

注:ここでは、MediaPipeのバージョンはv0.8.8としています。MediaPipeはまだアルファ版で今後互換性のない変更が生じる可能性があるので、注意が必要です。

つくるもの

例として検出結果の数を分類ラベルごとにカウントするCalculatorをつくります。

なお、カウントがしたいだけでそのカウント結果をMediaPipeフレームワーク内の他の処理が使うわけではないなら、Calculatorとしてつくらずに、MediaPipeの外側で結果を受け取ってカウント処理を行うというやり方もあると思います。そのため、実際ではMediaPipeのCalculatorとしてつくるべきかどうかも一度検討するとより良いかもしれません。

ここでは、比較的シンプルに説明できるため、検出結果の数をカウントするCalculatorをつくっていきます。

(1) 準備

(1.1) MediaPipeをインストールする

公式ドキュメントの説明にしたがって、MediaPipeをインストールします。

以下ではv0.8.8リリースを解凍してインストールを進めたとしていますが、git clone や git submodule で同バージョンにしてから進めてもとくに問題ないはずです。その場合はMediaPipeのディレクトリ名をmediapipe-0.8.8からmediapipeと読み替えてください。

(1.2) 新しくつくるCalculatorを格納するディレクトリをつくる

分かりやすいように、別でディレクトリをつくって、新しくつくるCalculatorはそこに置くことにします。別でつくっておくもう一つの利点としてはMediaPipeのバージョンを上げやすくなることです。

インストールしたMediaPipeのディレクトリと同じ階層にextensionsというディレクトリをつくり、その配下にさらにcalculatorsディレクトリをつくります。名前は適宜変えてもかまいません。

その後、MediaPipeのディレクトリの配下にさっきつくったextensionsディレクトリへのシンボリックリンクをつくります。

最終的には下記のようなディレクトリ構造になっているはずです。

extensions
    calculators
mediapipe-0.8.8
    ... (※他は略)
    docs
    extensions  (※シンボリックリンク)
    mediapipe
    third_party
    ... (※他は略)

Calculatorとそのテストはextensions/calculatorsの配下につくっていきます。

(2) テストをつくる

(2.1) 入出力データの仕様を考える/つくる

まずはCalculatorの入力と出力を考えます。MediaPipeでは、Calculatorへの入出力データ仕様は Protocol Buffers で定義することができます。

いまは検出結果のカウントをするCalculatorをつくろうとしているので、入力は検出結果ということになるでしょう。MediaPipeですでにDetectionというものが定義されているので、入力データにはこれを使うことにします。

出力データは検出結果の各分類のカウント数が含まれるようにしましょう。こちらは新しく定義することにします。

extensions/calculators配下にdetections_counts.protoというファイルをつくります。そして、下記の内容を書きます。

detections_counts.proto
syntax = "proto2";

package mediapipe;

message DetectionsCounts {
    // 検出結果の分類別のカウント。ラベルテキストからカウント数へのマッピング。
    map<string, int64> counts = 1;
}

Protocol Buffers の言語仕様についてはここでは説明しません。詳しくは Protocol Buffers の言語ガイドを参照してください。

さらに、このprotoファイルのビルドを定義します。MediaPipeではビルドにBazelを使っています。そのため、Bazel用にBUILDファイルをつくって、そこにprotoファイルのビルドを記述します。具体的には次のようにつくります。

protoファイルと同じ階層にBUILDという名前のファイルをつくり、下記の内容を書きます。

BUILD
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")

mediapipe_proto_library(
    name = "detections_counts_proto",
    srcs = ["detections_counts.proto"],
    visibility = ["//visibility:public"],
)

BUILDファイルが書けたら、新しくつくったprotoファイルをビルドしてみましょう。MediaPipeのディレクトリに移動して、下記のコマンドでビルドできます。

bazel build //extensions/calculators:detections_counts_proto

ビルドが成功するのを確認します。

(2.2) テストの実装

Calculatorのテストのソースファイルは、後ろに_testをつける慣習になっているようなので(例:top_k_scores_calculator_test.cc)、これにしたがいます。

今回つくるCalculatorはDetectionsToCountsCalculatorという名前にしましょう。そのため、そのテストはdetections_to_counts_calculator_test.ccという名前でファイルを作ります。

ここから、テストを実装していくわけですが、その前に、MediaPipeのCalculatorではどのようにテストが書かれているか、説明します。

MediaPipeのCalculatorでは、テストにgtest (Google Test) が使われているようです(includeしている箇所の例)。

CalculatorのテストをするためにはそのCalculatorだけが含まれたGraphが走らされます。ParseTextProtoOrDieでテキストを読み取ってCalculatorRunnerに渡されます。そして、CalculatorRunnerMutableInputs()から入力し、その後CalculatorRunnerを走らせ、Outputs()から出力を取り出しています。この流れの例としては、MediaPipeのソースコードのtop_k_scores_calculator_test.cc123-158行目(v0.8.8)あたりを参照してください。

似たような書き方で、こちらのdetections_to_counts_calculator_test.ccを実装しましょう。下記のように書きます。

detections_to_counts_calculator_test.cc
#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "extensions/calculators/detections_counts.pb.h"

namespace mediapipe {

namespace {

// テストデータをつくるためのヘルパー関数
Detection CreateTestDectionData(std::string label) {
  Detection detection;
  detection.add_label(label);
  LocationData* location_data = detection.mutable_location_data();
  location_data->mutable_bounding_box()->set_xmin(0.1);
  location_data->mutable_bounding_box()->set_width(0.2);
  location_data->mutable_bounding_box()->set_ymin(0.3);
  location_data->mutable_bounding_box()->set_height(0.4);
  return detection;
}

}

// テスト
TEST(DetectionsToCountsCalculatorTest, TestOnePacket) {

  // テストするCalculatorだけのGraphを読み込む
  CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
    calculator: "DetectionsToCountsCalculator"
    input_stream: "DETECTIONS:detections"
    output_stream: "DETECTIONS_COUNTS:detections_counts"
  )"));

  // 入力データを用意する
  std::vector<Detection> input_stream_detections;
  for (int i=0; i<3; ++i) {
    // 同じ分類が1つのみの検出結果を入力データに含める
    Detection detection = CreateTestDectionData(absl::StrFormat("test_%d_counts_1",i));
    input_stream_detections.push_back(detection);
  }
  // 同じ分類が2つある検出結果も入力データに含める
  input_stream_detections.push_back(CreateTestDectionData("test_3_counts_2"));
  input_stream_detections.push_back(CreateTestDectionData("test_3_counts_2"));

  // 入力データをCalculatorRunnerに取り付ける
  runner.MutableInputs()->Tag("DETECTIONS").packets.push_back(
    MakePacket<std::vector<Detection>>(input_stream_detections).At(Timestamp(0)));

  // CalculatorRunnerを走らせる
  MP_ASSERT_OK(runner.Run());

  // 出力結果を取り出す
  const std::vector<Packet>& output_detections_counts_packets = 
    runner.Outputs().Tag("DETECTIONS_COUNTS").packets;
  ASSERT_EQ(1, output_detections_counts_packets.size());
  const DetectionsCounts& output_detections_counts = 
    output_detections_counts_packets[0].Get<DetectionsCounts>();

  // 出力結果が正しいか確認する
  EXPECT_EQ(output_detections_counts.counts().at("test_0_counts_1"), 1);
  EXPECT_EQ(output_detections_counts.counts().at("test_1_counts_1"), 1);
  EXPECT_EQ(output_detections_counts.counts().at("test_2_counts_1"), 1);
  EXPECT_EQ(output_detections_counts.counts().at("test_3_counts_2"), 2);

}

} // namespace mediapipe

細かい説明は省略します。

(3) 処理が未実装のCalculatorを用意する

次は処理が未実装のCalculatorを用意します。

detections_to_counts_calculator.ccという名前のファイルをつくり、下記の内容を書きます。

detections_to_counts_calculator.cc
#include "mediapipe/framework/calculator_framework.h"

namespace mediapipe {

// 検出結果の数を分類ラベルごとにカウントするCalculator。
// Example Config (設定の例):
//   node {
//     calculator: "DetectionsToCountsCalculator"
//     input_stream: "DETECTIONS:detections"
//     output_stream: "DETECTIONS_COUNTS:detections_counts"
//   }
class DetectionsToCountsCalculator : public CalculatorBase {

 public:
  DetectionsToCountsCalculator() {};
  ~DetectionsToCountsCalculator() override {};
  DetectionsToCountsCalculator(const DetectionsToCountsCalculator &) = delete;
  DetectionsToCountsCalculator& operator=(const DetectionsToCountsCalculator &) = delete;

  static ::mediapipe::Status GetContract(CalculatorContract* cc);
  ::mediapipe::Status Open(CalculatorContext* cc) override;
  ::mediapipe::Status Process(CalculatorContext* cc) override;

};

REGISTER_CALCULATOR(DetectionsToCountsCalculator);

::mediapipe::Status DetectionsToCountsCalculator::GetContract(CalculatorContract* cc) {
  // 処理はまだ未実装
  return ::mediapipe::OkStatus();
}

::mediapipe::Status DetectionsToCountsCalculator::Open(CalculatorContext* cc) {
  // 処理はまだ未実装
  return ::mediapipe::OkStatus();
}

::mediapipe::Status DetectionsToCountsCalculator::Process(CalculatorContext* cc) {
  // 処理はまだ未実装
  return ::mediapipe::OkStatus();
}

} // namespace mediapipe

Calculatorをつくるときは、CalculatorBaseというクラスを継承してつくります。実際どんな挙動をしてほしいかは必要に応じてGetContractOpenProcessCloseに処理を書きます。これらの意味はここでは詳しくは説明しないので、詳細は公式ドキュメントのCalculatorsの説明を参照してください。

いまは処理が未実装のCalculatorを用意したいので、まだ現時点では、何もしないでOkStatus()を返すだけになっています。

(4) BUILDファイルにCalculatorとテストのビルドターゲットを記述する

同じディレクトリ階層のBUILDファイルに、新しく書いたCalculatorとテストのソースコードをビルドするためのビルドターゲットを書き加えます。

まず、C++のビルドルールをロードするために、先頭に下記を書き加えます。

load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")

そして、末尾にCalculatorとテストのビルドターゲットを書き加えます。

cc_library(
    name = "detections_to_counts_calculator",
    srcs = ["detections_to_counts_calculator.cc"],
    visibility = ["//visibility:public"],
    deps = [
        "//mediapipe/framework:calculator_framework",
    ],
    alwayslink = 1,
)

cc_test(
    name = "detections_to_counts_calculator_test",
    size = "small",
    srcs = ["detections_to_counts_calculator_test.cc"],
    deps = [
        ":detections_to_counts_calculator",
        ":detections_counts_cc_proto",
        "//mediapipe/framework:calculator_runner",
        "//mediapipe/framework/port:gtest_main",
        "//mediapipe/framework/port:parse_text_proto",
        "//mediapipe/framework/formats:detection_cc_proto",
    ],
)

cc_librarycc_testに関してはBazel公式ドキュメントの C/C++ Rules ページを参照してください。

(5) テストが失敗するのを確認する

MediaPipeのディレクトリ(同じ階層にWORKSPACEファイルがある階層)に移動して、下記のコマンドでテストを実行します。

bazel run //extensions/calculators:detections_to_counts_calculator_test

まだ処理を実装していないので、テストは失敗するはずです。

ただし、ビルドはエラーが発生しないはずです。もしビルドのエラーが発生した場合は、エラー文を頼りに修正する必要があります。例えば、「includeしようとしているヘッダーファイルが見つからない」といったようなエラーのときは、C++のコードだけでなく、BUILDファイルも見てビルドターゲットの依存先を表すdepsの一覧が正しいか確認してみるといいかもしれません。

ビルドは失敗せず、テストが失敗したのを確認して、次へ進みます。

(6) Calculatorの処理の実装

いよいよCalculatorの処理を実装します。

detections_to_counts_calculator.ccファイルを開いて、さきほど「処理はまだ未実装」としていた箇所に処理を書いていきます。必要に応じて依存先が増える場合は、ソースコードのほうでincludeするだけでなく、BUILDファイル内のビルドターゲットのdepsの一覧に依存先を書き足す必要もあります。

処理を実装してdetections_to_counts_calculator.ccファイルの内容が下記になるようにします。(先頭あたりに新しく追加されたincludeconstexpr char等を忘れないように注意してください)

detections_to_counts_calculator.cc
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "extensions/calculators/detections_counts.pb.h"

namespace mediapipe {

namespace {

constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionsCountsTag[] = "DETECTIONS_COUNTS";

}

// 検出結果の数を分類ラベルごとにカウントするCalculator。
// Example Config (設定の例):
//   node {
//     calculator: "DetectionsToCountsCalculator"
//     input_stream: "DETECTIONS:detections"
//     output_stream: "DETECTIONS_COUNTS:detections_counts"
//   }
class DetectionsToCountsCalculator : public CalculatorBase {

 public:
  DetectionsToCountsCalculator() {};
  ~DetectionsToCountsCalculator() override {};
  DetectionsToCountsCalculator(const DetectionsToCountsCalculator &) = delete;
  DetectionsToCountsCalculator& operator=(const DetectionsToCountsCalculator &) = delete;

  static ::mediapipe::Status GetContract(CalculatorContract* cc);
  ::mediapipe::Status Open(CalculatorContext* cc) override;
  ::mediapipe::Status Process(CalculatorContext* cc) override;

};

REGISTER_CALCULATOR(DetectionsToCountsCalculator);

::mediapipe::Status DetectionsToCountsCalculator::GetContract(CalculatorContract* cc) {

  RET_CHECK(cc->Inputs().HasTag(kDetectionsTag))
    << "Input stream <" << kDetectionsTag << "> is not provided";

  RET_CHECK(cc->Outputs().HasTag(kDetectionsCountsTag)) 
    << "Output stream <" << kDetectionsCountsTag << "> is not provided";

  cc->Inputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
  cc->Outputs().Tag(kDetectionsCountsTag).Set<DetectionsCounts>();

  return ::mediapipe::OkStatus();

}

::mediapipe::Status DetectionsToCountsCalculator::Open(CalculatorContext* cc) {

  cc->SetOffset(TimestampDiff(0));
  return ::mediapipe::OkStatus();

}

::mediapipe::Status DetectionsToCountsCalculator::Process(CalculatorContext* cc) {

  // 空の入力が来た場合はすぐにOkStatusを返す
  if (cc->Inputs().Tag(kDetectionsTag).IsEmpty()) {
    return ::mediapipe::OkStatus();
  }

  auto detections_counts = absl::make_unique<DetectionsCounts>();

  // 入力から検出結果のデータを取り出す
  auto detections = cc->Inputs().Tag(kDetectionsTag).Get<std::vector<Detection>>();

  // 検出結果を分類ラベル別にカウントする
  for (const auto& detection : detections) {
    std::string label = detection.label()[0];
    if (detections_counts->counts().contains(label)) {
      (*detections_counts->mutable_counts())[label] += 1;
    }
    else {
      (*detections_counts->mutable_counts())[label] = 1;
    }
  }

  // 出力
  cc->Outputs()
    .Tag(kDetectionsCountsTag)
    .Add(detections_counts.release(), cc->InputTimestamp());

  return ::mediapipe::OkStatus();

}

} // namespace mediapipe

また、BUILDファイル内のdetections_to_counts_calculatorビルドターゲットも下記のようになります。(depsの一覧が書き足されています。)

cc_library(
    name = "detections_to_counts_calculator",
    srcs = ["detections_to_counts_calculator.cc"],
    visibility = ["//visibility:public"],
    deps = [
        ":detections_counts_cc_proto",
        "//mediapipe/framework:calculator_framework",
        "//mediapipe/framework/formats:detection_cc_proto",
    ],
    alwayslink = 1,
)

(7) テストが成功するのを確認する

さきほど処理の実装の前にテストをしたときと同じようにテストします。ディレクトリをWORKSPACEファイルがある同じ階層に移動してからbazelのコマンドを実行するのを忘れないでください。

bazel run //extensions/calculators:detections_to_counts_calculator_test

テストが成功すれば完了です。

もしエラーやテスト失敗が出るようであればエラーメッセージを頼りにテストが成功するまで修正しましょう。

まとめ

テスト駆動でMediaPipeのCalculatorをつくる流れは以上です。

実際にいろいろと開発するときは、MediaPipeの仕様等をもう少し細かく把握する必要が出てくるかもしれません。ここではそこまで細かく説明しなかったので、その場合はMediaPipeのソースコードや公式ドキュメント等を読むと参考になるかもしれません。また、Protocol BuffersやBazelの知識も必要になってくると思うので、そちらも必要に応じてそれぞれの公式ドキュメントを参照するとよいかもしれません。

参考

6
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
0