LoginSignup
10
4

More than 5 years have passed since last update.

Google Test勉強録 (2) アルゴリズムのテストを書く

Last updated at Posted at 2018-02-18

1. はじめに

前回 の続きで、Google Testを使ってテストを書く勉強録です。

Google Testを使う個人的な動機として、機械学習などのアルゴリズムを実装するときに気軽にテストを書けたらいいなと思います。
そこで、今回はそれっぽい背景を例題として用意し、テスト駆動でアルゴリズムの実装をしてみようと思います。

2. それっぽい例: Bell数計算

ベル数 というものがあります。

ベル数とは次のようなものです。有限集合を空でない集合に分割する方法を列挙することを考えます。例えば、3点集合 {a, b, c} の分割は

{a, b, c}
{a, b}, {c}
{a}, {b, c}
{a, c}, {b}
{a}, {b}, {c}

の5通りです。$n$点集合の分割の総数をベル数といい、$B_n$と書きます。ベル数を小さい方から並べると

1, 1, 2, 5, 15, 52, 203, 877, 4140, 21147, 115975, 678570, 4213597, 27644437, 190899322, 1382958545, 10480142147, ...

となり、それなりに速く増大します。($O((n / \log(n))^n)$ くらい)

ベル数を計算する方法はいくつかあります。

2.1. 漸化式とBell三角形

$B_0 = 1$, および
$$
B_{n + 1} = \sum_{k = 0}^n \binom{n}{k} B_k
$$
で計算できます。$\binom{n}{m}$は二項係数です。

よって、二項係数を計算する機能と再帰っぽいことをする機能があれば計算できるはずです。これを実装した場合のテストとしては、とりあえずn = 10くらいまでで既知の値と一致するか確認すればよさそうです。

また、パスカルの三角形 と似たような手順で、右側にベル数が現れる三角列が構成できることが知られています (Bell triangle)。

1
1, 2
2, 3, 5
5, 7, 10, 15
15, 20, 27, 37, 52
52, 67, 87, 114, 151, 203
203, 255, 322, 409, 523, 674, 877

これを使えば少なくとも $O(n^2)$ で計算できるはずです。

2.2. Dobiński's formula

実は、平均1のPoisson分布の$n$次モーメントがベル数になります。なので、このような公式があります。
$$
B_n = \frac{1}{e} \sum_{i=0}^{\infty} \frac{k^n}{k!}
$$
無限和を有限の適当なところで打ち切れば、そこそこ良いベル数の近似が得られると思われます。

困ったことに、こういうのを実装すると結果は浮動小数点型になるはずです。なので、「だいたい近かったら通る」というテストを書きたいです。

3. とりあえずテストを書く

3.1. 共通の設定はフィクスチャに書く

次のような関数を書きたいとします。

// ベル三角形を使って B_n を計算して返す関数
int bell_trianble(int n);

// 漸化式を使って B_n を計算して返す関数
int bell_recursive(int n);

n=10くらいまでは既知の値が Wikipedia に載っていますので、それを信じてテストを書けばよいでしょう。

TEST(BellTriangleTest, Equal) {
  EXPECT_EQ(1, bell_triangle(1));
  EXPECT_EQ(2, bell_triangle(2));
  EXPECT_EQ(5, bell_triangle(3));
  EXPECT_EQ(15, bell_triangle(4));
  EXPECT_EQ(52, bell_triangle(5));
  EXPECT_EQ(203, bell_triangle(6));
  EXPECT_EQ(877, bell_triangle(7));
  EXPECT_EQ(4140, bell_triangle(8));
  EXPECT_EQ(21147, bell_triangle(9));
}

さて、定義上 bell_trianglebell_recursive の結果は同じになってほしいので、次のようなテストを実行したいです。

TEST(BellRecursiveTest, Compare) {
  EXPECT_EQ(bell_triangle(1), bell_recursive(1));
  EXPECT_EQ(bell_triangle(2), bell_recursive(2));
  EXPECT_EQ(bell_triangle(3), bell_recursive(3));
  EXPECT_EQ(bell_triangle(4), bell_recursive(4));
  EXPECT_EQ(bell_triangle(5), bell_recursive(5));
  EXPECT_EQ(bell_triangle(6), bell_recursive(6));
  EXPECT_EQ(bell_triangle(7), bell_recursive(7));
  EXPECT_EQ(bell_triangle(8), bell_recursive(8));
  EXPECT_EQ(bell_triangle(9), bell_recursive(9));
}

しかし、bell_triangle(n) (n = 1, ... 9) の値は別のテスト BellTriangleTest.Equal でも利用するため、同じ値を複数回計算することになってしまいます。bell_triangle(n) の計算量は$O(n^2)$ ですが、実行にもっと時間がかかる関数のテストを行いたいときは、計算済みの値を使いまわせた方がよいでしょう。

というわけで下のようなテストフィクスチャを書いて、bell_triangle(n)の値をいくつか事前計算することにします。

class BellComparisonTest: public testing::Test {
protected:
  constexpr static int n_max_ = 10;
  std::vector<int> bell_triangle_precompute_;

  // 開始時処理
  virtual void SetUp() {
    bell_triangle_precompute_.reserve(n_max_ + 1);

    for (int i = 0; i <= n_max_; ++i) {
      bell_triangle_precompute_.push_back(bell_triangle(i));
    }
  }

  // 終了時処理 (今回は不要)
  //virtual boid TearDown() {}
};

フィクスチャを利用したテストには TEST_F(フィクスチャクラス名、テスト名) を使います。

TEST_F(BellComparisonTest, Recursive) {
  for (int i = 0; i <= n_max_; ++i) {
    EXPECT_EQ(bell_triangle_precompute_[i], bell_recursive(i));
  }
}

3.2. 浮動小数点のテスト

また、Dobinski公式を打ち切って計算する関数も書いてみます。

// 無限和をtruncまでで打ち切ったDobinski公式
double bell_dobinski(int n, int trunc);

truncが十分大きければ、bellとだいたい同じ値が返ってきてほしいです。しかし、double型の近似値なのでイコールにはなってくれません。機械学習のアルゴリズムなんかをテストしたいときは、こういう状況が頻出するかと思います。

浮動小数点数どうしの比較は 色々な論争 があるらしいです。こわい。
ひとまず、Google Testでは誤差範囲を適当に見繕って比較してくれる機能があります。

TEST_F(BellComparisonTest, Dobinski) {
  int trunc = 50;
  for (int i = 0; i <= n_max_; ++i) {
    EXPECT_DOUBLE_EQ((double) bell_triangle_precompute_[i],
                     bell_dobinski(i, trunc));
  }
}

絶対誤差の程度を明示的に指定することもできます。今回は、そもそも収束列を途中で打ち切る実装になっていて、数学的にも=でないことがわかっているため、意味合い的にはこちらのほうがいいかもしれません。

TEST_F(BellComparisonTest, DobinskiWithTol) {
  int trunc = 50;
  double tol = 1e-8;
  for (int i = 0; i <= n_max_; ++i) {
    EXPECT_NEAR((double) bell_triangle_precompute_[i],
                bell_dobinski(i, trunc), tol);
  }
}

3.3. 例外のテスト

$n < 0$ に対して $B_n$ は未定義なので、定義域エラーを投げるようにしたとします。

int bell_triangle(int n) {
  if (n < 0) {
    throw std::domain_error("定義域エラー");
  }

  // --- do something --- //

  return something;
}

何かしらエラーを投げることが期待されるとき、それをテストする機能もあります。

TEST(BellTriangleTest, WrongDomain) {
  // n < 0 だと定義域エラー
  EXPECT_ANY_THROW(bell_triangle(-1));
  EXPECT_ANY_THROW(bell_triangle(-5));
}

エラーの型を明示的に指定することもできます。

TEST(BellTriangleTest, WrongDomainExplicit) {
  // n < 0 だと定義域エラー
  EXPECT_THROW(bell_triangle(-1), std::domain_error);
  EXPECT_THROW(bell_triangle(-5), std::domain_error);
}

4. テストの通りにがんばって実装する

というわけでやります。

今回はこのような構成になりました。bell/hoge.h に対応するテストを test/test_hoge.cpp に書いている感じです。

.
├── CMakeLists.txt
├── bell
│   ├── CMakeLists.txt
│   ├── bell.cpp
│   ├── bell.h
│   └── common.h //これはヘッダオンリー
└── test
    ├── CMakeLists.txt
    ├── test_bell.cpp
    └── test_common.cpp

5. 結果

5.1. トライアル1回目

まずテストを実行してみます。

Running tests...
Test project ${project_dir}/build
      Start  1: BellTriangleTest.WrongDomain
 1/20 Test  #1: BellTriangleTest.WrongDomain ...........   Passed    0.03 sec
      Start  2: BellTriangleTest.WrongDomainExplicit
 2/20 Test  #2: BellTriangleTest.WrongDomainExplicit ...   Passed    0.01 sec
      Start  3: BellTriangleTest.Zero
 3/20 Test  #3: BellTriangleTest.Zero ..................   Passed    0.01 sec
      Start  4: BellTriangleTest.Equal

-- (中略) --

19/20 Test #19: FactorialTest.Zero .....................   Passed    0.01 sec
      Start 20: FactorialTest.Equal
20/20 Test #20: FactorialTest.Equal ....................   Passed    0.01 sec

85% tests passed, 3 tests failed out of 20

Total Test time (real) =   0.30 sec

The following tests FAILED:
      9 - BellDobinskiTest.Zero (Failed)
     10 - BellComparisonTest.Dobinski (Failed)
     11 - BellComparisonTest.DobinskiWithTol (Failed)
Errors while running CTest

あっ。
浮動小数点数の比較テストが全部落ちました。

何がいけなかったのか…。

CMake (CTest) はGoogle Testのエラーメッセージをデフォルトで隠してしまうようなので、以下を実行します。

ctest --verbose

結果の一部:

// -- (中略) -- //

10: Test timeout computed to be: 10000000
10: Running main() from gtest_main.cc
10: Note: Google Test filter = BellComparisonTest.Dobinski
10: [==========] Running 1 test from 1 test case.
10: [----------] Global test environment set-up.
10: [----------] 1 test from BellComparisonTest
10: [ RUN      ] BellComparisonTest.Dobinski
10: ${project_dir}/test/test_bell.cpp:85: Failure
10: Expected equality of these values:
10:   (double) bell_triangle_precompute_[i]
10:     Which is: 1
10:   bell_dobinski(i, trunc)
10:     Which is: inf   // (あっ!!)

// -- (中略) -- //

bell_dobinski がinfを返してますね。これはどこかで0で割ってる気がする…
実装をみてみます。

double bell_dobinski(int n, int trunc) {
  if (n < 0) {
    throw std::domain_error("Invalid input: bell_dobinski(n, trunc) is defined for n >= 0");
  }
  if (trunc < 1) {
    throw std::domain_error("Invalid input: bell_dobinski(n, trunc) is defined for trunc >= 1");
  }

  double out = 0;
  for (int i = 0; i <= trunc; ++i) {
    out += std::pow((double)i, n) / factorial(i); // <-- ここが怪しい
  }
  return out / math_e;
}

一箇所怪しい割り算があります。
int factorial(int n) というのは自分で適当に実装した階乗なんですが、truncは50くらいまで動くのでオーバーフローしている気がしてきました。

試しに std::tgamma(n + 1) と比較するテストを書いてみると

test/test_common.cpp

TEST(FactorialTest, Overflow) {
  EXPECT_DOUBLE_EQ(std::tgamma(10 + 1), (double) factorial(10));
  EXPECT_DOUBLE_EQ(std::tgamma(20 + 1), (double) factorial(20));
  EXPECT_DOUBLE_EQ(std::tgamma(40 + 1), (double) factorial(40));
}
21: Test command: ${project_dir}/build/test/TestCommon "--gtest_filter=FactorialTest.Overflow"
21: Test timeout computed to be: 10000000
21: Running main() from gtest_main.cc
21: Note: Google Test filter = FactorialTest.Overflow
21: [==========] Running 1 test from 1 test case.
21: [----------] Global test environment set-up.
21: [----------] 1 test from FactorialTest
21: [ RUN      ] FactorialTest.Overflow
21: ${project_dir}/test/test_common.cpp:72: Failure
21: Expected equality of these values:
21:   std::tgamma(20 + 1)
21:     Which is: 2.43290200817664e+18
21:   (double) factorial(20)
21:     Which is: -2102132736 // <-- オーバーフローしてる
21: ${project_dir}/test/test_common.cpp:73: Failure
21: Expected equality of these values:
21:   std::tgamma(40 + 1)
21:     Which is: 8.1591528324789411e+47
21:   (double) factorial(40)
21:     Which is: 0 // <-- オーバーフローして0になってる
21: [  FAILED  ] FactorialTest.Overflow (1 ms)
21: [----------] 1 test from FactorialTest (1 ms total)

オーバーフローしてましたね。わざとじゃないです。わざとじゃなかったです…。

5.2. トライアル2回目

$k^n / k!$ の分母と分子をオレオレ実装で別々に評価してから割ると、打ち切り次数20程度でもオーバーフローしていることがわかりました。

本当は、このようにしなければなりませんでした。

double bell_dobinski(int n, int trunc) {
  if (n < 0) {
    throw std::domain_error("Invalid input: bell_dobinski(n, trunc) is defined for n >= 0");
  }
  if (trunc < 1) {
    throw std::domain_error("Invalid input: bell_dobinski(n, trunc) is defined for trunc >= 1");
  }

  if (n == 0) {
    return 1.0;
  }

  double out = 0;
  for (int i = 1; i <= trunc; ++i) {
    out += std::exp(n * std::log(i) - std::lgamma(i + 1));
  }
  return out / math_e;
}

$n \geq 1, k \geq 1$ のときは、
$$
\frac{k^n}{k!} = \exp(n \log k - \log \Gamma(k + 1))
$$
で計算することにして、オーバーフローをなるべく避けます。$n = 0$ のときだけ例外対応すればOKです。

というわけで、再チャレンジします。

Test project ${project_dir}/build
      Start  1: BellTriangleTest.WrongDomain
 1/20 Test  #1: BellTriangleTest.WrongDomain ...........   Passed    0.01 sec
      Start  2: BellTriangleTest.WrongDomainExplicit
 2/20 Test  #2: BellTriangleTest.WrongDomainExplicit ...   Passed    0.01 sec

// --(中略)-- //

 9/20 Test  #9: BellDobinskiTest.Zero ..................   Passed    0.01 sec
      Start 10: BellComparisonTest.Dobinski
10/20 Test #10: BellComparisonTest.Dobinski ............   Passed    0.01 sec
      Start 11: BellComparisonTest.DobinskiWithTol
11/20 Test #11: BellComparisonTest.DobinskiWithTol .....   Passed    0.02 sec

// --(中略)-- //

100% tests passed, 0 tests failed out of 20

Total Test time (real) =   0.22 sec

できました。浮動小数点数どうしの比較もよしなにやってくれているみたいですね。

6. まとめ

Google Testを使ってテストを書いてからアルゴリズムを書く練習をしたり、(わざとではないが)テストに失敗して実装を直したりしました。

今回のコードは下にあります (sample2)。
https://github.com/ktrmnm/gtest-learn

10
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
4