概要
GraphQLには本来SQLに入っているような絞り込み条件(where)や件数制限(limit)などは定義されておらず、サーバー側のリゾルバで実装するのが基本になっているようです。
しかし、GraphQL界隈にはRDBから自動的にリゾルバを作りRDBをそのままGraphQLサーバーに変えるようなソリューションがあったりするのですが、それらはどのようにして絞り込みなんかをしているのだろうと思ったら引数としてwhereなどを特定のJSON形式で渡して絞り込みをするフィルターという仕組みを独自に作って対応しているようです。
これをFastAPI+strawberry+SQLModelで実装してみました。
フィルタの仕様
見つけた範囲では、Hasura
とpg_graphql
(supabase
)がいいんじゃないかと思いました。
pg_graphql
はPostgresQLをデータベースに使ったfirebase代替を謳うmBaaSであるsupabase
が自身のGraphQL対応のために作ったソフトウェアのようです。
決められた仕様があるわけではないのでそれぞれ内容が異なりますが、とりあえずここでは両方の形式に対応できるように作ってみました。
コード
from sqlalchemy import String, and_, or_
from sqlalchemy.sql.elements import ClauseElement
from sqlmodel import select
from strawberry.utils.str_converters import to_camel_case
class QueryFilter:
"""
hasura または pg_graphql(supabase) フォーマットの JSON フィルタを SQLAlchemy クエリに変換する
"""
# Operator
OPERATORS = {
'eq': {'call': lambda field, value: field == value, 'cast': True, 'fcast': None},
'neq': {'call': lambda field, value: field != value, 'cast': True, 'fcast': None},
'gt': {'call': lambda field, value: field > value, 'cast': True, 'fcast': None},
'lt': {'call': lambda field, value: field < value, 'cast': True, 'fcast': None},
'gte': {'call': lambda field, value: field >= value, 'cast': True, 'fcast': None},
'lte': {'call': lambda field, value: field <= value, 'cast': True, 'fcast': None},
'in': {'call': lambda field, value: field.in_(value), 'cast': False, 'fcast': None},
'contains': {'call': lambda field, value: field.contains(value), 'cast': False, 'fcast': String},
'like': {'call': lambda field, value: field.like(value), 'cast': False, 'fcast': String},
'regexp': {'call': lambda field, value: field.regexp_match(value), 'cast': False, 'fcast': String},
}
# sort order direction
ORDER_DIRECTIONS = {
# Hasura format
'asc': {'type': 'asc', 'nulls': 'last'},
'desc': {'type': 'desc', 'nulls': 'last'},
'asc_nulls_first': {'type': 'asc', 'nulls': 'first'},
'asc_nulls_last': {'type': 'asc', 'nulls': 'last'},
'desc_nulls_first': {'type': 'desc', 'nulls': 'first'},
'desc_nulls_last': {'type': 'desc', 'nulls': 'last'},
# pg_graphql format
'AscNullsFirst': {'type': 'asc', 'nulls': 'first'},
'AscNullsLast': {'type': 'asc', 'nulls': 'last'},
'DescNullsFirst': {'type': 'desc', 'nulls': 'first'},
'DescNullsLast': {'type': 'desc', 'nulls': 'last'},
}
# Define boolean operators
BOOLEAN_OPERATORS = {
'and': and_,
'or': or_,
'_and': and_,
'_or': or_,
}
def __init__(self, model: any, graphql_type: any, schema: any) -> None:
"""
イニシャライザ
:param model: SQLModel model
:param graphql_type: Strawberry GraphQL type
:param schema: Strawberry schema
"""
# Initialize variables
self.model = model
self.graphql_type = graphql_type
self.schema = schema
# Get field mapping
self.field_map = self.get_field_mapping(self.graphql_type, self.schema)
def get_field_mapping(self, graphql_type: any, schema: any) -> dict:
"""
自動生成されたGraphQLタイプから、変換後のフィールド名と元のフィールド名のマッピングを取得する
:param graphql_type: Strawberry GraphQL type
:param schema: Strawberry schema
:return: Mapping between converted and original field names
"""
result = {}
for field in graphql_type.__strawberry_definition__.fields:
graphql_name = to_camel_case(field.name) if schema.config.name_converter.auto_camel_case else field.name
result[graphql_name] = field.name
return result
def set_order_directions(self, field: any, direction: str) -> any:
"""
ソートの方向を設定する
:param field: SQLAlchemy field
:param direction: Sort direction
:return: SQLAlchemy field with sort direction
"""
if direction not in self.ORDER_DIRECTIONS:
raise ValueError(f'Unknown order direction {direction}')
direction_type = self.ORDER_DIRECTIONS[direction]['type']
null_order = self.ORDER_DIRECTIONS[direction]['nulls']
field = field.asc() if direction_type == 'asc' else field.desc()
field = field.nulls_first() if null_order == 'first' else field.nulls_last()
return field
def convert_value_to_field_type(self, field: any, value: any) -> any:
"""
モデルのフィールドタイプに基づいて値を適切な型に変換する
:param field: SQLAlchemy field
:param value: Value to convert
:return: Converted value
"""
field_type = field.type.python_type # Get SQLAlchemy type
try:
return field_type(value) # type conversion
except (ValueError, TypeError):
raise ValueError(f'Cannot convert {value} to type {field_type}')
def convert_conditions(self, filters: dict, model: any, field_map: dict) -> any:
"""
Where(filter)を解析し、SQLAlchemyのクエリリストに変換する。
:param filters: Filter dictionary
:param model: SQLModel model
:param field_map: Field mapping
:return: SQLAlchemy term list
"""
conditions = []
if not isinstance(filters, dict) or not filters:
raise ValueError('Invalid filter')
for key, value in filters.items():
if key in self.BOOLEAN_OPERATORS:
# AND/OR operator processing
if not isinstance(value, list) or not value:
raise ValueError(f"{key} operator must be a non-empty list.")
sub_conditions = [self.convert_conditions(sub_filter, model, field_map) for sub_filter in value]
conditions.append(self.BOOLEAN_OPERATORS[key](*sub_conditions))
elif key in {'_not', 'not'}:
# NOT operator processing
if not isinstance(value, dict) or not value:
raise ValueError(f"{key} operator must be a non-empty dictionary.")
sub_condition = self.convert_conditions(value, model, field_map)
if not isinstance(sub_condition, ClauseElement):
raise TypeError("Expected SQLAlchemy condition, but got something else.")
conditions.append(~sub_condition)
elif isinstance(value, dict):
# Comparison operator processing
if key not in field_map:
raise ValueError(f'Unknown field {key}')
field = getattr(model, field_map[key])
for op, val in value.items():
op = op.lstrip('_') # Remove first underscores (Hasura format)
if op not in self.OPERATORS:
raise ValueError(f'Unknown operator {op}')
convert_val = self.convert_value_to_field_type(field, val) if self.OPERATORS[op]['cast'] else val
if self.OPERATORS[op]['fcast'] is not None:
conditions.append(self.OPERATORS[op]['call'](field.cast(self.OPERATORS[op]['fcast']), convert_val))
else:
conditions.append(self.OPERATORS[op]['call'](field, convert_val))
else:
raise ValueError(f'Invalid filter: {key}')
return and_(*conditions)
def convert_query(self, where: dict = None, filter: dict = None, order_by: list = None, limit: int = None, first: int = None, offset: int = None) -> select:
"""
hasura または pg_graphql 形式の JSON フィルタを SQLAlchemy クエリに変換する
:param where: WHERE part of the query (Hasura format)
:param filter: Filter part of the query (pg_graphql format)
:param order_by: ORDER BY part of the query
:param limit: LIMIT part of the query (Hasura format)
:param first: LIMIT part of the query (pg_graphql format)
:param offset: OFFSET part of the query
:return: SQLAlchemy query
"""
# Convert arguments for pg_graphql format to Hasura format arguments
# ('filter' to 'where' and 'first' to 'limit')
if filter is not None:
where = filter
if first is not None:
limit = first
# Filter Validation
if not self.validate(self.model, self.graphql_type, self.schema, {'where': where, 'order_by': order_by, 'limit': limit, 'offset': offset}):
raise ValueError('Invalid filter')
# Base query
statement = select(self.model)
# WHERE
if where is not None:
if not isinstance(where, dict):
raise ValueError('Invalid filter')
statement = statement.where(self.convert_conditions(where, self.model, self.field_map))
# ORDER BY
if order_by is not None:
if not isinstance(order_by, list):
raise ValueError('order_by must be a list')
for item in order_by:
for key, direction in item.items():
if key not in self.field_map:
raise ValueError(f'Unknown field {key}')
field = getattr(self.model, self.field_map[key])
statement = statement.order_by(self.set_order_directions(field, direction))
# LIMIT
if limit is not None:
if limit < 0:
raise ValueError('limit must be greater than or equal to 0')
statement = statement.limit(limit)
# OFFSET
if offset is not None:
if offset < 0:
raise ValueError('offset must be greater than or equal to 0')
statement = statement.offset(offset)
return statement
def validate(self, model: any, graphql_type: any, schema: any, filter: dict) -> bool:
"""
フィルタバリデーション(ライブラリユーザーはこのメソッドを継承して実装する)
:param model: SQLModel model
:param graphql_type: Strawberry GraphQL type
:param schema: Strawberry schema
:param filter: Filter dictionary
:return: True if the filter is valid, False otherwise
"""
return True