0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

GraphQLサーバー(strawberry)でフィルターを使えるようにする

Posted at

概要

GraphQLには本来SQLに入っているような絞り込み条件(where)や件数制限(limit)などは定義されておらず、サーバー側のリゾルバで実装するのが基本になっているようです。

しかし、GraphQL界隈にはRDBから自動的にリゾルバを作りRDBをそのままGraphQLサーバーに変えるようなソリューションがあったりするのですが、それらはどのようにして絞り込みなんかをしているのだろうと思ったら引数としてwhereなどを特定のJSON形式で渡して絞り込みをするフィルターという仕組みを独自に作って対応しているようです。

これをFastAPI+strawberry+SQLModelで実装してみました。

フィルタの仕様

見つけた範囲では、Hasurapg_graphql(supabase)がいいんじゃないかと思いました。

pg_graphqlはPostgresQLをデータベースに使ったfirebase代替を謳うmBaaSであるsupabaseが自身のGraphQL対応のために作ったソフトウェアのようです。

決められた仕様があるわけではないのでそれぞれ内容が異なりますが、とりあえずここでは両方の形式に対応できるように作ってみました。

コード

from sqlalchemy import String, and_, or_
from sqlalchemy.sql.elements import ClauseElement
from sqlmodel import select
from strawberry.utils.str_converters import to_camel_case


class QueryFilter:
    """
    hasura または pg_graphql(supabase) フォーマットの JSON フィルタを SQLAlchemy クエリに変換する
    """

    # Operator
    OPERATORS = {
        'eq': {'call': lambda field, value: field == value, 'cast': True, 'fcast': None},
        'neq': {'call': lambda field, value: field != value, 'cast': True, 'fcast': None},
        'gt': {'call': lambda field, value: field > value, 'cast': True, 'fcast': None},
        'lt': {'call': lambda field, value: field < value, 'cast': True, 'fcast': None},
        'gte': {'call': lambda field, value: field >= value, 'cast': True, 'fcast': None},
        'lte': {'call': lambda field, value: field <= value, 'cast': True, 'fcast': None},
        'in': {'call': lambda field, value: field.in_(value), 'cast': False, 'fcast': None},
        'contains': {'call': lambda field, value: field.contains(value), 'cast': False, 'fcast': String},
        'like': {'call': lambda field, value: field.like(value), 'cast': False, 'fcast': String},
        'regexp': {'call': lambda field, value: field.regexp_match(value), 'cast': False, 'fcast': String},
    }

    # sort order direction
    ORDER_DIRECTIONS = {
        # Hasura format
        'asc': {'type': 'asc', 'nulls': 'last'},
        'desc': {'type': 'desc', 'nulls': 'last'},
        'asc_nulls_first': {'type': 'asc', 'nulls': 'first'},
        'asc_nulls_last': {'type': 'asc', 'nulls': 'last'},
        'desc_nulls_first': {'type': 'desc', 'nulls': 'first'},
        'desc_nulls_last': {'type': 'desc', 'nulls': 'last'},
        # pg_graphql format
        'AscNullsFirst': {'type': 'asc', 'nulls': 'first'},
        'AscNullsLast': {'type': 'asc', 'nulls': 'last'},
        'DescNullsFirst': {'type': 'desc', 'nulls': 'first'},
        'DescNullsLast': {'type': 'desc', 'nulls': 'last'},
    }

    # Define boolean operators
    BOOLEAN_OPERATORS = {
        'and': and_,
        'or': or_,
        '_and': and_,
        '_or': or_,
    }

    def __init__(self, model: any, graphql_type: any, schema: any) -> None:
        """
        イニシャライザ

        :param model: SQLModel model
        :param graphql_type: Strawberry GraphQL type
        :param schema: Strawberry schema
        """
        # Initialize variables
        self.model = model
        self.graphql_type = graphql_type
        self.schema = schema
        # Get field mapping
        self.field_map = self.get_field_mapping(self.graphql_type, self.schema)

    def get_field_mapping(self, graphql_type: any, schema: any) -> dict:
        """
        自動生成されたGraphQLタイプから、変換後のフィールド名と元のフィールド名のマッピングを取得する

        :param graphql_type: Strawberry GraphQL type
        :param schema: Strawberry schema

        :return: Mapping between converted and original field names
        """
        result = {}

        for field in graphql_type.__strawberry_definition__.fields:
            graphql_name = to_camel_case(field.name) if schema.config.name_converter.auto_camel_case else field.name
            result[graphql_name] = field.name

        return result

    def set_order_directions(self, field: any, direction: str) -> any:
        """
        ソートの方向を設定する

        :param field: SQLAlchemy field
        :param direction: Sort direction

        :return: SQLAlchemy field with sort direction
        """
        if direction not in self.ORDER_DIRECTIONS:
            raise ValueError(f'Unknown order direction {direction}')

        direction_type = self.ORDER_DIRECTIONS[direction]['type']
        null_order = self.ORDER_DIRECTIONS[direction]['nulls']
        field = field.asc() if direction_type == 'asc' else field.desc()
        field = field.nulls_first() if null_order == 'first' else field.nulls_last()

        return field

    def convert_value_to_field_type(self, field: any, value: any) -> any:
        """
        モデルのフィールドタイプに基づいて値を適切な型に変換する

        :param field: SQLAlchemy field
        :param value: Value to convert

        :return: Converted value
        """
        field_type = field.type.python_type  # Get SQLAlchemy type
        try:
            return field_type(value)  # type conversion
        except (ValueError, TypeError):
            raise ValueError(f'Cannot convert {value} to type {field_type}')

    def convert_conditions(self, filters: dict, model: any, field_map: dict) -> any:
        """
        Where(filter)を解析し、SQLAlchemyのクエリリストに変換する。

        :param filters: Filter dictionary
        :param model: SQLModel model
        :param field_map: Field mapping

        :return: SQLAlchemy term list
        """
        conditions = []

        if not isinstance(filters, dict) or not filters:
            raise ValueError('Invalid filter')

        for key, value in filters.items():
            if key in self.BOOLEAN_OPERATORS:
                # AND/OR operator processing
                if not isinstance(value, list) or not value:
                    raise ValueError(f"{key} operator must be a non-empty list.")
                sub_conditions = [self.convert_conditions(sub_filter, model, field_map) for sub_filter in value]
                conditions.append(self.BOOLEAN_OPERATORS[key](*sub_conditions))
            elif key in {'_not', 'not'}:
                # NOT operator processing
                if not isinstance(value, dict) or not value:
                    raise ValueError(f"{key} operator must be a non-empty dictionary.")
                sub_condition = self.convert_conditions(value, model, field_map)
                if not isinstance(sub_condition, ClauseElement):
                    raise TypeError("Expected SQLAlchemy condition, but got something else.")
                conditions.append(~sub_condition)
            elif isinstance(value, dict):
                # Comparison operator processing
                if key not in field_map:
                    raise ValueError(f'Unknown field {key}')
                field = getattr(model, field_map[key])
                for op, val in value.items():
                    op = op.lstrip('_')  # Remove first underscores (Hasura format)
                    if op not in self.OPERATORS:
                        raise ValueError(f'Unknown operator {op}')
                    convert_val = self.convert_value_to_field_type(field, val) if self.OPERATORS[op]['cast'] else val
                    if self.OPERATORS[op]['fcast'] is not None:
                        conditions.append(self.OPERATORS[op]['call'](field.cast(self.OPERATORS[op]['fcast']), convert_val))
                    else:
                        conditions.append(self.OPERATORS[op]['call'](field, convert_val))
            else:
                raise ValueError(f'Invalid filter: {key}')

        return and_(*conditions)

    def convert_query(self, where: dict = None, filter: dict = None, order_by: list = None, limit: int = None, first: int = None, offset: int = None) -> select:
        """
        hasura または pg_graphql 形式の JSON フィルタを SQLAlchemy クエリに変換する

        :param where: WHERE part of the query (Hasura format)
        :param filter: Filter part of the query (pg_graphql format)
        :param order_by: ORDER BY part of the query
        :param limit: LIMIT part of the query (Hasura format)
        :param first: LIMIT part of the query (pg_graphql format)
        :param offset: OFFSET part of the query

        :return: SQLAlchemy query
        """
        # Convert arguments for pg_graphql format to Hasura format arguments
        # ('filter' to 'where' and 'first' to 'limit')
        if filter is not None:
            where = filter

        if first is not None:
            limit = first

        # Filter Validation
        if not self.validate(self.model, self.graphql_type, self.schema, {'where': where, 'order_by': order_by, 'limit': limit, 'offset': offset}):
            raise ValueError('Invalid filter')

        # Base query
        statement = select(self.model)

        # WHERE
        if where is not None:
            if not isinstance(where, dict):
                raise ValueError('Invalid filter')
            statement = statement.where(self.convert_conditions(where, self.model, self.field_map))

        # ORDER BY
        if order_by is not None:
            if not isinstance(order_by, list):
                raise ValueError('order_by must be a list')
            for item in order_by:
                for key, direction in item.items():
                    if key not in self.field_map:
                        raise ValueError(f'Unknown field {key}')
                    field = getattr(self.model, self.field_map[key])
                    statement = statement.order_by(self.set_order_directions(field, direction))

        # LIMIT
        if limit is not None:
            if limit < 0:
                raise ValueError('limit must be greater than or equal to 0')
            statement = statement.limit(limit)

        # OFFSET
        if offset is not None:
            if offset < 0:
                raise ValueError('offset must be greater than or equal to 0')
            statement = statement.offset(offset)

        return statement

    def validate(self, model: any, graphql_type: any, schema: any, filter: dict) -> bool:
        """
        フィルタバリデーション(ライブラリユーザーはこのメソッドを継承して実装する)

        :param model: SQLModel model
        :param graphql_type: Strawberry GraphQL type
        :param schema: Strawberry schema
        :param filter: Filter dictionary

        :return: True if the filter is valid, False otherwise
        """
        return True
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?