Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional Pagination #168

Merged
merged 15 commits into from
Nov 12, 2024
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 24.4.2
hooks:
- id: black
exclude: ^tests/\w+/snapshots/

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.289
rev: v0.4.5
hooks:
- id: ruff
exclude: ^tests/\w+/snapshots/
Expand All @@ -18,13 +18,13 @@ repos:
exclude: (CHANGELOG|TWEET).md

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
rev: v4.0.0-alpha.8
hooks:
- id: prettier
files: '^docs/.*\.mdx?$'

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: check-merge-conflict
Expand Down
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: minor

Add an optional function to exclude relationships from relay pagination and use traditional strawberry lists.
Default behavior preserves original behavior for backwords compatibilty.
27 changes: 18 additions & 9 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,11 @@ class StrawberrySQLAlchemyType(Generic[BaseModelType]):

@overload
@classmethod
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self:
...
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...

@overload
@classmethod
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]:
...
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...

@classmethod
def from_type(
Expand Down Expand Up @@ -374,7 +372,7 @@ def _convert_column_to_strawberry_type(
return type_annotation

def _convert_relationship_to_strawberry_type(
self, relationship: RelationshipProperty
self, relationship: RelationshipProperty, use_list: bool = False
) -> Union[Type[Any], ForwardRef]:
"""
Given a SQLAlchemy relationship, return the type annotation for the field in the
Expand All @@ -387,6 +385,10 @@ def _convert_relationship_to_strawberry_type(
else:
self._related_type_models.add(relationship_model)
if relationship.uselist:
# Use list if excluding relay pagination
if use_list:
return List[ForwardRef(type_name)]

return self._connection_type_for(type_name)
else:
if self._get_relationship_is_optional(relationship):
Expand Down Expand Up @@ -524,14 +526,14 @@ async def resolve(self, info: Info):
return resolve

def connection_resolver_for(
self, relationship: RelationshipProperty
self, relationship: RelationshipProperty, use_list=False
) -> Callable[..., Awaitable[Any]]:
"""
Return an async field resolver for the given relationship that
returns a Connection instead of an array of objects.
"""
relationship_resolver = self.relationship_resolver_for(relationship)
if relationship.uselist:
if relationship.uselist and not use_list:
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
Expand Down Expand Up @@ -666,6 +668,7 @@ def convert(type_: Any) -> Any:
generated_field_keys = []

excluded_keys = getattr(type_, "__exclude__", [])
list_keys = getattr(type_, "__use_list__", [])

# if the type inherits from another mapped type, then it may have
# generated resolvers. These will be treated by dataclasses as having
Expand All @@ -690,7 +693,8 @@ def convert(type_: Any) -> Any:
):
continue
strawberry_type = self._convert_relationship_to_strawberry_type(
relationship
relationship,
key in list_keys,
)
self._add_annotation(
type_,
Expand All @@ -700,7 +704,12 @@ def convert(type_: Any) -> Any:
)
sqlalchemy_field = cast(
StrawberryField,
field(resolver=self.connection_resolver_for(relationship)),
field(
resolver=self.connection_resolver_for(
relationship,
key in list_keys,
)
),
)
assert not sqlalchemy_field.init
setattr(
Expand Down
21 changes: 7 additions & 14 deletions src/strawberry_sqlalchemy_mapper/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: Iterable[Union[str, relay.GlobalID]],
required: Literal[True],
) -> AwaitableOrValue[Iterable[_T]]:
...
) -> AwaitableOrValue[Iterable[_T]]: ...


@overload
Expand All @@ -174,8 +173,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: None = None,
required: Literal[True],
) -> AwaitableOrValue[Iterable[_T]]:
...
) -> AwaitableOrValue[Iterable[_T]]: ...


@overload
Expand All @@ -190,8 +188,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: Iterable[Union[str, relay.GlobalID]],
required: Literal[False],
) -> AwaitableOrValue[Iterable[Optional[_T]]]:
...
) -> AwaitableOrValue[Iterable[Optional[_T]]]: ...


@overload
Expand All @@ -206,8 +203,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: None = None,
required: Literal[False],
) -> AwaitableOrValue[Optional[Iterable[_T]]]:
...
) -> AwaitableOrValue[Optional[Iterable[_T]]]: ...


@overload
Expand All @@ -229,8 +225,7 @@ def resolve_model_nodes(
Iterable[Optional[_T]],
Optional[Query[_T]],
]
]:
...
]: ...


def resolve_model_nodes(
Expand Down Expand Up @@ -307,8 +302,7 @@ def resolve_model_node(
session: Session,
info: Optional[Info] = ...,
required: Literal[False] = ...,
) -> AwaitableOrValue[Optional[_T]]:
...
) -> AwaitableOrValue[Optional[_T]]: ...


@overload
Expand All @@ -323,8 +317,7 @@ def resolve_model_node(
session: Session,
info: Optional[Info] = ...,
required: Literal[True],
) -> AwaitableOrValue[_T]:
...
) -> AwaitableOrValue[_T]: ...


def resolve_model_node(
Expand Down
28 changes: 26 additions & 2 deletions tests/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,31 @@ class Lawyer:
assert {"Employee", "Lawyer"} == {t.__name__ for t in additional_types}


def test_use_list(employee_and_department_tables, mapper):
Employee, Department = employee_and_department_tables

@mapper.type(Employee)
class Employee:
pass

@mapper.type(Department)
class Department:
__use_list__ = ["employees"]

mapper.finalize()
additional_types = list(mapper.mapped_types.values())
assert len(additional_types) == 2
mapped_employee_type = additional_types[0]
assert mapped_employee_type.__name__ == "Employee"
mapped_department_type = additional_types[1]
assert mapped_department_type.__name__ == "Department"
assert len(mapped_department_type.__strawberry_definition__._fields) == 3
department_type_fields = mapped_department_type.__strawberry_definition__._fields
name = next(iter(filter(lambda f: f.name == "employees", department_type_fields)))
assert type(name.type) != StrawberryOptional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fruitymedley, I believe the comparison should be done using isinstance(name.type, StrawberryOptional) instead of using !=. This aligns with the conventions we're following in the codebase when checking if a type is StrawberryOptional.

I came across a couple of instances in the code that confirm this pattern:

Here, we're using isinstance to check for StrawberryOptional in the type_.of_type.
Similarly, here, the same pattern is used for type_.

What do you think?

assert type(name.type) == List[mapped_employee_type]


def test_type_relationships(employee_and_department_tables, mapper):
Employee, _ = employee_and_department_tables

Expand Down Expand Up @@ -297,8 +322,7 @@ class Department:
@strawberry.type
class Query:
@strawberry.field
def departments(self) -> Department:
...
def departments(self) -> Department: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)
Expand Down
Loading