search
LoginSignup
5

More than 1 year has passed since last update.

posted at

updated at

[Python, unittest] 標準入出力のテストをする [競プロとかコーディングテストとか]

この記事はPython Advent Calendar 2020の10日目の記事です。

標準入出力でのテストケースってどう書くの?

コーディングテストでアルゴリズムを解く問題に出会ったんですけど,やっぱunittest使ってしゅばっとテストしたいじゃないですか.10行とかある入力を打ちこんで出力n行見比べるとか地獄ですし.けどunittestで標準入力->標準出力のテストの書き方が調べても出てこなかったので作りました.また,標準エラー出力も拾います.

では,みていきましょう!
急いでいる方は一番最後へどうぞ.
また,ツッコミや修正などありましたら,コメントいただければと思います.

想定

同一ディレクトリ内に,
- main.py: 実行したいPythonファイル
- test.py: テストのためのPythonファイル
があるものとします.

また,テストには標準ライブラリであるunittestを使います.実行環境はPython3.7ですがおそらくだいたい動くでしょう.あまりに古いとsubprocessの中身が違って subprocess.Popen がないとかはあるかもしれません.

やりかた

結局はいかに標準入力をして,標準出力・標準エラー出力を回収するか,だと思います.しかも,競プロとかコーディングテストだとやっぱり

python main.py

って呼び出した結果をテストしたいですよね.そこで,subprocessの出番です.子プロセスを作ってしまいます.

そのために使うのが p = subprocess.Popenp.communicateです.

1つ目のポイントはプロセスを作るときに標準入力・標準出力・標準エラー出力をパイプを用いてこのtest.pyとつなげてしまうことです.['python', 'main.py']みたいな代わりにshell=Trueを使うかとかはご自身の方針や用途に合わせてご決定なさってください.

2つ目のポイントはo, e = p.communicate(input=param[0].encode())の部分でしょう.標準入力をテストのパラメータを利用し(input=param[0].encode()),標準出力(o: outputのつもり)・標準エラー出力(e: errorのつもり)を回収します.標準入力・標準出力・標準エラー出力に関してはすべてエンコードやデコードが必要なことに注意が必要です.つまりこの段階のo, eはデコードする必要があります.

テストしよう

個人的にはこれでいいかな,となりました.

self.assertEqual(e.decode(), "")
self.assertEqual(o.decode().strip(), str(param[1]))

つまり,
- エラーがない,つまり標準エラー出力が空っぽ: self.assertEqual(e.decode(), "")
- 答えが正しい: self.assertEqual(o.decode().strip(), str(param[1]))
標準出力・標準エラー出力はデコードが必要なことと,標準出力は問答無用でstrなのでそこに気を付ければいいかなと思います.

完成

最終的にはこんな感じになりました.
こうしてやれば,入出力が複数行あっても楽ちんです.改行文字いれてやればいいですもんね.

test.py
import subprocess
from unittest import TestCase, main


class MyTestClass(TestCase):
    def _std_test(self, param):
        p = subprocess.Popen(
            ['python', 'main.py'],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        o, e = p.communicate(input=param[0].encode())
        if o.decode().strip() != str(param[1]):
            print(o.decode())
        self.assertEqual(e.decode(), "")
        self.assertEqual(o.decode().strip(), str(param[1]))

    def test_case(self):
        """
        標準入力するものや期待される標準出力はダミーデータです
        """
        params = (
             ("input_1", "answer_1"),
             ("input_2\ninput_2_2", "answer_2\nanswer_2_second"),
        )

        for param in params:
            with self.subTest(param=param):
                self._std_test(param)

参考になれば幸いです!

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
5