LoginSignup
4
1

More than 3 years have passed since last update.

PytestでPublic関数をMockする方法

Last updated at Posted at 2020-02-09

こんにちは、セールで買った某M○Proteinの「ほうじ茶ラテ味」プロテインが海老の出汁みたいな味で飲めないこの頃です(なに

さて弊社のDjangoサーバーのAPIはテストで120%カバレッジするというという方針のもと、日々テストコードを実装コードの倍ぐらい書いていく作業をしています。そんな中で本日は「Helperクラスで実装したPublic関数をどうやってAPI経由でテストするか(そしてどこで詰まったか」についてざっくりと紹介します。

テストしたい実装内容

実装した関数は以下の2つです(投稿用に関数名などはダミーにしています、またsetting.pyやurls.pyは別途設定済みです)

app/views.py
from rest_framework.views import APIView
from app.helper import helper_foo
from app.models import HogehogeSerializer


class HogehogeListAPIView(APIView):
    """ Djangoフレームワークを使ったAPI ViewでPost機能を実装
    """
    permission_classes = (permissions.IsAuthenticated,)

    def post(self, request, format=None):
    """ API Post/hogehoge
       """
        serializer = HogehogeSerializer(data=request.data)
        if serializer.is_valid(): #Requestデータのチェック
            try:
                with transaction.atomic():
                    serializer.save()
                    helper_foo(serializer.data) # Helperの関数を呼び出し
            except Exception as error: 
                return Response(status=status.HTTP_400_BAD_REQUEST)
            return Response(serializer.data, status=status.HTTP_201_CREATED)
        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
app/helper.py
import boto3

def helper_foo(data):
    """Helper関数
    データを元にS3Bucketを作成する
    """
    client = boto3.client("s")

    """
    バケットを作成(失敗したらエラーを投げる
    """
  client.create_bucket(...

そしてpytestの内容がこちら

app/tests.py
from django.test import TestCase
from django.conf import settings

from rest_framework.test import APIClient

class HogehogeAPITest(TestCase):
    def setUp(self):
        self.client = APIClient()
        self.data = #テスト用データを作成

    def test_hogehoge_post(self):
        response = self.client.post('/hogehoge',
                                    self.data},
                                    format='json')
        self.assertEqual(response.status_code, 201)
        # 諸々Assertでチェック

テストの問題点

テスト内で実際にS3Bucketを作成するコードが実行されてしまい、テストを走らせるたびに余計なS3Bucketが増えていく。

対応策(Patch)でHelper関数をモック化する

UnittestのMockモジュールにあるpatchで以下のように関数をモック化できるということで実行

Targetはapp.helper.foo_helperを指定。

app/tests.py
from unittest.mock import patch
from django.test import TestCase
from django.conf import settings

from rest_framework.test import APIClient

class HogehogeAPITest(TestCase):
    def setUp(self):
        self.client = APIClient()

    @patch('app.helper.helper_foo')
    def test_hogehoge_post(self, mock_function):
        response = self.client.post('/hogehoge',
                                    self.data},
                                    format='json')
        self.assertEqual(response.status_code, 201)
        # 諸々Assertでチェック
        self.assertEqual(mock_function.call_count, 1) # Mock関数が一度呼ばれたか確認

実行するとテスト結果が失敗となり以下のメッセージが

self.assertEqual(mock_function.call_count, 1)
AssertionError: 0 != 1

あれ、関数がMockできていない?AWSのコンソールも調べるとS3Bucketも(残念ながら)ちゃんとできている。

どうしたか

色々調べてみるとPatchの参照ロジックは以下のようになっているとのこと

Now we want to test some_function but we want to mock out SomeClass using patch(). The problem is that when we import module b, which we will have to do then it imports SomeClass from module a. If we use patch() to mock out a.SomeClass then it will have no effect on our test; module b already has a reference to the real SomeClass and it looks like our patching had no effect. (引用元)

つまり
module bをmodule aにインポートした場合、module bはmodule aから参照されるため b.someclass(ここでいうhelper.helper_foo)で参照してもテストに影響しない。とのこと

ということで、先程のMockのTargetをapp.helper.foo_helperからapp.views.foo_helperに変更。

app/tests.py
from unittest.mock import patch
from django.test import TestCase
from django.conf import settings

from rest_framework.test import APIClient

class HogehogeAPITest(TestCase):
    def setUp(self):
        self.client = APIClient()

    @patch('app.views.helper_foo')
    def test_hogehoge_post(self, mock_function):
        response = self.client.post('/hogehoge',
                                    self.data},
                                    format='json')
        self.assertEqual(response.status_code, 201)
        # 諸々Assertでチェック
        self.assertEqual(mock_function.call_count, 1) # Mock関数が一度呼ばれたか確認

これでテストは無事動き、S3バケットも生成されないことが確認できた。

10年近くPythonを触っていたけれど、初めてImport時のModule参照の仕組みを理解できました。

以上です。Testコードって色々な発見がある分野で面白いのでぜひQiitaにもっとPytestの記事が増えればなぁと思っている今日この頃です。

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1