Skip to content

Commit

Permalink
Optional Pagination (#168)
Browse files Browse the repository at this point in the history
* Pre-commit-update

* Add option to exclude relay pagination

* Use standard resolver for exclude relay

* Clarify docstring

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Documentation

* Add exclude relay test

* Add release.md

* Use property of class

* Restore relay.py

* Upd test to use list

* fix tests

* adding type ignore on return

* fix test

* use isinstance on test

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: jojo <[email protected]>
Co-authored-by: Ckk3 <[email protected]>
  • Loading branch information
4 people authored Nov 12, 2024
1 parent d1fc069 commit 7cd14e8
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 30 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 24.4.2
hooks:
- id: black
exclude: ^tests/\w+/snapshots/

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.289
rev: v0.4.5
hooks:
- id: ruff
exclude: ^tests/\w+/snapshots/
Expand All @@ -18,13 +18,13 @@ repos:
exclude: (CHANGELOG|TWEET).md

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
rev: v4.0.0-alpha.8
hooks:
- id: prettier
files: '^docs/.*\.mdx?$'

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: check-merge-conflict
Expand Down
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: minor

Add an optional function to exclude relationships from relay pagination and use traditional strawberry lists.
Default behavior preserves original behavior for backwords compatibilty.
27 changes: 18 additions & 9 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,11 @@ class StrawberrySQLAlchemyType(Generic[BaseModelType]):

@overload
@classmethod
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self:
...
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...

@overload
@classmethod
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]:
...
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...

@classmethod
def from_type(
Expand Down Expand Up @@ -374,7 +372,7 @@ def _convert_column_to_strawberry_type(
return type_annotation

def _convert_relationship_to_strawberry_type(
self, relationship: RelationshipProperty
self, relationship: RelationshipProperty, use_list: bool = False
) -> Union[Type[Any], ForwardRef]:
"""
Given a SQLAlchemy relationship, return the type annotation for the field in the
Expand All @@ -387,6 +385,10 @@ def _convert_relationship_to_strawberry_type(
else:
self._related_type_models.add(relationship_model)
if relationship.uselist:
# Use list if excluding relay pagination
if use_list:
return List[ForwardRef(type_name)] # type: ignore

return self._connection_type_for(type_name)
else:
if self._get_relationship_is_optional(relationship):
Expand Down Expand Up @@ -524,14 +526,14 @@ async def resolve(self, info: Info):
return resolve

def connection_resolver_for(
self, relationship: RelationshipProperty
self, relationship: RelationshipProperty, use_list=False
) -> Callable[..., Awaitable[Any]]:
"""
Return an async field resolver for the given relationship that
returns a Connection instead of an array of objects.
"""
relationship_resolver = self.relationship_resolver_for(relationship)
if relationship.uselist:
if relationship.uselist and not use_list:
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
Expand Down Expand Up @@ -666,6 +668,7 @@ def convert(type_: Any) -> Any:
generated_field_keys = []

excluded_keys = getattr(type_, "__exclude__", [])
list_keys = getattr(type_, "__use_list__", [])

# if the type inherits from another mapped type, then it may have
# generated resolvers. These will be treated by dataclasses as having
Expand All @@ -690,7 +693,8 @@ def convert(type_: Any) -> Any:
):
continue
strawberry_type = self._convert_relationship_to_strawberry_type(
relationship
relationship,
key in list_keys,
)
self._add_annotation(
type_,
Expand All @@ -700,7 +704,12 @@ def convert(type_: Any) -> Any:
)
sqlalchemy_field = cast(
StrawberryField,
field(resolver=self.connection_resolver_for(relationship)),
field(
resolver=self.connection_resolver_for(
relationship,
key in list_keys,
)
),
)
assert not sqlalchemy_field.init
setattr(
Expand Down
21 changes: 7 additions & 14 deletions src/strawberry_sqlalchemy_mapper/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: Iterable[Union[str, relay.GlobalID]],
required: Literal[True],
) -> AwaitableOrValue[Iterable[_T]]:
...
) -> AwaitableOrValue[Iterable[_T]]: ...


@overload
Expand All @@ -174,8 +173,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: None = None,
required: Literal[True],
) -> AwaitableOrValue[Iterable[_T]]:
...
) -> AwaitableOrValue[Iterable[_T]]: ...


@overload
Expand All @@ -190,8 +188,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: Iterable[Union[str, relay.GlobalID]],
required: Literal[False],
) -> AwaitableOrValue[Iterable[Optional[_T]]]:
...
) -> AwaitableOrValue[Iterable[Optional[_T]]]: ...


@overload
Expand All @@ -206,8 +203,7 @@ def resolve_model_nodes(
info: Optional[Info] = None,
node_ids: None = None,
required: Literal[False],
) -> AwaitableOrValue[Optional[Iterable[_T]]]:
...
) -> AwaitableOrValue[Optional[Iterable[_T]]]: ...


@overload
Expand All @@ -229,8 +225,7 @@ def resolve_model_nodes(
Iterable[Optional[_T]],
Optional[Query[_T]],
]
]:
...
]: ...


def resolve_model_nodes(
Expand Down Expand Up @@ -307,8 +302,7 @@ def resolve_model_node(
session: Session,
info: Optional[Info] = ...,
required: Literal[False] = ...,
) -> AwaitableOrValue[Optional[_T]]:
...
) -> AwaitableOrValue[Optional[_T]]: ...


@overload
Expand All @@ -323,8 +317,7 @@ def resolve_model_node(
session: Session,
info: Optional[Info] = ...,
required: Literal[True],
) -> AwaitableOrValue[_T]:
...
) -> AwaitableOrValue[_T]: ...


def resolve_model_node(
Expand Down
34 changes: 31 additions & 3 deletions tests/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy.dialects.postgresql.array import ARRAY
from sqlalchemy.orm import relationship
from strawberry.scalars import JSON as StrawberryJSON
from strawberry.types.base import StrawberryOptional
from strawberry.types.base import StrawberryList, StrawberryOptional
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper


Expand Down Expand Up @@ -263,6 +263,35 @@ class Lawyer:
assert {"Employee", "Lawyer"} == {t.__name__ for t in additional_types}


def test_use_list(employee_and_department_tables, mapper):
Employee, Department = employee_and_department_tables

@mapper.type(Employee)
class Employee:
pass

@mapper.type(Department)
class Department:
__use_list__ = ["employees"]

mapper.finalize()
additional_types = list(mapper.mapped_types.values())
assert len(additional_types) == 2
mapped_employee_type = additional_types[0]
assert mapped_employee_type.__name__ == "Employee"
mapped_department_type = additional_types[1]
assert mapped_department_type.__name__ == "Department"
assert len(mapped_department_type.__strawberry_definition__.fields) == 3
department_type_fields = mapped_department_type.__strawberry_definition__.fields

name = next(
(field for field in department_type_fields if field.name == "employees"), None
)
assert name is not None
assert isinstance(name.type, StrawberryOptional) is False
assert isinstance(name.type, StrawberryList) is True


def test_type_relationships(employee_and_department_tables, mapper):
Employee, _ = employee_and_department_tables

Expand Down Expand Up @@ -297,8 +326,7 @@ class Department:
@strawberry.type
class Query:
@strawberry.field
def departments(self) -> Department:
...
def departments(self) -> Department: ...

mapper.finalize()
schema = strawberry.Schema(query=Query)
Expand Down

0 comments on commit 7cd14e8

Please sign in to comment.