0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Django Rest Frameworkで、テストからエンドポイントごとにアクセスのあるテーブル一覧を出力する

Posted at

はじめに

APIのテストを流すだけで、エンドポイントごとにアクセスのあるテーブルを一覧で出力できたら影響確認のための資料作成が楽になりそうですよね。
Django Rest Frameworkで作成されたAPIをpytestでテストするとき、view層のテストで、Django ORMを通してDBアクセスのあったテーブル一覧を取得できたので方法を記事として残します。

環境

  • Python 3.13
  • Django 5.2
  • Django Rest Framework 3.16
  • pytest 8.4
  • MySQL 8.4

requirements.txt

依存ライブラリをまとめます。
私はMySQLで動かしたので、MySQL依存のものを記載します。適宜書き換えてください。
テスト系もまとめて記載します。

Django
djangorestframework
dj-database-url
django-mysql
mysqlclient
pytest
pytest-django
factory_boy
sqlglot

テーブル依存関係

テスト対象のコード

長いので折りたたみます。

GetCreateUserAPIView
class GetCreateUserAPIView(APIView):
    def get(self, request):
        group_with_users = Group.objects.prefetch_related(
            Prefetch(
                "groupuserrelation_set",
                queryset=GroupUserRelation.objects.select_related("user"),
                to_attr="group_user_relation",
            )
        )

        for g in group_with_users:
            g.users = [rel.user for rel in g.group_user_relation]

        return Response(
            UserListResponseSerializer(group_with_users, many=True).data,
            status=status.HTTP_200_OK,
        )

    def post(self, request):
        serializer = UserRegisterRequestSerializer(data=request.data)
        if not serializer.is_valid():
            return Response(status=status.HTTP_400_BAD_REQUEST)

        data = serializer.validated_data
        if User.objects.filter(code=data["code"]).first():
            return Response(status=status.HTTP_409_CONFLICT)

        group = Group.objects.filter(id=data["group_id"]).first()
        if not group:
            return Response(status=status.HTTP_404_NOT_FOUND)

        with transaction.atomic():
            user = User.objects.create(name=data["name"], code=data["code"])

            GroupUserRelation.objects.create(
                user=user,
                group=group,
            )

        return Response(status=status.HTTP_201_CREATED)
GetUserDetailAPIView
class GetUserDetailAPIView(APIView):
    def get(self, request, user_code: str):
        with_tag = unicodedata.normalize(
            "NFKC",
            request.query_params.get("with_tag", "false").lower(),
        ) in ("true", "y", "yes", "1")

        qs = User.objects.prefetch_related(
            Prefetch(
                "groupuserrelation_set",
                queryset=GroupUserRelation.objects.select_related("group"),
                to_attr="group_user_relation",
            )
        )

        if with_tag:
            qs = qs.prefetch_related(
                Prefetch(
                    "userfollowtag_set",
                    queryset=UserFollowTag.objects.select_related("tag"),
                    to_attr="user_follow_tag",
                )
            )

        user = qs.filter(code=user_code).first()

        if not user:
            return Response(status=status.HTTP_404_NOT_FOUND)

        user.groups = [rel.group for rel in user.group_user_relation]

        if with_tag:
            user.tags = [rel.tag for rel in user.user_follow_tag]

        return Response(
            UserDetailResponseSerializer(user).data,
            status=status.HTTP_200_OK,
        )
urls
urlpatterns = [
    path("users", GetCreateUserAPIView.as_view(), name="get_create_user"),
    path("users/<str:user_code>", GetUserDetailAPIView.as_view(), name="get_user_detail"),
]

テストコード

こちらも長いので折りたたみます。

TestGetCreateUserAPIView
import pytest
from rest_framework.test import APIClient


@pytest.mark.django_db
class TestGetCreateUserAPIView:
    def test_get(self):
        """取得処理正常系"""
        from .factories import GroupFactory, GroupUserRelationFactory, UserFactory

        user1 = UserFactory()
        user2 = UserFactory()
        user3 = UserFactory()

        group1 = GroupFactory()
        group2 = GroupFactory()

        GroupUserRelationFactory(user=user1, group=group1)
        GroupUserRelationFactory(user=user2, group=group1)
        GroupUserRelationFactory(user=user2, group=group2)
        GroupUserRelationFactory(user=user3, group=group2)

        res = APIClient().get("/users")
        assert res.status_code == 200
        assert res.json() == [
            {
                "id": group1.id,
                "name": group1.name,
                "users": [
                    {
                        "id": user1.id,
                        "name": user1.name,
                        "code": user1.code,
                    },
                    {
                        "id": user2.id,
                        "name": user2.name,
                        "code": user2.code,
                    },
                ],
            },
            {
                "id": group2.id,
                "name": group2.name,
                "users": [
                    {
                        "id": user2.id,
                        "name": user2.name,
                        "code": user2.code,
                    },
                    {
                        "id": user3.id,
                        "name": user3.name,
                        "code": user3.code,
                    },
                ],
            },
        ]

    def test_post(self):
        """登録処理正常系"""
        from .factories import GroupFactory

        group = GroupFactory()

        res = APIClient().post(
            "/users",
            {"name": "名前", "code": "user_code", "group_id": group.id},
            format="json",
        )
        assert res.status_code == 201

    def test_post_failed_by_request_invalid(self):
        """リクエスト形式が不正"""
        res = APIClient().post(
            "/users",
            {"name": "名前", "code": "user_code"},
            format="json",
        )
        assert res.status_code == 400

    def test_post_failed_by_duplicate_user_code(self):
        """ユーザ.コードがすでに存在する"""
        from .factories import UserFactory

        UserFactory(code="user_code")

        res = APIClient().post(
            "/users",
            {"name": "名前", "code": "user_code", "group_id": 1},
            format="json",
        )
        assert res.status_code == 409

    def test_post_failed_by_not_found_group(self):
        """リクエスト.グループIDが存在しない"""
        res = APIClient().post(
            "/users",
            {"name": "名前", "code": "user_code", "group_id": 1},
            format="json",
        )
        assert res.status_code == 404
TestGetUserDetailAPIView
@pytest.mark.django_db
class TestGetUserDetailAPIView:
    def test_get(self):
        """取得タグなし正常系"""
        from .factories import GroupFactory, GroupUserRelationFactory, UserFactory

        user1 = UserFactory()

        group1 = GroupFactory()
        group2 = GroupFactory()

        GroupUserRelationFactory(user=user1, group=group1)
        GroupUserRelationFactory(user=user1, group=group2)

        res = APIClient().get(f"/users/{user1.code}")
        assert res.status_code == 200
        assert res.json() == {
            "id": user1.id,
            "name": user1.name,
            "code": user1.code,
            "groups": [
                {
                    "id": group1.id,
                    "name": group1.name,
                },
                {
                    "id": group2.id,
                    "name": group2.name,
                },
            ],
        }

    def test_get_with_tag(self):
        """取得タグあり正常系"""
        from .factories import GroupFactory, GroupUserRelationFactory, TagFactory, UserFactory, UserFollowTagFactory

        user1 = UserFactory()

        group1 = GroupFactory()
        group2 = GroupFactory()

        GroupUserRelationFactory(user=user1, group=group1)
        GroupUserRelationFactory(user=user1, group=group2)

        tag1 = TagFactory()
        tag2 = TagFactory()
        tag3 = TagFactory()

        UserFollowTagFactory(user=user1, tag=tag1)
        UserFollowTagFactory(user=user1, tag=tag2)
        UserFollowTagFactory(user=user1, tag=tag3)

        res = APIClient().get(f"/users/{user1.code}?with_tag=true")
        assert res.status_code == 200
        assert res.json() == {
            "id": user1.id,
            "name": user1.name,
            "code": user1.code,
            "groups": [
                {
                    "id": group1.id,
                    "name": group1.name,
                },
                {
                    "id": group2.id,
                    "name": group2.name,
                },
            ],
            "tags": [
                {
                    "id": tag1.id,
                    "name": tag1.name,
                },
                {
                    "id": tag2.id,
                    "name": tag2.name,
                },
                {
                    "id": tag3.id,
                    "name": tag3.name,
                },
            ],
        }

エンドポイントごとのDBアクセスと操作

上記コードをまとめると下記の通りに操作しています。

  • GET /users
テーブル 操作
group READ
user READ
group_user_relation READ
  • POST /users
テーブル 操作
group READ
user CREATE, READ
group_user_relation CREATE
  • GET /users/{user_code}?with_tag=(true/false)

tag, user_follow_tagはクエリパラメータに応じて取得する/しないが変わります。

テーブル 操作 require
group READ true
user READ true
group_user_relation READ true
tag READ false
user_follow_tag READ false

方針

  • テストからエンドポイント/HTTPメソッド単位でのDBアクセスがあるテーブルを取得する
    • views層のテストで、DBアクセスをモンキーパッチしてクエリを取得
    • 取得したクエリからアクセスされたテーブルを抜き出す
  • views層以外の層に対してもテストがあるものとする
  • views層のテストは rest_framework.test.APIClient を用いる

実装

conftest でDBアクセスのモンキーパッチとクエリ解析、結果の出力まで行います。

クエリの取得

DBアクセスのモンキーパッチと、クエリを取得する処理をまとめます。
DBアクセスのモンキーパッチは データベースの計測 を参考に行います。

QueryLogger

クエリはプレースホルダー化されたSQLと与えられるパラメータが与えられるので、そこからクエリを生成して queries に格納します。
パラメータがない場合、Noneが与えられるためガードを挟みます。
クエリはプレースホルダー化されているため、sql引数をそのままSQLパーサーに通すとエラーになってしまうため、プレースホルダーに実レコードを入れます。
クエリのプレースホルダーは ?:param 形式ではなく、 %s の形式で入ってきますので、 sql % param でそのままぶちこめます。

django.db.transaction.atomic では SAVEPOINT を含むクエリが入るため、それら関係ないクエリは EXCLUDED_PREFIXES で明示的に除いています。

pytestはwith句内で処理を行いたい場合、yieldで実行できます。
Use fixtures in classes and modules withにあります。

class QueryLogger:
    EXCLUDED_PREFIXES = ("BEGIN", "SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT", "COMMIT")

    def __init__(self):
        self.queries = []

    def __call__(self, execute, sql, params, many, context):
        if sql.strip().upper().startswith(self.EXCLUDED_PREFIXES):
            return execute(sql, params, many, context)

        format_query = sql
        if params:
            format_query = sql % tuple(repr(p) for p in params)
        self.queries.append(format_query)

        return execute(sql, params, many, context)

    def reset(self):
        self.queries.clear()

@pytest.fixture(autouse=True)
def collect_query_by_endpoint():
    ql = QueryLogger()
    with connection.execute_wrapper(ql):
        yield

これでテストごとに発行されるクエリが取れるようになりました。

views層のテストのみに絞る

@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
    ql = QueryLogger()
    orig_generic = APIClient.generic

    def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
        try:
            m = resolve(path)
            return (method, m.url_name)
        except Resolver404:  # エンドポイント未解決404ケア
            pass
        return (method, path)

    def wrapped_generic(self, method, path, *args, **kwargs):
        ql.reset()  # 明示的にリセットする
        resp = orig_generic(self, method, path, *args, **kwargs)
        req = getattr(resp, "wsgi_request", None)
        if req and getattr(req, "resolver_match", None):
            url_name = req.resolver_match.url_name or req.path_info
            ep_key = (method.upper(), url_name)
        else:
            ep_key = endpoint_key_from(method.upper(), path)
        # TODO: 後ほどまとめる
        (ep_key, ql)  # ((HTTPメソッド, URL名), クエリーロガー)
        return resp

    monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)

    with connection.execute_wrapper(ql):
        yield

APIClient.generic

下記より APIClient.generic をモンキーパッチしてあげればエンドポイントごとのクエリーロガーをまとめられそうです。

monckypatch

下記を読んでください。

モンキーパッチで呼び出すメソッドではAPIClientも想定通り振る舞ってほしいため、内部でオリジナルを呼び出してミドルウェアやデコレータのような振る舞いをするようにします。

orig_generic = APIClient.generic
resp = orig_generic(self, method, path, *args, **kwargs)

def wrapped_generic(self, method, path, *args, **kwargs):
  resp = orig_generic(self, method, path, *args, **kwargs)
  ...
  return resp

monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)

endpoint_key_from

DRFのresolveにパスを渡すと、解決できればdjango.urls.ResolverMatchが手に入ります。
ここからメソッド名とURL名を取得してタプルにします。
パスそのままではなくURL名にしている理由は、パスパラメータやクエリパラメータケアです。
タプルである理由は後述。

def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
    try:
        m = resolve(path)
        return (method, m.url_name)
    except Resolver404:  # エンドポイント未解決404ケア
        pass
    return (method, path)

取得したクエリからエンドポイントごとのテーブルを取得

class Aggregator:
    def __init__(self, outdir: pathlib.Path):
        self.data: dict[tuple[str, str], set[str]] = {}
        self.outdir = outdir
        self.url_name_pattern_map = {}

        targets = [url for url in get_resolver().url_patterns if url.app_name not in ["admin"]]
        self.dfs_urls(targets)

    def merge(self, endpoint_key: tuple[str, str], ql: QueryLogger):
        for query in ql.queries:
            ast = sqlglot.parse_one(query, read=settings.DB_ENGINE)
            self.data.setdefault(endpoint_key, set()).update({table.name for table in ast.find_all(exp.Table)})

    def write_file(self):
        self.outdir.mkdir(parents=True, exist_ok=True)
        result = []
        for ep, table_ops in sorted(self.data.items()):
            method, url_name = ep
            url_pattern = self.url_name_pattern_map.get(url_name, url_name)
            result.append(
                {
                    "method": method,
                    "url_pattern": url_pattern,
                    "tables": list(table_ops),
                }
            )
        with open(self.outdir / "data.json", "w") as f:
            f.write(
                json.dumps(
                    result,
                    indent=2,
                    ensure_ascii=False,
                )
            )

    def dfs_urls(self, url_patterns, prefix=""):
        for url in url_patterns:
            if isinstance(url, URLResolver):
                # ネストしたResolverのprefixも含める
                self.dfs_urls(url.url_patterns, prefix + str(url.pattern))
            elif isinstance(url, URLPattern):
                if url.name:  # 名前付きのものだけ
                    self.url_name_pattern_map[url.name] = prefix + str(url.pattern)


HERE = pathlib.Path(__file__).parent


@pytest.fixture(scope="session", autouse=True)
def write_endpoint_reports_after_session(request):
    outdir = HERE / "docs"
    shutil.rmtree(outdir, ignore_errors=True)
    aggregator = Aggregator(outdir=outdir)
    request.config._aggregator = aggregator
    yield
    aggregator.write_file()


@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
    aggregator: Aggregator = request.config._aggregator
    ql = QueryLogger()
    orig_generic = APIClient.generic

    def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
        ...

    def wrapped_generic(self, method, path, *args, **kwargs):
        ql.reset()
        ...
        if req and getattr(req, "resolver_match", None):
            url_name = req.resolver_match.url_name or req.path_info
            ep_key = (method.upper(), url_name)
        else:
            ep_key = endpoint_key_from(method.upper(), path)
        aggregator.merge(ep_key, ql)  # 変更箇所
        return resp

    monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)

    with connection.execute_wrapper(ql):
        yield

sqlglot.parse

SQLGlotでクエリをパースして、テーブルを取得しています。
処理については下記参照。

for query in ql.queries:
  ast = sqlglot.parse_one(query, read=settings.DB_ENGINE)

Aggregator.merge

Aggregatorを作成して、エンドポイント/HTTPメソッドごとにテーブルをまとめます。
(HTTPメソッド, URL名)ごとに辞書を作成したいため、タプルをキーにしています。タプルである理由はこれです。

def merge(self, endpoint_key: tuple[str, str], ql: QueryLogger):
  for query in ql.queries:
    ...
    self.data.setdefault(endpoint_key, set()).update({table})

aggregator.merge(ep_key, ql)

write_endpoint_reports_after_session

Aggregatorはテストが行われる前にインスタンス化し、すべてのテストが完了したあとにファイル書き込みを行います。
そのため、 @pytest.fixture(scope="session", autouse=True) としています。
シングルトン実現のために、request.config._aggregator にインスタンスを格納しています。これは雑です。

@pytest.fixture(scope="session", autouse=True)
def write_endpoint_reports_after_session(request):
    aggregator = Aggregator()
    request.config._aggregator = aggregator
    yield
    aggregator.write_file()


@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
    aggregator: Aggregator = request.config._aggregator
    ...

dfs_urls

URL名でタプルを作成しているため、URL名と実エンドポイントのマップを作成します。
下記記事参照。

実装全体

ここまでのコードはわかりやすさ優先でコードを省略していますので、今一度コード全体を記載します。

query_logger.py
import json
import pathlib

import sqlglot
from django.conf import settings
from django.urls import URLPattern, URLResolver, get_resolver
from sqlglot import exp


class QueryLogger:
    EXCLUDED_PREFIXES = ("BEGIN", "SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT", "COMMIT")

    def __init__(self):
        self.queries = []

    def __call__(self, execute, sql, params, many, context):
        if sql.strip().upper().startswith(self.EXCLUDED_PREFIXES):
            return execute(sql, params, many, context)

        format_query = sql
        if params:
            format_query = sql % tuple(repr(p) for p in params)
        self.queries.append(format_query)

        return execute(sql, params, many, context)

    def reset(self):
        self.queries.clear()


class Aggregator:
    def __init__(self, outdir: pathlib.Path):
        self.data: dict[tuple[str, str], set[str]] = {}
        self.outdir = outdir
        self.url_name_pattern_map = {}

        targets = [url for url in get_resolver().url_patterns if url.app_name not in ["admin"]]
        self.dfs_urls(targets)

    def merge(self, endpoint_key: tuple[str, str], ql: QueryLogger):
        for query in ql.queries:
            ast = sqlglot.parse_one(query, read=settings.DB_ENGINE)
            self.data.setdefault(endpoint_key, set()).update({table.name for table in ast.find_all(exp.Table)})

    def write_file(self):
        self.outdir.mkdir(parents=True, exist_ok=True)
        result = []
        for ep, table_ops in sorted(self.data.items()):
            method, url_name = ep
            url_pattern = self.url_name_pattern_map.get(url_name, url_name)
            result.append(
                {
                    "method": method,
                    "url_pattern": url_pattern,
                    "tables": list(table_ops),
                }
            )
        with open(self.outdir / "data.json", "w") as f:
            f.write(
                json.dumps(
                    result,
                    indent=2,
                    ensure_ascii=False,
                )
            )

    def dfs_urls(self, url_patterns, prefix=""):
        for url in url_patterns:
            if isinstance(url, URLResolver):
                # ネストしたResolverのprefixも含める
                self.dfs_urls(url.url_patterns, prefix + str(url.pattern))
            elif isinstance(url, URLPattern):
                if url.name:  # 名前付きのものだけ
                    self.url_name_pattern_map[url.name] = prefix + str(url.pattern)
conftest.py
import pathlib
import shutil

import pytest
from django.db import connection
from django.urls import Resolver404, resolve
from rest_framework.test import APIClient

from tests.test_tool.query_logger import Aggregator, QueryLogger

HERE = pathlib.Path(__file__).parent


@pytest.fixture(scope="session", autouse=True)
def write_endpoint_reports_after_session(request):
    outdir = HERE / "docs"
    shutil.rmtree(outdir, ignore_errors=True)
    aggregator = Aggregator(outdir=outdir)
    request.config._aggregator = aggregator
    yield
    aggregator.write_file()


@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
    aggregator: Aggregator = request.config._aggregator
    ql = QueryLogger()
    orig_generic = APIClient.generic

    def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
        try:
            m = resolve(path)
            return (method, m.url_name)
        except Resolver404:  # エンドポイント未解決404ケア
            pass
        return (method, path)

    def wrapped_generic(self, method, path, *args, **kwargs):
        ql.reset()
        resp = orig_generic(self, method, path, *args, **kwargs)
        req = getattr(resp, "wsgi_request", None)
        if req and getattr(req, "resolver_match", None):
            url_name = req.resolver_match.url_name or req.path_info
            ep_key = (method.upper(), url_name)
        else:
            ep_key = endpoint_key_from(method.upper(), path)
        aggregator.merge(ep_key, ql)
        return resp

    monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)

    with connection.execute_wrapper(ql):
        yield

動かしてみる

テストを流してみます。

pytest -vv

すると、 conftest.py のある tests/ 配下に docs/data.json が出来上がっているはずです。
みてみると......

data.json
[
  {
    "method": "GET",
    "url_pattern": "users",
    "tables": [
      "user",
      "group",
      "group_user_relation"
    ]
  },
  {
    "method": "GET",
    "url_pattern": "users/<str:user_code>",
    "tables": [
      "group",
      "group_user_relation",
      "user",
      "tag",
      "user_follow_tag"
    ]
  },
  {
    "method": "POST",
    "url_pattern": "users",
    "tables": [
      "user",
      "group",
      "group_user_relation"
    ]
  }
]

できてますね!

TestGetUserDetailAPIView の下記部分をコメントアウトしてみます。
タグ取得処理がテスト上通らなくなるので、 users/<str:user_code> からは taguser_follow_tag が消えるはずです。

@pytest.mark.django_db
class TestGetUserDetailAPIView:
    def test_get(self):
        """取得タグなし正常系"""
        ...

#    def test_get_with_tag(self):
#        """取得タグあり正常系"""
#        ...

これで実行すると...

data.json
[
  {
    "method": "GET",
    "url_pattern": "users",
    "tables": [
      "user",
      "group_user_relation",
      "group"
    ]
  },
  {
    "method": "GET",
    "url_pattern": "users/<str:user_code>",
    "tables": [
      "group",
      "group_user_relation",
      "user"
    ]
  },
  {
    "method": "POST",
    "url_pattern": "users",
    "tables": [
      "group",
      "group_user_relation",
      "user"
    ]
  }
]

想定通り消えました!

終わりに

DRFでviews層のテストを実行すればエンドポイントごとにアクセスのあるテーブル一覧を取得することができました!
これで資料作成と影響確認が少し楽になりそうです!

しかし、これには罠があり、views層のテストを内部的な処理を意識して書く必要が出てきます。
views層のテストは可能ならブラックボックステストの観点で行いたいですが、条件に応じてテーブルの結合が変わる場合などでは取得ができなくなってしまうため、ホワイトボックステストの観点でviews層を書く必要が出てきます。
それだけviews層のテストは分厚くもなります。
果たしてこれが嬉しいのか。
実際に運用してみての所感は、テストの堅牢性が増すし資料作成も楽になるしでviews層をホワイトボックステスト観点で書く方に軍配が上がっています。

発展として、クエリが取れていてクエリパースもできているので、CRUD図なんかも作れそうな気がします。
ただそれを記事にするには、記事が長くなりすぎてしまうので今回はここまで。

image.png

以上、Django Rest Frameworkで、テストからエンドポイントごとにアクセスのあるテーブル一覧を出力するための実現案でした。
おわりだよー(o・∇・o)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?