LoginSignup
2
0

More than 3 years have passed since last update.

[DRF]PrimaryKeyRelatedFieldを高速化するスニペット

Last updated at Posted at 2020-09-12

はじめに

シリアライザでPrimaryKeyRelatedFieldmany=Trueにすると、pkの配列をリクエストパラメータで渡せますが、pkが大量にあるとめちゃくちゃ遅くなります
そのため原因の調査と高速化方法を考えました

原因

結論これです

ManyRelatedField.py
def to_internal_value(self, data):
    if isinstance(data, str) or not hasattr(data, '__iter__'):
        self.fail('not_a_list', input_type=type(data).__name__)
    if not self.allow_empty and len(data) == 0:
        self.fail('empty')

    return [
        self.child_relation.to_internal_value(item)
        for item in data
    ]

配列分self.child_relation.to_internal_value(item)が呼ばれてるため(to_internal_valueの中でpkをgetしている)遅くなっていました

高速化するスニペット

from rest_framework import serializers
from rest_framework.relations import MANY_RELATION_KWARGS, ManyRelatedField


class PrimaryKeyRelatedFieldEx(serializers.PrimaryKeyRelatedField):
    def __init__(self, **kwargs):
        self.queryset_response = kwargs.pop('queryset_response', False)
        super().__init__(**kwargs)

    class _ManyRelatedFieldEx(ManyRelatedField):
        def to_internal_value(self, data):
            if isinstance(data, str) or not hasattr(data, '__iter__'):
                self.fail('not_a_list', input_type=type(data).__name__)
            if not self.allow_empty and len(data) == 0:
                self.fail('empty')
            return self.child_relation.to_internal_value(data)

    @classmethod
    def many_init(cls, *args, **kwargs):
        list_kwargs = {'child_relation': cls(*args, **kwargs)}
        for key in kwargs:
            if key in MANY_RELATION_KWARGS:
                list_kwargs[key] = kwargs[key]
        return cls._ManyRelatedFieldEx(**list_kwargs)

    def to_internal_value(self, data):
        if isinstance(data, list):
            if self.pk_field is not None:
                data = self.pk_field.to_internal_value(data)
            results = self.get_queryset().filter(pk__in=data)
            # 全てのデータがあるかチェックする
            pk_list = results.values_list('pk', flat=True)
            pk_list = [str(n) for n in pk_list]
            data_list = [str(n) for n in data]
            diff = list(set(data_list) - set(list(pk_list)))
            if len(diff) > 0:
                pk_value = ', '.join(map(str, diff))
                self.fail('does_not_exist', pk_value=pk_value)
            if self.queryset_response:
                return results
            else:
                return list(results)
        else:
            return super().to_internal_value(data)

解説

  • 一つずつgetしていたのをfilterのinで取得するように変更
  • 存在しないpkが含まれていた場合のエラーメッセージは一つしか表示されなかったのを、カンマ区切りで全て表示するように変更
  • queryset_responseのパラメータを追加しています。queryset_response=Trueにするとレスポンスがquerysetになります(今まではquerysetの配列。個人的にはquerysetになってた方が使いやすいんじゃないかなと思う)
2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0