from sqlalchemy import String, and_, or_
from sqlalchemy.sql.elements import ClauseElement
from sqlmodel import select
from strawberry.utils.str_converters import to_camel_case
class QueryFilter:
hasura または pg_graphql(supabase) フォーマットの JSON フィルタを SQLAlchemy クエリに変換する
# Operator
'eq': {'call': lambda field, value: field == value, 'cast': True, 'fcast': None},
'neq': {'call': lambda field, value: field != value, 'cast': True, 'fcast': None},
'gt': {'call': lambda field, value: field > value, 'cast': True, 'fcast': None},
'lt': {'call': lambda field, value: field < value, 'cast': True, 'fcast': None},
'gte': {'call': lambda field, value: field >= value, 'cast': True, 'fcast': None},
'lte': {'call': lambda field, value: field <= value, 'cast': True, 'fcast': None},
'in': {'call': lambda field, value: field.in_(value), 'cast': False, 'fcast': None},
'contains': {'call': lambda field, value: field.contains(value), 'cast': False, 'fcast': String},
'like': {'call': lambda field, value: field.like(value), 'cast': False, 'fcast': String},
'regexp': {'call': lambda field, value: field.regexp_match(value), 'cast': False, 'fcast': String},
# sort order direction
# Hasura format
'asc': {'type': 'asc', 'nulls': 'last'},
'desc': {'type': 'desc', 'nulls': 'last'},
'asc_nulls_first': {'type': 'asc', 'nulls': 'first'},
'asc_nulls_last': {'type': 'asc', 'nulls': 'last'},
'desc_nulls_first': {'type': 'desc', 'nulls': 'first'},
'desc_nulls_last': {'type': 'desc', 'nulls': 'last'},
# pg_graphql format
'AscNullsFirst': {'type': 'asc', 'nulls': 'first'},
'AscNullsLast': {'type': 'asc', 'nulls': 'last'},
'DescNullsFirst': {'type': 'desc', 'nulls': 'first'},
'DescNullsLast': {'type': 'desc', 'nulls': 'last'},
# Define boolean operators
'and': and_,
'or': or_,
'_and': and_,
'_or': or_,
def __init__(self, model: any, graphql_type: any, schema: any) -> None:
:param model: SQLModel model
:param graphql_type: Strawberry GraphQL type
:param schema: Strawberry schema
# Initialize variables
self.model = model
self.graphql_type = graphql_type
self.schema = schema
# Get field mapping
self.field_map = self.get_field_mapping(self.graphql_type, self.schema)
def get_field_mapping(self, graphql_type: any, schema: any) -> dict:
:param graphql_type: Strawberry GraphQL type
:param schema: Strawberry schema
:return: Mapping between converted and original field names
result = {}
for field in graphql_type.__strawberry_definition__.fields:
graphql_name = to_camel_case(field.name) if schema.config.name_converter.auto_camel_case else field.name
result[graphql_name] = field.name
return result
def set_order_directions(self, field: any, direction: str) -> any:
:param field: SQLAlchemy field
:param direction: Sort direction
:return: SQLAlchemy field with sort direction
if direction not in self.ORDER_DIRECTIONS:
raise ValueError(f'Unknown order direction {direction}')
direction_type = self.ORDER_DIRECTIONS[direction]['type']
null_order = self.ORDER_DIRECTIONS[direction]['nulls']
field = field.asc() if direction_type == 'asc' else field.desc()
field = field.nulls_first() if null_order == 'first' else field.nulls_last()
return field
def convert_value_to_field_type(self, field: any, value: any) -> any:
:param field: SQLAlchemy field
:param value: Value to convert
:return: Converted value
field_type = field.type.python_type # Get SQLAlchemy type
return field_type(value) # type conversion
except (ValueError, TypeError):
raise ValueError(f'Cannot convert {value} to type {field_type}')
def convert_conditions(self, filters: dict, model: any, field_map: dict) -> any:
:param filters: Filter dictionary
:param model: SQLModel model
:param field_map: Field mapping
:return: SQLAlchemy term list
conditions = []
if not isinstance(filters, dict) or not filters:
raise ValueError('Invalid filter')
for key, value in filters.items():
if key in self.BOOLEAN_OPERATORS:
# AND/OR operator processing
if not isinstance(value, list) or not value:
raise ValueError(f"{key} operator must be a non-empty list.")
sub_conditions = [self.convert_conditions(sub_filter, model, field_map) for sub_filter in value]
elif key in {'_not', 'not'}:
# NOT operator processing
if not isinstance(value, dict) or not value:
raise ValueError(f"{key} operator must be a non-empty dictionary.")
sub_condition = self.convert_conditions(value, model, field_map)
if not isinstance(sub_condition, ClauseElement):
raise TypeError("Expected SQLAlchemy condition, but got something else.")
elif isinstance(value, dict):
# Comparison operator processing
if key not in field_map:
raise ValueError(f'Unknown field {key}')
field = getattr(model, field_map[key])
for op, val in value.items():
op = op.lstrip('_') # Remove first underscores (Hasura format)
if op not in self.OPERATORS:
raise ValueError(f'Unknown operator {op}')
convert_val = self.convert_value_to_field_type(field, val) if self.OPERATORS[op]['cast'] else val
if self.OPERATORS[op]['fcast'] is not None:
conditions.append(self.OPERATORS[op]['call'](field.cast(self.OPERATORS[op]['fcast']), convert_val))
conditions.append(self.OPERATORS[op]['call'](field, convert_val))
raise ValueError(f'Invalid filter: {key}')
return and_(*conditions)
def convert_query(self, where: dict = None, filter: dict = None, order_by: list = None, limit: int = None, first: int = None, offset: int = None) -> select:
hasura または pg_graphql 形式の JSON フィルタを SQLAlchemy クエリに変換する
:param where: WHERE part of the query (Hasura format)
:param filter: Filter part of the query (pg_graphql format)
:param order_by: ORDER BY part of the query
:param limit: LIMIT part of the query (Hasura format)
:param first: LIMIT part of the query (pg_graphql format)
:param offset: OFFSET part of the query
:return: SQLAlchemy query
# Convert arguments for pg_graphql format to Hasura format arguments
# ('filter' to 'where' and 'first' to 'limit')
if filter is not None:
where = filter
if first is not None:
limit = first
# Filter Validation
if not self.validate(self.model, self.graphql_type, self.schema, {'where': where, 'order_by': order_by, 'limit': limit, 'offset': offset}):
raise ValueError('Invalid filter')
# Base query
statement = select(self.model)
if where is not None:
if not isinstance(where, dict):
raise ValueError('Invalid filter')
statement = statement.where(self.convert_conditions(where, self.model, self.field_map))
if order_by is not None:
if not isinstance(order_by, list):
raise ValueError('order_by must be a list')
for item in order_by:
for key, direction in item.items():
if key not in self.field_map:
raise ValueError(f'Unknown field {key}')
field = getattr(self.model, self.field_map[key])
statement = statement.order_by(self.set_order_directions(field, direction))
if limit is not None:
if limit < 0:
raise ValueError('limit must be greater than or equal to 0')
statement = statement.limit(limit)
if offset is not None:
if offset < 0:
raise ValueError('offset must be greater than or equal to 0')
statement = statement.offset(offset)
return statement
def validate(self, model: any, graphql_type: any, schema: any, filter: dict) -> bool:
:param model: SQLModel model
:param graphql_type: Strawberry GraphQL type
:param schema: Strawberry schema
:param filter: Filter dictionary
:return: True if the filter is valid, False otherwise
return True