テスト駆動でMediaPipeのCalculatorをつくる
はじめに
テスト駆動でMediaPipeのCalculatorをつくる流れについて紹介します。
CalculatorというのはMediaPipeにおける処理の単位で、このCalculatorをノードとした有向グラフを記述することで処理の流れを記述することができます。(MediaPipeフレームワークの各概念に関する詳しい説明は公式ドキュメントの「Framework Concepts」を参照してください。)
MediaPipeで新しい機能やアプリケーションを作りたいとき、既存のCalculatorの組み合わせでは実現できないため新しくCalculatorをつくる必要がある場合があります。
こういった場合で、Calculatorをテスト駆動で新しく作っていく流れについて書いてきます。もし参考になれば幸いです。
対象読者としては下記を念頭に置いています。
- MediaPipeのインストールができる
- 基本的なC++のプログラムが書ける
注:ここでは、MediaPipeのバージョンはv0.8.8としています。MediaPipeはまだアルファ版で今後互換性のない変更が生じる可能性があるので、注意が必要です。
つくるもの
例として検出結果の数を分類ラベルごとにカウントするCalculatorをつくります。
なお、カウントがしたいだけでそのカウント結果をMediaPipeフレームワーク内の他の処理が使うわけではないなら、Calculatorとしてつくらずに、MediaPipeの外側で結果を受け取ってカウント処理を行うというやり方もあると思います。そのため、実際ではMediaPipeのCalculatorとしてつくるべきかどうかも一度検討するとより良いかもしれません。
ここでは、比較的シンプルに説明できるため、検出結果の数をカウントするCalculatorをつくっていきます。
(1) 準備
(1.1) MediaPipeをインストールする
公式ドキュメントの説明にしたがって、MediaPipeをインストールします。
以下ではv0.8.8リリースを解凍してインストールを進めたとしていますが、git clone や git submodule で同バージョンにしてから進めてもとくに問題ないはずです。その場合はMediaPipeのディレクトリ名をmediapipe-0.8.8
からmediapipe
と読み替えてください。
(1.2) 新しくつくるCalculatorを格納するディレクトリをつくる
分かりやすいように、別でディレクトリをつくって、新しくつくるCalculatorはそこに置くことにします。別でつくっておくもう一つの利点としてはMediaPipeのバージョンを上げやすくなることです。
インストールしたMediaPipeのディレクトリと同じ階層にextensions
というディレクトリをつくり、その配下にさらにcalculators
ディレクトリをつくります。名前は適宜変えてもかまいません。
その後、MediaPipeのディレクトリの配下にさっきつくったextensions
ディレクトリへのシンボリックリンクをつくります。
最終的には下記のようなディレクトリ構造になっているはずです。
extensions
calculators
mediapipe-0.8.8
... (※他は略)
docs
extensions (※シンボリックリンク)
mediapipe
third_party
... (※他は略)
Calculatorとそのテストはextensions/calculators
の配下につくっていきます。
(2) テストをつくる
(2.1) 入出力データの仕様を考える/つくる
まずはCalculatorの入力と出力を考えます。MediaPipeでは、Calculatorへの入出力データ仕様は Protocol Buffers で定義することができます。
いまは検出結果のカウントをするCalculatorをつくろうとしているので、入力は検出結果ということになるでしょう。MediaPipeですでにDetection
というものが定義されているので、入力データにはこれを使うことにします。
出力データは検出結果の各分類のカウント数が含まれるようにしましょう。こちらは新しく定義することにします。
extensions/calculators
配下にdetections_counts.proto
というファイルをつくります。そして、下記の内容を書きます。
syntax = "proto2";
package mediapipe;
message DetectionsCounts {
// 検出結果の分類別のカウント。ラベルテキストからカウント数へのマッピング。
map<string, int64> counts = 1;
}
Protocol Buffers の言語仕様についてはここでは説明しません。詳しくは Protocol Buffers の言語ガイドを参照してください。
さらに、このprotoファイルのビルドを定義します。MediaPipeではビルドにBazelを使っています。そのため、Bazel用にBUILDファイルをつくって、そこにprotoファイルのビルドを記述します。具体的には次のようにつくります。
protoファイルと同じ階層にBUILD
という名前のファイルをつくり、下記の内容を書きます。
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
mediapipe_proto_library(
name = "detections_counts_proto",
srcs = ["detections_counts.proto"],
visibility = ["//visibility:public"],
)
BUILDファイルが書けたら、新しくつくったprotoファイルをビルドしてみましょう。MediaPipeのディレクトリに移動して、下記のコマンドでビルドできます。
bazel build //extensions/calculators:detections_counts_proto
ビルドが成功するのを確認します。
(2.2) テストの実装
Calculatorのテストのソースファイルは、後ろに_test
をつける慣習になっているようなので(例:top_k_scores_calculator_test.cc)、これにしたがいます。
今回つくるCalculatorはDetectionsToCountsCalculator
という名前にしましょう。そのため、そのテストはdetections_to_counts_calculator_test.cc
という名前でファイルを作ります。
ここから、テストを実装していくわけですが、その前に、MediaPipeのCalculatorではどのようにテストが書かれているか、説明します。
MediaPipeのCalculatorでは、テストにgtest (Google Test) が使われているようです(includeしている箇所の例)。
CalculatorのテストをするためにはそのCalculatorだけが含まれたGraphが走らされます。ParseTextProtoOrDie
でテキストを読み取ってCalculatorRunner
に渡されます。そして、CalculatorRunner
のMutableInputs()
から入力し、その後CalculatorRunner
を走らせ、Outputs()
から出力を取り出しています。この流れの例としては、MediaPipeのソースコードのtop_k_scores_calculator_test.cc
の123-158行目(v0.8.8)あたりを参照してください。
似たような書き方で、こちらのdetections_to_counts_calculator_test.cc
を実装しましょう。下記のように書きます。
#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "extensions/calculators/detections_counts.pb.h"
namespace mediapipe {
namespace {
// テストデータをつくるためのヘルパー関数
Detection CreateTestDectionData(std::string label) {
Detection detection;
detection.add_label(label);
LocationData* location_data = detection.mutable_location_data();
location_data->mutable_bounding_box()->set_xmin(0.1);
location_data->mutable_bounding_box()->set_width(0.2);
location_data->mutable_bounding_box()->set_ymin(0.3);
location_data->mutable_bounding_box()->set_height(0.4);
return detection;
}
}
// テスト
TEST(DetectionsToCountsCalculatorTest, TestOnePacket) {
// テストするCalculatorだけのGraphを読み込む
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "DetectionsToCountsCalculator"
input_stream: "DETECTIONS:detections"
output_stream: "DETECTIONS_COUNTS:detections_counts"
)"));
// 入力データを用意する
std::vector<Detection> input_stream_detections;
for (int i=0; i<3; ++i) {
// 同じ分類が1つのみの検出結果を入力データに含める
Detection detection = CreateTestDectionData(absl::StrFormat("test_%d_counts_1",i));
input_stream_detections.push_back(detection);
}
// 同じ分類が2つある検出結果も入力データに含める
input_stream_detections.push_back(CreateTestDectionData("test_3_counts_2"));
input_stream_detections.push_back(CreateTestDectionData("test_3_counts_2"));
// 入力データをCalculatorRunnerに取り付ける
runner.MutableInputs()->Tag("DETECTIONS").packets.push_back(
MakePacket<std::vector<Detection>>(input_stream_detections).At(Timestamp(0)));
// CalculatorRunnerを走らせる
MP_ASSERT_OK(runner.Run());
// 出力結果を取り出す
const std::vector<Packet>& output_detections_counts_packets =
runner.Outputs().Tag("DETECTIONS_COUNTS").packets;
ASSERT_EQ(1, output_detections_counts_packets.size());
const DetectionsCounts& output_detections_counts =
output_detections_counts_packets[0].Get<DetectionsCounts>();
// 出力結果が正しいか確認する
EXPECT_EQ(output_detections_counts.counts().at("test_0_counts_1"), 1);
EXPECT_EQ(output_detections_counts.counts().at("test_1_counts_1"), 1);
EXPECT_EQ(output_detections_counts.counts().at("test_2_counts_1"), 1);
EXPECT_EQ(output_detections_counts.counts().at("test_3_counts_2"), 2);
}
} // namespace mediapipe
細かい説明は省略します。
(3) 処理が未実装のCalculatorを用意する
次は処理が未実装のCalculatorを用意します。
detections_to_counts_calculator.cc
という名前のファイルをつくり、下記の内容を書きます。
#include "mediapipe/framework/calculator_framework.h"
namespace mediapipe {
// 検出結果の数を分類ラベルごとにカウントするCalculator。
// Example Config (設定の例):
// node {
// calculator: "DetectionsToCountsCalculator"
// input_stream: "DETECTIONS:detections"
// output_stream: "DETECTIONS_COUNTS:detections_counts"
// }
class DetectionsToCountsCalculator : public CalculatorBase {
public:
DetectionsToCountsCalculator() {};
~DetectionsToCountsCalculator() override {};
DetectionsToCountsCalculator(const DetectionsToCountsCalculator &) = delete;
DetectionsToCountsCalculator& operator=(const DetectionsToCountsCalculator &) = delete;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(DetectionsToCountsCalculator);
::mediapipe::Status DetectionsToCountsCalculator::GetContract(CalculatorContract* cc) {
// 処理はまだ未実装
return ::mediapipe::OkStatus();
}
::mediapipe::Status DetectionsToCountsCalculator::Open(CalculatorContext* cc) {
// 処理はまだ未実装
return ::mediapipe::OkStatus();
}
::mediapipe::Status DetectionsToCountsCalculator::Process(CalculatorContext* cc) {
// 処理はまだ未実装
return ::mediapipe::OkStatus();
}
} // namespace mediapipe
Calculatorをつくるときは、CalculatorBase
というクラスを継承してつくります。実際どんな挙動をしてほしいかは必要に応じてGetContract
、Open
、Process
、Close
に処理を書きます。これらの意味はここでは詳しくは説明しないので、詳細は公式ドキュメントのCalculatorsの説明を参照してください。
いまは処理が未実装のCalculatorを用意したいので、まだ現時点では、何もしないでOkStatus()
を返すだけになっています。
(4) BUILDファイルにCalculatorとテストのビルドターゲットを記述する
同じディレクトリ階層のBUILD
ファイルに、新しく書いたCalculatorとテストのソースコードをビルドするためのビルドターゲットを書き加えます。
まず、C++のビルドルールをロードするために、先頭に下記を書き加えます。
load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
そして、末尾にCalculatorとテストのビルドターゲットを書き加えます。
cc_library(
name = "detections_to_counts_calculator",
srcs = ["detections_to_counts_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
],
alwayslink = 1,
)
cc_test(
name = "detections_to_counts_calculator_test",
size = "small",
srcs = ["detections_to_counts_calculator_test.cc"],
deps = [
":detections_to_counts_calculator",
":detections_counts_cc_proto",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/formats:detection_cc_proto",
],
)
cc_library
やcc_test
に関してはBazel公式ドキュメントの C/C++ Rules ページを参照してください。
(5) テストが失敗するのを確認する
MediaPipeのディレクトリ(同じ階層にWORKSPACE
ファイルがある階層)に移動して、下記のコマンドでテストを実行します。
bazel run //extensions/calculators:detections_to_counts_calculator_test
まだ処理を実装していないので、テストは失敗するはずです。
ただし、ビルドはエラーが発生しないはずです。もしビルドのエラーが発生した場合は、エラー文を頼りに修正する必要があります。例えば、「includeしようとしているヘッダーファイルが見つからない」といったようなエラーのときは、C++のコードだけでなく、BUILDファイルも見てビルドターゲットの依存先を表すdeps
の一覧が正しいか確認してみるといいかもしれません。
ビルドは失敗せず、テストが失敗したのを確認して、次へ進みます。
(6) Calculatorの処理の実装
いよいよCalculatorの処理を実装します。
detections_to_counts_calculator.cc
ファイルを開いて、さきほど「処理はまだ未実装」としていた箇所に処理を書いていきます。必要に応じて依存先が増える場合は、ソースコードのほうでincludeするだけでなく、BUILDファイル内のビルドターゲットのdeps
の一覧に依存先を書き足す必要もあります。
処理を実装してdetections_to_counts_calculator.cc
ファイルの内容が下記になるようにします。(先頭あたりに新しく追加されたinclude
やconstexpr char
等を忘れないように注意してください)
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "extensions/calculators/detections_counts.pb.h"
namespace mediapipe {
namespace {
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kDetectionsCountsTag[] = "DETECTIONS_COUNTS";
}
// 検出結果の数を分類ラベルごとにカウントするCalculator。
// Example Config (設定の例):
// node {
// calculator: "DetectionsToCountsCalculator"
// input_stream: "DETECTIONS:detections"
// output_stream: "DETECTIONS_COUNTS:detections_counts"
// }
class DetectionsToCountsCalculator : public CalculatorBase {
public:
DetectionsToCountsCalculator() {};
~DetectionsToCountsCalculator() override {};
DetectionsToCountsCalculator(const DetectionsToCountsCalculator &) = delete;
DetectionsToCountsCalculator& operator=(const DetectionsToCountsCalculator &) = delete;
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
};
REGISTER_CALCULATOR(DetectionsToCountsCalculator);
::mediapipe::Status DetectionsToCountsCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kDetectionsTag))
<< "Input stream <" << kDetectionsTag << "> is not provided";
RET_CHECK(cc->Outputs().HasTag(kDetectionsCountsTag))
<< "Output stream <" << kDetectionsCountsTag << "> is not provided";
cc->Inputs().Tag(kDetectionsTag).Set<std::vector<Detection>>();
cc->Outputs().Tag(kDetectionsCountsTag).Set<DetectionsCounts>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status DetectionsToCountsCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status DetectionsToCountsCalculator::Process(CalculatorContext* cc) {
// 空の入力が来た場合はすぐにOkStatusを返す
if (cc->Inputs().Tag(kDetectionsTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
auto detections_counts = absl::make_unique<DetectionsCounts>();
// 入力から検出結果のデータを取り出す
auto detections = cc->Inputs().Tag(kDetectionsTag).Get<std::vector<Detection>>();
// 検出結果を分類ラベル別にカウントする
for (const auto& detection : detections) {
std::string label = detection.label()[0];
if (detections_counts->counts().contains(label)) {
(*detections_counts->mutable_counts())[label] += 1;
}
else {
(*detections_counts->mutable_counts())[label] = 1;
}
}
// 出力
cc->Outputs()
.Tag(kDetectionsCountsTag)
.Add(detections_counts.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
} // namespace mediapipe
また、BUILD
ファイル内のdetections_to_counts_calculator
ビルドターゲットも下記のようになります。(deps
の一覧が書き足されています。)
cc_library(
name = "detections_to_counts_calculator",
srcs = ["detections_to_counts_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":detections_counts_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:detection_cc_proto",
],
alwayslink = 1,
)
(7) テストが成功するのを確認する
さきほど処理の実装の前にテストをしたときと同じようにテストします。ディレクトリをWORKSPACE
ファイルがある同じ階層に移動してからbazelのコマンドを実行するのを忘れないでください。
bazel run //extensions/calculators:detections_to_counts_calculator_test
テストが成功すれば完了です。
もしエラーやテスト失敗が出るようであればエラーメッセージを頼りにテストが成功するまで修正しましょう。
まとめ
テスト駆動でMediaPipeのCalculatorをつくる流れは以上です。
実際にいろいろと開発するときは、MediaPipeの仕様等をもう少し細かく把握する必要が出てくるかもしれません。ここではそこまで細かく説明しなかったので、その場合はMediaPipeのソースコードや公式ドキュメント等を読むと参考になるかもしれません。また、Protocol BuffersやBazelの知識も必要になってくると思うので、そちらも必要に応じてそれぞれの公式ドキュメントを参照するとよいかもしれません。