はじめに
APIのテストを流すだけで、エンドポイントごとにアクセスのあるテーブルを一覧で出力できたら影響確認のための資料作成が楽になりそうですよね。
Django Rest Frameworkで作成されたAPIをpytestでテストするとき、view層のテストで、Django ORMを通してDBアクセスのあったテーブル一覧を取得できたので方法を記事として残します。
環境
- Python 3.13
- Django 5.2
- Django Rest Framework 3.16
- pytest 8.4
- MySQL 8.4
requirements.txt
依存ライブラリをまとめます。
私はMySQLで動かしたので、MySQL依存のものを記載します。適宜書き換えてください。
テスト系もまとめて記載します。
Django
djangorestframework
dj-database-url
django-mysql
mysqlclient
pytest
pytest-django
factory_boy
sqlglot
テーブル依存関係
テスト対象のコード
長いので折りたたみます。
GetCreateUserAPIView
class GetCreateUserAPIView(APIView):
def get(self, request):
group_with_users = Group.objects.prefetch_related(
Prefetch(
"groupuserrelation_set",
queryset=GroupUserRelation.objects.select_related("user"),
to_attr="group_user_relation",
)
)
for g in group_with_users:
g.users = [rel.user for rel in g.group_user_relation]
return Response(
UserListResponseSerializer(group_with_users, many=True).data,
status=status.HTTP_200_OK,
)
def post(self, request):
serializer = UserRegisterRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(status=status.HTTP_400_BAD_REQUEST)
data = serializer.validated_data
if User.objects.filter(code=data["code"]).first():
return Response(status=status.HTTP_409_CONFLICT)
group = Group.objects.filter(id=data["group_id"]).first()
if not group:
return Response(status=status.HTTP_404_NOT_FOUND)
with transaction.atomic():
user = User.objects.create(name=data["name"], code=data["code"])
GroupUserRelation.objects.create(
user=user,
group=group,
)
return Response(status=status.HTTP_201_CREATED)
GetUserDetailAPIView
class GetUserDetailAPIView(APIView):
def get(self, request, user_code: str):
with_tag = unicodedata.normalize(
"NFKC",
request.query_params.get("with_tag", "false").lower(),
) in ("true", "y", "yes", "1")
qs = User.objects.prefetch_related(
Prefetch(
"groupuserrelation_set",
queryset=GroupUserRelation.objects.select_related("group"),
to_attr="group_user_relation",
)
)
if with_tag:
qs = qs.prefetch_related(
Prefetch(
"userfollowtag_set",
queryset=UserFollowTag.objects.select_related("tag"),
to_attr="user_follow_tag",
)
)
user = qs.filter(code=user_code).first()
if not user:
return Response(status=status.HTTP_404_NOT_FOUND)
user.groups = [rel.group for rel in user.group_user_relation]
if with_tag:
user.tags = [rel.tag for rel in user.user_follow_tag]
return Response(
UserDetailResponseSerializer(user).data,
status=status.HTTP_200_OK,
)
urls
urlpatterns = [
path("users", GetCreateUserAPIView.as_view(), name="get_create_user"),
path("users/<str:user_code>", GetUserDetailAPIView.as_view(), name="get_user_detail"),
]
テストコード
こちらも長いので折りたたみます。
TestGetCreateUserAPIView
import pytest
from rest_framework.test import APIClient
@pytest.mark.django_db
class TestGetCreateUserAPIView:
def test_get(self):
"""取得処理正常系"""
from .factories import GroupFactory, GroupUserRelationFactory, UserFactory
user1 = UserFactory()
user2 = UserFactory()
user3 = UserFactory()
group1 = GroupFactory()
group2 = GroupFactory()
GroupUserRelationFactory(user=user1, group=group1)
GroupUserRelationFactory(user=user2, group=group1)
GroupUserRelationFactory(user=user2, group=group2)
GroupUserRelationFactory(user=user3, group=group2)
res = APIClient().get("/users")
assert res.status_code == 200
assert res.json() == [
{
"id": group1.id,
"name": group1.name,
"users": [
{
"id": user1.id,
"name": user1.name,
"code": user1.code,
},
{
"id": user2.id,
"name": user2.name,
"code": user2.code,
},
],
},
{
"id": group2.id,
"name": group2.name,
"users": [
{
"id": user2.id,
"name": user2.name,
"code": user2.code,
},
{
"id": user3.id,
"name": user3.name,
"code": user3.code,
},
],
},
]
def test_post(self):
"""登録処理正常系"""
from .factories import GroupFactory
group = GroupFactory()
res = APIClient().post(
"/users",
{"name": "名前", "code": "user_code", "group_id": group.id},
format="json",
)
assert res.status_code == 201
def test_post_failed_by_request_invalid(self):
"""リクエスト形式が不正"""
res = APIClient().post(
"/users",
{"name": "名前", "code": "user_code"},
format="json",
)
assert res.status_code == 400
def test_post_failed_by_duplicate_user_code(self):
"""ユーザ.コードがすでに存在する"""
from .factories import UserFactory
UserFactory(code="user_code")
res = APIClient().post(
"/users",
{"name": "名前", "code": "user_code", "group_id": 1},
format="json",
)
assert res.status_code == 409
def test_post_failed_by_not_found_group(self):
"""リクエスト.グループIDが存在しない"""
res = APIClient().post(
"/users",
{"name": "名前", "code": "user_code", "group_id": 1},
format="json",
)
assert res.status_code == 404
TestGetUserDetailAPIView
@pytest.mark.django_db
class TestGetUserDetailAPIView:
def test_get(self):
"""取得タグなし正常系"""
from .factories import GroupFactory, GroupUserRelationFactory, UserFactory
user1 = UserFactory()
group1 = GroupFactory()
group2 = GroupFactory()
GroupUserRelationFactory(user=user1, group=group1)
GroupUserRelationFactory(user=user1, group=group2)
res = APIClient().get(f"/users/{user1.code}")
assert res.status_code == 200
assert res.json() == {
"id": user1.id,
"name": user1.name,
"code": user1.code,
"groups": [
{
"id": group1.id,
"name": group1.name,
},
{
"id": group2.id,
"name": group2.name,
},
],
}
def test_get_with_tag(self):
"""取得タグあり正常系"""
from .factories import GroupFactory, GroupUserRelationFactory, TagFactory, UserFactory, UserFollowTagFactory
user1 = UserFactory()
group1 = GroupFactory()
group2 = GroupFactory()
GroupUserRelationFactory(user=user1, group=group1)
GroupUserRelationFactory(user=user1, group=group2)
tag1 = TagFactory()
tag2 = TagFactory()
tag3 = TagFactory()
UserFollowTagFactory(user=user1, tag=tag1)
UserFollowTagFactory(user=user1, tag=tag2)
UserFollowTagFactory(user=user1, tag=tag3)
res = APIClient().get(f"/users/{user1.code}?with_tag=true")
assert res.status_code == 200
assert res.json() == {
"id": user1.id,
"name": user1.name,
"code": user1.code,
"groups": [
{
"id": group1.id,
"name": group1.name,
},
{
"id": group2.id,
"name": group2.name,
},
],
"tags": [
{
"id": tag1.id,
"name": tag1.name,
},
{
"id": tag2.id,
"name": tag2.name,
},
{
"id": tag3.id,
"name": tag3.name,
},
],
}
エンドポイントごとのDBアクセスと操作
上記コードをまとめると下記の通りに操作しています。
- GET
/users
テーブル | 操作 |
---|---|
group | READ |
user | READ |
group_user_relation | READ |
- POST
/users
テーブル | 操作 |
---|---|
group | READ |
user | CREATE, READ |
group_user_relation | CREATE |
- GET
/users/{user_code}?with_tag=(true/false)
tag, user_follow_tagはクエリパラメータに応じて取得する/しないが変わります。
テーブル | 操作 | require |
---|---|---|
group | READ | true |
user | READ | true |
group_user_relation | READ | true |
tag | READ | false |
user_follow_tag | READ | false |
方針
- テストからエンドポイント/HTTPメソッド単位でのDBアクセスがあるテーブルを取得する
- views層のテストで、DBアクセスをモンキーパッチしてクエリを取得
- 取得したクエリからアクセスされたテーブルを抜き出す
- views層以外の層に対してもテストがあるものとする
- views層のテストは rest_framework.test.APIClient を用いる
実装
conftest でDBアクセスのモンキーパッチとクエリ解析、結果の出力まで行います。
クエリの取得
DBアクセスのモンキーパッチと、クエリを取得する処理をまとめます。
DBアクセスのモンキーパッチは データベースの計測 を参考に行います。
QueryLogger
クエリはプレースホルダー化されたSQLと与えられるパラメータが与えられるので、そこからクエリを生成して queries
に格納します。
パラメータがない場合、Noneが与えられるためガードを挟みます。
クエリはプレースホルダー化されているため、sql引数をそのままSQLパーサーに通すとエラーになってしまうため、プレースホルダーに実レコードを入れます。
クエリのプレースホルダーは ?
や :param
形式ではなく、 %s
の形式で入ってきますので、 sql % param
でそのままぶちこめます。
django.db.transaction.atomic
では SAVEPOINT
を含むクエリが入るため、それら関係ないクエリは EXCLUDED_PREFIXES
で明示的に除いています。
pytestはwith句内で処理を行いたい場合、yieldで実行できます。
Use fixtures in classes and modules withにあります。
class QueryLogger:
EXCLUDED_PREFIXES = ("BEGIN", "SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT", "COMMIT")
def __init__(self):
self.queries = []
def __call__(self, execute, sql, params, many, context):
if sql.strip().upper().startswith(self.EXCLUDED_PREFIXES):
return execute(sql, params, many, context)
format_query = sql
if params:
format_query = sql % tuple(repr(p) for p in params)
self.queries.append(format_query)
return execute(sql, params, many, context)
def reset(self):
self.queries.clear()
@pytest.fixture(autouse=True)
def collect_query_by_endpoint():
ql = QueryLogger()
with connection.execute_wrapper(ql):
yield
これでテストごとに発行されるクエリが取れるようになりました。
views層のテストのみに絞る
@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
ql = QueryLogger()
orig_generic = APIClient.generic
def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
try:
m = resolve(path)
return (method, m.url_name)
except Resolver404: # エンドポイント未解決404ケア
pass
return (method, path)
def wrapped_generic(self, method, path, *args, **kwargs):
ql.reset() # 明示的にリセットする
resp = orig_generic(self, method, path, *args, **kwargs)
req = getattr(resp, "wsgi_request", None)
if req and getattr(req, "resolver_match", None):
url_name = req.resolver_match.url_name or req.path_info
ep_key = (method.upper(), url_name)
else:
ep_key = endpoint_key_from(method.upper(), path)
# TODO: 後ほどまとめる
(ep_key, ql) # ((HTTPメソッド, URL名), クエリーロガー)
return resp
monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)
with connection.execute_wrapper(ql):
yield
APIClient.generic
下記より APIClient.generic
をモンキーパッチしてあげればエンドポイントごとのクエリーロガーをまとめられそうです。
- https://github.com/encode/django-rest-framework/blob/main/rest_framework/test.py#L291-L335
- https://github.com/encode/django-rest-framework/blob/main/rest_framework/test.py#L198-L229
monckypatch
下記を読んでください。
モンキーパッチで呼び出すメソッドではAPIClientも想定通り振る舞ってほしいため、内部でオリジナルを呼び出してミドルウェアやデコレータのような振る舞いをするようにします。
orig_generic = APIClient.generic
resp = orig_generic(self, method, path, *args, **kwargs)
def wrapped_generic(self, method, path, *args, **kwargs):
resp = orig_generic(self, method, path, *args, **kwargs)
...
return resp
monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)
endpoint_key_from
DRFのresolveにパスを渡すと、解決できればdjango.urls.ResolverMatchが手に入ります。
ここからメソッド名とURL名を取得してタプルにします。
パスそのままではなくURL名にしている理由は、パスパラメータやクエリパラメータケアです。
タプルである理由は後述。
def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
try:
m = resolve(path)
return (method, m.url_name)
except Resolver404: # エンドポイント未解決404ケア
pass
return (method, path)
取得したクエリからエンドポイントごとのテーブルを取得
class Aggregator:
def __init__(self, outdir: pathlib.Path):
self.data: dict[tuple[str, str], set[str]] = {}
self.outdir = outdir
self.url_name_pattern_map = {}
targets = [url for url in get_resolver().url_patterns if url.app_name not in ["admin"]]
self.dfs_urls(targets)
def merge(self, endpoint_key: tuple[str, str], ql: QueryLogger):
for query in ql.queries:
ast = sqlglot.parse_one(query, read=settings.DB_ENGINE)
self.data.setdefault(endpoint_key, set()).update({table.name for table in ast.find_all(exp.Table)})
def write_file(self):
self.outdir.mkdir(parents=True, exist_ok=True)
result = []
for ep, table_ops in sorted(self.data.items()):
method, url_name = ep
url_pattern = self.url_name_pattern_map.get(url_name, url_name)
result.append(
{
"method": method,
"url_pattern": url_pattern,
"tables": list(table_ops),
}
)
with open(self.outdir / "data.json", "w") as f:
f.write(
json.dumps(
result,
indent=2,
ensure_ascii=False,
)
)
def dfs_urls(self, url_patterns, prefix=""):
for url in url_patterns:
if isinstance(url, URLResolver):
# ネストしたResolverのprefixも含める
self.dfs_urls(url.url_patterns, prefix + str(url.pattern))
elif isinstance(url, URLPattern):
if url.name: # 名前付きのものだけ
self.url_name_pattern_map[url.name] = prefix + str(url.pattern)
HERE = pathlib.Path(__file__).parent
@pytest.fixture(scope="session", autouse=True)
def write_endpoint_reports_after_session(request):
outdir = HERE / "docs"
shutil.rmtree(outdir, ignore_errors=True)
aggregator = Aggregator(outdir=outdir)
request.config._aggregator = aggregator
yield
aggregator.write_file()
@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
aggregator: Aggregator = request.config._aggregator
ql = QueryLogger()
orig_generic = APIClient.generic
def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
...
def wrapped_generic(self, method, path, *args, **kwargs):
ql.reset()
...
if req and getattr(req, "resolver_match", None):
url_name = req.resolver_match.url_name or req.path_info
ep_key = (method.upper(), url_name)
else:
ep_key = endpoint_key_from(method.upper(), path)
aggregator.merge(ep_key, ql) # 変更箇所
return resp
monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)
with connection.execute_wrapper(ql):
yield
sqlglot.parse
SQLGlotでクエリをパースして、テーブルを取得しています。
処理については下記参照。
for query in ql.queries:
ast = sqlglot.parse_one(query, read=settings.DB_ENGINE)
Aggregator.merge
Aggregatorを作成して、エンドポイント/HTTPメソッドごとにテーブルをまとめます。
(HTTPメソッド, URL名)ごとに辞書を作成したいため、タプルをキーにしています。タプルである理由はこれです。
def merge(self, endpoint_key: tuple[str, str], ql: QueryLogger):
for query in ql.queries:
...
self.data.setdefault(endpoint_key, set()).update({table})
aggregator.merge(ep_key, ql)
write_endpoint_reports_after_session
Aggregatorはテストが行われる前にインスタンス化し、すべてのテストが完了したあとにファイル書き込みを行います。
そのため、 @pytest.fixture(scope="session", autouse=True)
としています。
シングルトン実現のために、request.config._aggregator
にインスタンスを格納しています。これは雑です。
@pytest.fixture(scope="session", autouse=True)
def write_endpoint_reports_after_session(request):
aggregator = Aggregator()
request.config._aggregator = aggregator
yield
aggregator.write_file()
@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
aggregator: Aggregator = request.config._aggregator
...
dfs_urls
URL名でタプルを作成しているため、URL名と実エンドポイントのマップを作成します。
下記記事参照。
実装全体
ここまでのコードはわかりやすさ優先でコードを省略していますので、今一度コード全体を記載します。
import json
import pathlib
import sqlglot
from django.conf import settings
from django.urls import URLPattern, URLResolver, get_resolver
from sqlglot import exp
class QueryLogger:
EXCLUDED_PREFIXES = ("BEGIN", "SAVEPOINT", "RELEASE SAVEPOINT", "ROLLBACK TO SAVEPOINT", "COMMIT")
def __init__(self):
self.queries = []
def __call__(self, execute, sql, params, many, context):
if sql.strip().upper().startswith(self.EXCLUDED_PREFIXES):
return execute(sql, params, many, context)
format_query = sql
if params:
format_query = sql % tuple(repr(p) for p in params)
self.queries.append(format_query)
return execute(sql, params, many, context)
def reset(self):
self.queries.clear()
class Aggregator:
def __init__(self, outdir: pathlib.Path):
self.data: dict[tuple[str, str], set[str]] = {}
self.outdir = outdir
self.url_name_pattern_map = {}
targets = [url for url in get_resolver().url_patterns if url.app_name not in ["admin"]]
self.dfs_urls(targets)
def merge(self, endpoint_key: tuple[str, str], ql: QueryLogger):
for query in ql.queries:
ast = sqlglot.parse_one(query, read=settings.DB_ENGINE)
self.data.setdefault(endpoint_key, set()).update({table.name for table in ast.find_all(exp.Table)})
def write_file(self):
self.outdir.mkdir(parents=True, exist_ok=True)
result = []
for ep, table_ops in sorted(self.data.items()):
method, url_name = ep
url_pattern = self.url_name_pattern_map.get(url_name, url_name)
result.append(
{
"method": method,
"url_pattern": url_pattern,
"tables": list(table_ops),
}
)
with open(self.outdir / "data.json", "w") as f:
f.write(
json.dumps(
result,
indent=2,
ensure_ascii=False,
)
)
def dfs_urls(self, url_patterns, prefix=""):
for url in url_patterns:
if isinstance(url, URLResolver):
# ネストしたResolverのprefixも含める
self.dfs_urls(url.url_patterns, prefix + str(url.pattern))
elif isinstance(url, URLPattern):
if url.name: # 名前付きのものだけ
self.url_name_pattern_map[url.name] = prefix + str(url.pattern)
import pathlib
import shutil
import pytest
from django.db import connection
from django.urls import Resolver404, resolve
from rest_framework.test import APIClient
from tests.test_tool.query_logger import Aggregator, QueryLogger
HERE = pathlib.Path(__file__).parent
@pytest.fixture(scope="session", autouse=True)
def write_endpoint_reports_after_session(request):
outdir = HERE / "docs"
shutil.rmtree(outdir, ignore_errors=True)
aggregator = Aggregator(outdir=outdir)
request.config._aggregator = aggregator
yield
aggregator.write_file()
@pytest.fixture(autouse=True)
def collect_query_by_endpoint(request, monkeypatch):
aggregator: Aggregator = request.config._aggregator
ql = QueryLogger()
orig_generic = APIClient.generic
def endpoint_key_from(method: str, path: str) -> tuple[str, str]:
try:
m = resolve(path)
return (method, m.url_name)
except Resolver404: # エンドポイント未解決404ケア
pass
return (method, path)
def wrapped_generic(self, method, path, *args, **kwargs):
ql.reset()
resp = orig_generic(self, method, path, *args, **kwargs)
req = getattr(resp, "wsgi_request", None)
if req and getattr(req, "resolver_match", None):
url_name = req.resolver_match.url_name or req.path_info
ep_key = (method.upper(), url_name)
else:
ep_key = endpoint_key_from(method.upper(), path)
aggregator.merge(ep_key, ql)
return resp
monkeypatch.setattr(APIClient, "generic", wrapped_generic, raising=True)
with connection.execute_wrapper(ql):
yield
動かしてみる
テストを流してみます。
pytest -vv
すると、 conftest.py
のある tests/
配下に docs/data.json
が出来上がっているはずです。
みてみると......
[
{
"method": "GET",
"url_pattern": "users",
"tables": [
"user",
"group",
"group_user_relation"
]
},
{
"method": "GET",
"url_pattern": "users/<str:user_code>",
"tables": [
"group",
"group_user_relation",
"user",
"tag",
"user_follow_tag"
]
},
{
"method": "POST",
"url_pattern": "users",
"tables": [
"user",
"group",
"group_user_relation"
]
}
]
できてますね!
TestGetUserDetailAPIView
の下記部分をコメントアウトしてみます。
タグ取得処理がテスト上通らなくなるので、 users/<str:user_code>
からは tag
と user_follow_tag
が消えるはずです。
@pytest.mark.django_db
class TestGetUserDetailAPIView:
def test_get(self):
"""取得タグなし正常系"""
...
# def test_get_with_tag(self):
# """取得タグあり正常系"""
# ...
これで実行すると...
[
{
"method": "GET",
"url_pattern": "users",
"tables": [
"user",
"group_user_relation",
"group"
]
},
{
"method": "GET",
"url_pattern": "users/<str:user_code>",
"tables": [
"group",
"group_user_relation",
"user"
]
},
{
"method": "POST",
"url_pattern": "users",
"tables": [
"group",
"group_user_relation",
"user"
]
}
]
想定通り消えました!
終わりに
DRFでviews層のテストを実行すればエンドポイントごとにアクセスのあるテーブル一覧を取得することができました!
これで資料作成と影響確認が少し楽になりそうです!
しかし、これには罠があり、views層のテストを内部的な処理を意識して書く必要が出てきます。
views層のテストは可能ならブラックボックステストの観点で行いたいですが、条件に応じてテーブルの結合が変わる場合などでは取得ができなくなってしまうため、ホワイトボックステストの観点でviews層を書く必要が出てきます。
それだけviews層のテストは分厚くもなります。
果たしてこれが嬉しいのか。
実際に運用してみての所感は、テストの堅牢性が増すし資料作成も楽になるしでviews層をホワイトボックステスト観点で書く方に軍配が上がっています。
発展として、クエリが取れていてクエリパースもできているので、CRUD図なんかも作れそうな気がします。
ただそれを記事にするには、記事が長くなりすぎてしまうので今回はここまで。
以上、Django Rest Frameworkで、テストからエンドポイントごとにアクセスのあるテーブル一覧を出力するための実現案でした。
おわりだよー(o・∇・o)