まとめ
- unittest で discover する対象クラスを限定する手段を書きます。
- 前半で、どのようなときに必要になるか、も書きます。
- 後半で、コードを書きます。
- load_tests() のコード例があります。
はじめに
現実世界で目的地にたどり着く交通手段が徒歩、車、公共交通機関と複数あっても到達点は目的地の1つでありながら、それぞれの手段に料金、時間、事故による影響をどれだけ受けうるか、などメリットやデメリットがあります。
コードを書く際も同様に、同じゴールを達成するために複数の手段があることがあります。
たとえば、ツリーの探索において、深さ優先・幅優先など複数の探索手段があり、メリット・デメリットがあります。ソートのアルゴリズムも同様です。ちょっと言いすぎですが、デザインパターンの Strategy のようなケースにおいては類似のケースが発生しえます。
場合によっては深さ優先、場合によっては幅優先、と選択するようなケースでは、コード内にそれぞれのテストケースを置くことになり、コードの中身を意識しないブラックボックスなテストを行う場合、愚直に実装すると似たテストコードが重複して存在することになります。
このようなときに load_tests() でテストケースを限定するコードを書けるよ、という話を書きます。
コード
例として、整数 x と y の掛け算の実装を「x 回 y を足し算」と「乗算」の2通りで実装し、それぞれのテストを実装することを考えます。
1) 足し算と乗算の実装とテスト
素朴に実装すると次のコードになります。
コメント「冗長」と書いた部分が冗長です。
import unittest
class MultiplyBase(object):
def multiply(self, n1: int, n2: int) -> int:
raise NotImplementedError()
class MultiplyPlus(MultiplyBase):
def multiply(self, n1: int, n2: int) -> int:
ret = 0
for i in range(0, n1):
ret += n2
return ret
class MultiplyAsterisk(MultiplyBase):
def multiply(self, n1: int, n2: int) -> int:
return n1 * n2
class TestMultiplyBase(unittest.TestCase):
def test_multiply(self):
obj: MultiplyBase = MultiplyBase()
with self.assertRaises(NotImplementedError):
obj.multiply(2, 3)
class TestMultiplyPlus(unittest.TestCase):
def test_multiply(self):
obj: MultiplyBase = MultiplyPlus()
self.assertEqual(obj.multiply(2, 3), 6) # 冗長!!
class TestMultiplyAsteriskClass(unittest.TestCase):
def test_multiply(self):
obj: MultiplyBase = MultiplyAsterisk()
self.assertEqual(obj.multiply(2, 3), 6) # 冗長!!
この量ならさほど気にならないけど、多くなると大変。
2) 冗長な箇所の排除を実現した実装とテスト
TestCaseHelper クラスでテストケースを実装し、このテストケースを派生して、具体的なテストケースを作ることにします。
TestCaseHelper を使うことで、各テストケースのクラスはどのクラスのオブジェクトを使うか、だけ実装すればよくなります。これにより冗長なコードを排除できます。
ただし、この場合、 TestCaseHelper クラスが unittest のデフォルトの挙動で discover されてしまいます。このためにテストケースが必ず失敗するので load_tests() により TestCaseHelper をテストケースとみなさないようにしています。
import unittest
class MultiplyBase(object):
def multiply(self, n1: int, n2: int) -> int:
raise NotImplementedError()
class MultiplyPlus(MultiplyBase):
def multiply(self, n1: int, n2: int) -> int:
ret = 0
for i in range(0, n1):
ret += n2
return ret
class MultiplyAsterisk(MultiplyBase):
def multiply(self, n1: int, n2: int) -> int:
return n1 * n2
class TestCaseHelper(unittest.TestCase):
def _get_multiply_obj(self) -> MultiplyBase:
raise NotImplementedError()
def test_multiply(self):
obj: MultiplyBase = self._get_multiply_obj()
self.assertEqual(obj.multiply(2, 3), 6)
class TestMultiplyBase(TestCaseHelper):
def _get_multiply_obj(self) -> MultiplyBase:
return MultiplyBase()
def test_multiply(self):
obj: MultiplyBase = self._get_multiply_obj()
with self.assertRaises(NotImplementedError):
obj.multiply(2, 3)
class TestMultiplyPlus(TestCaseHelper):
def _get_multiply_obj(self) -> MultiplyBase:
return MultiplyPlus()
class TestMultiplyAsterisk(TestCaseHelper):
def _get_multiply_obj(self) -> MultiplyBase:
return MultiplyAsterisk()
def load_tests(loader, tests, patterns):
test_cases = (TestMultiplyBase, TestMultiplyPlus, TestMultiplyAsterisk)
suite = unittest.TestSuite()
for test_class in test_cases:
tests = loader.loadTestsFromTestCase(test_class)
suite.addTests(tests)
return suite
これで、load_tests() で一部のテストケースを除外できました。