From 05bf3e6fb2f48c37f2a25858e74ff9cdc0884073 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Sat, 4 Nov 2023 15:42:47 -0300 Subject: [PATCH] feat: add sqlakeyset integration --- poetry.lock | 21 +- pyproject.toml | 1 + src/strawberry_sqlalchemy_mapper/field.py | 97 ++++++-- src/strawberry_sqlalchemy_mapper/relay.py | 105 ++++++++- tests/relay/test_connection.py | 270 ++++++++++++++++++++++ 5 files changed, 477 insertions(+), 17 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0b91462..12cfe01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "argcomplete" @@ -1168,6 +1168,23 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sqlakeyset" +version = "2.0.1695177552" +description = "offset-free paging for sqlalchemy" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "sqlakeyset-2.0.1695177552-py3-none-any.whl", hash = "sha256:8086d3e8fc0e50f01325077b40e9e94ba9e53e8561f3749e9dbb49c0557186ee"}, + {file = "sqlakeyset-2.0.1695177552.tar.gz", hash = "sha256:cbb0864dd7d04c86debbfa5a338aea0a4d2385ef759b0210eb90fbae42a55ce3"}, +] + +[package.dependencies] +packaging = ">=20.0" +python-dateutil = "*" +sqlalchemy = ">=1.3.11" +typing_extensions = ">=4,<5" + [[package]] name = "sqlalchemy" version = "2.0.20" @@ -1397,4 +1414,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ced476e10b04eb5c6461ef3caee90d57b694da635c2561eddbf7d7d6b1c44308" +content-hash = "5b3c092d52d75696e0dffa05460203092826ec2fd30176ab1d3f8f75375526e2" diff --git a/pyproject.toml b/pyproject.toml index 04dfbfa..f25dc18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ sqlalchemy = {extras = ["asyncio"], version = ">=1.4"} strawberry-graphql = ">=0.95" sentinel = ">=0.3,<1.1" greenlet = {version = ">=3.0.0rc1", python = ">=3.12"} +sqlakeyset = "^2.0.1695177552" [tool.poetry.group.dev.dependencies] asyncpg = "^0.28.0" diff --git a/src/strawberry_sqlalchemy_mapper/field.py b/src/strawberry_sqlalchemy_mapper/field.py index 632020b..65fd821 100644 --- a/src/strawberry_sqlalchemy_mapper/field.py +++ b/src/strawberry_sqlalchemy_mapper/field.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import contextlib +import contextvars import dataclasses import inspect from collections import defaultdict @@ -9,6 +11,7 @@ Callable, DefaultDict, Dict, + Generator, Iterable, Iterator, List, @@ -25,6 +28,7 @@ ) from typing_extensions import Annotated, TypeAlias +from sqlakeyset.types import Keyset from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Query, Session from strawberry import relay @@ -55,6 +59,25 @@ assert argument # type: ignore[truthy-function] +connection_session: contextvars.ContextVar[ + Union[Session, AsyncSession, None] +] = contextvars.ContextVar( + "connection-session", + default=None, +) + + +@contextlib.contextmanager +def set_connection_session( + s: Union[Session, AsyncSession, None] +) -> Generator[None, None, None]: + token = connection_session.set(s) + try: + yield + finally: + connection_session.reset(token) + + class StrawberrySQLAlchemyField(StrawberryField): """ Base field for SQLAlchemy types. @@ -63,9 +86,11 @@ class StrawberrySQLAlchemyField(StrawberryField): def __init__( self, sessionmaker: _SessionMaker | None = None, + keyset: Keyset | None = None, **kwargs, ): self.sessionmaker = sessionmaker + self.keyset = keyset super().__init__(**kwargs) @@ -260,6 +285,7 @@ class StrawberrySQLAlchemyConnectionExtension(relay.ConnectionExtension): def apply(self, field: StrawberrySQLAlchemyField) -> None: # type: ignore[override] from strawberry_sqlalchemy_mapper.mapper import StrawberrySQLAlchemyType + self.field = field strawberry_definition = get_object_definition(field.type, strict=True) node_type = strawberry_definition.type_var_map.get("NodeType") if node_type is None: @@ -295,35 +321,74 @@ def default_resolver( info: Info, **kwargs: Any, ) -> Iterable[Any]: - session = field_sessionmaker() + session = connection_session.get() + if session is None: + session = field_sessionmaker() - if isinstance(session, AsyncSession): + def _get_query(s: Session): if root is not None: - return cast( - Iterable[Any], - StrawberrySQLAlchemyAsyncQuery( - session=session, - query=getattr(root, field.python_name), - ), - ) + # root won't be None when resolving nested connections. + # TODO: Maybe we want to send this to a dataloader? + query = getattr(root, field.python_name) + else: + query = s.query(model) + + if field.keyset is not None: + query = query.order_by(*field.keyset) + return query + + if isinstance(session, AsyncSession): return cast( Iterable[Any], StrawberrySQLAlchemyAsyncQuery( session=session, - query=lambda s: s.query(model), + query=lambda s: _get_query(s), ), ) - if root is not None: - return getattr(root, field.python_name) - - return session.query(model) + return _get_query(session) field.base_resolver = StrawberryResolver(default_resolver) return super().apply(field) + def resolve(self, *args, **kwargs) -> Any: + if (field_sessionmaker := self.field.sessionmaker) is None: + raise TypeError(f"Missing `sessionmaker` argument for field {field.name}") + + session = field_sessionmaker() + + if isinstance(session, AsyncSession): + super_meth = super().resolve + + async def inner_resolve_async(): + async with session as s: + with set_connection_session(s): + retval = super_meth(*args, **kwargs) + if inspect.isawaitable(retval): + retval = await retval + return retval + + return inner_resolve_async() + + with session as s, set_connection_session(s): + return super().resolve(*args, **kwargs) + + async def resolve_async(self, *args, **kwargs) -> Any: + if (field_sessionmaker := self.field.sessionmaker) is None: + raise TypeError(f"Missing `sessionmaker` argument for field {field.name}") + + session = field_sessionmaker() + + if isinstance(session, AsyncSession): + async with session as s: + with set_connection_session(s): + return await super().resolve_async(*args, **kwargs) + + with session as s, set_connection_session(s): + return await super().resolve_async(*args, **kwargs) + @overload def field( @@ -527,6 +592,7 @@ def connection( directives: Sequence[object] | None = (), extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, + keyset: Keyset | None = None, ) -> Any: ... @@ -549,6 +615,7 @@ def connection( directives: Sequence[object] | None = (), extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, + keyset: Keyset | None = None, ) -> Any: ... @@ -569,6 +636,7 @@ def connection( directives: Sequence[object] | None = (), extensions: Sequence[FieldExtension] = (), sessionmaker: _SessionMaker | None = None, + keyset: Keyset | None = None, # This init parameter is used by pyright to determine whether this field # is added in the constructor or not. It is not used to change # any behavior at the moment. @@ -648,6 +716,7 @@ def connection( directives=directives or (), extensions=extensions, sessionmaker=sessionmaker, + keyset=keyset, ) if resolver: diff --git a/src/strawberry_sqlalchemy_mapper/relay.py b/src/strawberry_sqlalchemy_mapper/relay.py index 7e538e6..3777145 100644 --- a/src/strawberry_sqlalchemy_mapper/relay.py +++ b/src/strawberry_sqlalchemy_mapper/relay.py @@ -4,6 +4,7 @@ TYPE_CHECKING, Any, Iterable, + List, Optional, Type, TypeVar, @@ -12,19 +13,25 @@ overload, ) +import sqlakeyset +import strawberry from sqlalchemy import and_, or_ from sqlalchemy.exc import NoResultFound +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.inspection import inspect as sqlalchemy_inspect from strawberry import relay from strawberry.relay.exceptions import NodeIDAnnotationError +from strawberry.relay.types import NodeType +from strawberry.type import StrawberryContainer, get_object_definition if TYPE_CHECKING: - from typing_extensions import Literal + from typing_extensions import Literal, Self from sqlalchemy.orm import Query, Session from strawberry.types.info import Info from strawberry.utils.await_maybe import AwaitableOrValue + from strawberry_sqlalchemy_mapper.field import StrawberrySQLAlchemyAsyncQuery from strawberry_sqlalchemy_mapper.mapper import ( WithStrawberrySQLAlchemyObjectDefinition, ) @@ -41,6 +48,102 @@ ] +@strawberry.type(description="An edge in a connection.") +class Edge(relay.Edge[NodeType]): + @classmethod + def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self: + return cls(cursor=cursor, node=node) + + +@strawberry.type(name="Connection", description="A connection to a list of items.") +class KeysetConnection(relay.Connection[NodeType]): + edges: List[Edge[NodeType]] = strawberry.field( + description="Contains the nodes in this connection", + ) + + @classmethod + def resolve_connection( + cls, + nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + **kwargs: Any, + ) -> AwaitableOrValue[Self]: + from .field import StrawberrySQLAlchemyAsyncQuery, connection_session + + if first and last: + raise ValueError("Cannot provide both `first` and `last`") + elif first and before: + raise ValueError("`first` cannot be provided with `before`") + elif last and after: + raise ValueError("`last` cannot be provided with `after`") + + max_results = info.schema.config.relay_max_results + per_page = first or last or max_results + if per_page > max_results: + raise ValueError(f"Argument 'last' cannot be higher than {max_results}.") + + session = connection_session.get() + assert session is not None + + def resolve_connection(page: sqlakeyset.Page): + type_def = get_object_definition(cls) + assert type_def + field_def = type_def.get_field("edges") + assert field_def + + field = field_def.resolve_type(type_definition=type_def) + while isinstance(field, StrawberryContainer): + field = field.of_type + + edge_class = cast(Edge[NodeType], field) + + return cls( + page_info=relay.PageInfo( + has_next_page=page.paging.has_next, + has_previous_page=page.paging.has_previous, + start_cursor=page.paging.get_bookmark_at(0) if page else None, + end_cursor=page.paging.get_bookmark_at(-1) if page else None, + ), + edges=[ + edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i)) + for i, n in enumerate(page) + ], + ) + + def resolve_nodes(s: Session, nodes=nodes): + if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery): + nodes = nodes.query(s) + + return resolve_connection( + sqlakeyset.get_page( + nodes, + before=( + sqlakeyset.unserialize_bookmark(before).place + if before + else None + ), + after=( + sqlakeyset.unserialize_bookmark(after).place if after else None + ), + per_page=per_page, + ) + ) + + if isinstance(session, AsyncSession): + + async def resolve_async(nodes=nodes): + return await session.run_sync(lambda s: resolve_nodes(s)) + + return resolve_async() + + return resolve_nodes(session) + + @overload def resolve_model_nodes( source: Union[ diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index 0f838c0..31160c0 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import sessionmaker from strawberry import relay from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper, connection +from strawberry_sqlalchemy_mapper.relay import KeysetConnection @pytest.fixture @@ -485,3 +486,272 @@ class Query: ] } } + + +def test_query_keyset( + base: Any, + engine: Engine, + sessionmaker: sessionmaker, + fruit_table, +): + base.metadata.create_all(engine) + mapper = StrawberrySQLAlchemyMapper() + + @mapper.type(fruit_table) + class Fruit(relay.Node): + id: relay.NodeID[int] + name: str + + @strawberry.type + class Query: + fruits: KeysetConnection[Fruit] = connection( + sessionmaker=sessionmaker, + keyset=(fruit_table.name,), + ) + + schema = strawberry.Schema(query=Query) + + query = """\ + query Fruits($first: Int, $after: String) { + fruits(first: $first, after: $after) { + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + name + } + } + } + } + """ + + with sessionmaker() as session: + f1 = fruit_table(name="Banana", color="Yellow") + f2 = fruit_table(name="Apple", color="Red") + f3 = fruit_table(name="Orange", color="Orange") + f4 = fruit_table(name="Mango", color="Orange") + f5 = fruit_table(name="Grape", color="Purple") + session.add_all([f1, f2, f3, f4, f5]) + session.commit() + + result = schema.execute_sync(query) + assert result.errors is None + assert result.data == { + "fruits": { + "edges": [ + { + "cursor": ">s:Apple", + "node": {"name": "Apple"}, + }, + { + "cursor": ">s:Banana", + "node": {"name": "Banana"}, + }, + { + "cursor": ">s:Grape", + "node": {"name": "Grape"}, + }, + { + "cursor": ">s:Mango", + "node": {"name": "Mango"}, + }, + { + "cursor": ">s:Orange", + "node": {"name": "Orange"}, + }, + ], + "pageInfo": { + "endCursor": ">s:Orange", + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": ">s:Apple", + }, + } + } + + result = schema.execute_sync(query, {"first": 2}) + assert result.errors is None + assert result.data == { + "fruits": { + "edges": [ + { + "cursor": ">s:Apple", + "node": {"name": "Apple"}, + }, + { + "cursor": ">s:Banana", + "node": {"name": "Banana"}, + }, + ], + "pageInfo": { + "endCursor": ">s:Banana", + "hasNextPage": True, + "hasPreviousPage": False, + "startCursor": ">s:Apple", + }, + } + } + + result = schema.execute_sync(query, {"first": 2, "after": ">s:Banana"}) + assert result.errors is None + assert result.data == { + "fruits": { + "edges": [ + { + "cursor": ">s:Grape", + "node": {"name": "Grape"}, + }, + { + "cursor": ">s:Mango", + "node": {"name": "Mango"}, + }, + ], + "pageInfo": { + "endCursor": ">s:Mango", + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": ">s:Grape", + }, + } + } + + +@pytest.mark.asyncio +async def test_query_keyset_async( + base: Any, + async_engine: AsyncEngine, + sessionmaker: sessionmaker, + async_sessionmaker, + fruit_table, +): + async with async_engine.begin() as conn: + await conn.run_sync(base.metadata.create_all) + mapper = StrawberrySQLAlchemyMapper() + + @mapper.type(fruit_table) + class Fruit(relay.Node): + id: relay.NodeID[int] + name: str + + @strawberry.type + class Query: + fruits: KeysetConnection[Fruit] = connection( + sessionmaker=async_sessionmaker, + keyset=(fruit_table.name,), + ) + + schema = strawberry.Schema(query=Query) + + query = """\ + query Fruits($first: Int, $after: String) { + fruits(first: $first, after: $after) { + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + name + } + } + } + } + """ + + async with async_sessionmaker(expire_on_commit=False) as session: + f1 = fruit_table(name="Banana", color="Yellow") + f2 = fruit_table(name="Apple", color="Red") + f3 = fruit_table(name="Orange", color="Orange") + f4 = fruit_table(name="Mango", color="Orange") + f5 = fruit_table(name="Grape", color="Purple") + session.add_all([f1, f2, f3, f4, f5]) + await session.commit() + + result = await schema.execute(query) + assert result.errors is None + assert result.data == { + "fruits": { + "edges": [ + { + "cursor": ">s:Apple", + "node": {"name": "Apple"}, + }, + { + "cursor": ">s:Banana", + "node": {"name": "Banana"}, + }, + { + "cursor": ">s:Grape", + "node": {"name": "Grape"}, + }, + { + "cursor": ">s:Mango", + "node": {"name": "Mango"}, + }, + { + "cursor": ">s:Orange", + "node": {"name": "Orange"}, + }, + ], + "pageInfo": { + "endCursor": ">s:Orange", + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": ">s:Apple", + }, + } + } + + result = await schema.execute(query, {"first": 2}) + assert result.errors is None + assert result.data == { + "fruits": { + "edges": [ + { + "cursor": ">s:Apple", + "node": {"name": "Apple"}, + }, + { + "cursor": ">s:Banana", + "node": {"name": "Banana"}, + }, + ], + "pageInfo": { + "endCursor": ">s:Banana", + "hasNextPage": True, + "hasPreviousPage": False, + "startCursor": ">s:Apple", + }, + } + } + + result = await schema.execute(query, {"first": 2, "after": ">s:Banana"}) + assert result.errors is None + assert result.data == { + "fruits": { + "edges": [ + { + "cursor": ">s:Grape", + "node": {"name": "Grape"}, + }, + { + "cursor": ">s:Mango", + "node": {"name": "Mango"}, + }, + ], + "pageInfo": { + "endCursor": ">s:Mango", + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": ">s:Grape", + }, + } + }