Skip to content

Commit

Permalink
feat: add sqlakeyset integration
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Nov 4, 2023
1 parent 9ff8723 commit 05bf3e6
Show file tree
Hide file tree
Showing 5 changed files with 477 additions and 17 deletions.
21 changes: 19 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ sqlalchemy = {extras = ["asyncio"], version = ">=1.4"}
strawberry-graphql = ">=0.95"
sentinel = ">=0.3,<1.1"
greenlet = {version = ">=3.0.0rc1", python = ">=3.12"}
sqlakeyset = "^2.0.1695177552"

[tool.poetry.group.dev.dependencies]
asyncpg = "^0.28.0"
Expand Down
97 changes: 83 additions & 14 deletions src/strawberry_sqlalchemy_mapper/field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import contextlib
import contextvars
import dataclasses
import inspect
from collections import defaultdict
Expand All @@ -9,6 +11,7 @@
Callable,
DefaultDict,
Dict,
Generator,
Iterable,
Iterator,
List,
Expand All @@ -25,6 +28,7 @@
)
from typing_extensions import Annotated, TypeAlias

from sqlakeyset.types import Keyset
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Query, Session
from strawberry import relay
Expand Down Expand Up @@ -55,6 +59,25 @@
assert argument # type: ignore[truthy-function]


connection_session: contextvars.ContextVar[
Union[Session, AsyncSession, None]
] = contextvars.ContextVar(
"connection-session",
default=None,
)


@contextlib.contextmanager
def set_connection_session(
s: Union[Session, AsyncSession, None]
) -> Generator[None, None, None]:
token = connection_session.set(s)
try:
yield
finally:
connection_session.reset(token)


class StrawberrySQLAlchemyField(StrawberryField):
"""
Base field for SQLAlchemy types.
Expand All @@ -63,9 +86,11 @@ class StrawberrySQLAlchemyField(StrawberryField):
def __init__(
self,
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
**kwargs,
):
self.sessionmaker = sessionmaker
self.keyset = keyset
super().__init__(**kwargs)


Expand Down Expand Up @@ -260,6 +285,7 @@ class StrawberrySQLAlchemyConnectionExtension(relay.ConnectionExtension):
def apply(self, field: StrawberrySQLAlchemyField) -> None: # type: ignore[override]
from strawberry_sqlalchemy_mapper.mapper import StrawberrySQLAlchemyType

self.field = field
strawberry_definition = get_object_definition(field.type, strict=True)
node_type = strawberry_definition.type_var_map.get("NodeType")
if node_type is None:
Expand Down Expand Up @@ -295,35 +321,74 @@ def default_resolver(
info: Info,
**kwargs: Any,
) -> Iterable[Any]:
session = field_sessionmaker()
session = connection_session.get()
if session is None:
session = field_sessionmaker()

if isinstance(session, AsyncSession):
def _get_query(s: Session):
if root is not None:
return cast(
Iterable[Any],
StrawberrySQLAlchemyAsyncQuery(
session=session,
query=getattr(root, field.python_name),
),
)
# root won't be None when resolving nested connections.
# TODO: Maybe we want to send this to a dataloader?
query = getattr(root, field.python_name)
else:
query = s.query(model)

if field.keyset is not None:
query = query.order_by(*field.keyset)

return query

if isinstance(session, AsyncSession):
return cast(
Iterable[Any],
StrawberrySQLAlchemyAsyncQuery(
session=session,
query=lambda s: s.query(model),
query=lambda s: _get_query(s),
),
)

if root is not None:
return getattr(root, field.python_name)

return session.query(model)
return _get_query(session)

field.base_resolver = StrawberryResolver(default_resolver)

return super().apply(field)

def resolve(self, *args, **kwargs) -> Any:
if (field_sessionmaker := self.field.sessionmaker) is None:
raise TypeError(f"Missing `sessionmaker` argument for field {field.name}")

session = field_sessionmaker()

if isinstance(session, AsyncSession):
super_meth = super().resolve

async def inner_resolve_async():
async with session as s:
with set_connection_session(s):
retval = super_meth(*args, **kwargs)
if inspect.isawaitable(retval):
retval = await retval
return retval

return inner_resolve_async()

with session as s, set_connection_session(s):
return super().resolve(*args, **kwargs)

async def resolve_async(self, *args, **kwargs) -> Any:
if (field_sessionmaker := self.field.sessionmaker) is None:
raise TypeError(f"Missing `sessionmaker` argument for field {field.name}")

session = field_sessionmaker()

if isinstance(session, AsyncSession):
async with session as s:
with set_connection_session(s):
return await super().resolve_async(*args, **kwargs)

with session as s, set_connection_session(s):
return await super().resolve_async(*args, **kwargs)


@overload
def field(
Expand Down Expand Up @@ -527,6 +592,7 @@ def connection(
directives: Sequence[object] | None = (),
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
) -> Any:
...

Expand All @@ -549,6 +615,7 @@ def connection(
directives: Sequence[object] | None = (),
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
) -> Any:
...

Expand All @@ -569,6 +636,7 @@ def connection(
directives: Sequence[object] | None = (),
extensions: Sequence[FieldExtension] = (),
sessionmaker: _SessionMaker | None = None,
keyset: Keyset | None = None,
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
# any behavior at the moment.
Expand Down Expand Up @@ -648,6 +716,7 @@ def connection(
directives=directives or (),
extensions=extensions,
sessionmaker=sessionmaker,
keyset=keyset,
)

if resolver:
Expand Down
105 changes: 104 additions & 1 deletion src/strawberry_sqlalchemy_mapper/relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TYPE_CHECKING,
Any,
Iterable,
List,
Optional,
Type,
TypeVar,
Expand All @@ -12,19 +13,25 @@
overload,
)

import sqlakeyset
import strawberry
from sqlalchemy import and_, or_
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
from strawberry import relay
from strawberry.relay.exceptions import NodeIDAnnotationError
from strawberry.relay.types import NodeType
from strawberry.type import StrawberryContainer, get_object_definition

if TYPE_CHECKING:
from typing_extensions import Literal
from typing_extensions import Literal, Self

from sqlalchemy.orm import Query, Session
from strawberry.types.info import Info
from strawberry.utils.await_maybe import AwaitableOrValue

from strawberry_sqlalchemy_mapper.field import StrawberrySQLAlchemyAsyncQuery
from strawberry_sqlalchemy_mapper.mapper import (
WithStrawberrySQLAlchemyObjectDefinition,
)
Expand All @@ -41,6 +48,102 @@
]


@strawberry.type(description="An edge in a connection.")
class Edge(relay.Edge[NodeType]):
@classmethod
def resolve_edge(cls, node: NodeType, *, cursor: Any = None) -> Self:
return cls(cursor=cursor, node=node)


@strawberry.type(name="Connection", description="A connection to a list of items.")
class KeysetConnection(relay.Connection[NodeType]):
edges: List[Edge[NodeType]] = strawberry.field(
description="Contains the nodes in this connection",
)

@classmethod
def resolve_connection(
cls,
nodes: Union[Query, StrawberrySQLAlchemyAsyncQuery],
*,
info: Info,
before: Optional[str] = None,
after: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
**kwargs: Any,
) -> AwaitableOrValue[Self]:
from .field import StrawberrySQLAlchemyAsyncQuery, connection_session

if first and last:
raise ValueError("Cannot provide both `first` and `last`")
elif first and before:
raise ValueError("`first` cannot be provided with `before`")
elif last and after:
raise ValueError("`last` cannot be provided with `after`")

max_results = info.schema.config.relay_max_results
per_page = first or last or max_results
if per_page > max_results:
raise ValueError(f"Argument 'last' cannot be higher than {max_results}.")

session = connection_session.get()
assert session is not None

def resolve_connection(page: sqlakeyset.Page):
type_def = get_object_definition(cls)
assert type_def
field_def = type_def.get_field("edges")
assert field_def

field = field_def.resolve_type(type_definition=type_def)
while isinstance(field, StrawberryContainer):
field = field.of_type

edge_class = cast(Edge[NodeType], field)

return cls(
page_info=relay.PageInfo(
has_next_page=page.paging.has_next,
has_previous_page=page.paging.has_previous,
start_cursor=page.paging.get_bookmark_at(0) if page else None,
end_cursor=page.paging.get_bookmark_at(-1) if page else None,
),
edges=[
edge_class.resolve_edge(n, cursor=page.paging.get_bookmark_at(i))
for i, n in enumerate(page)
],
)

def resolve_nodes(s: Session, nodes=nodes):
if isinstance(nodes, StrawberrySQLAlchemyAsyncQuery):
nodes = nodes.query(s)

return resolve_connection(
sqlakeyset.get_page(
nodes,
before=(
sqlakeyset.unserialize_bookmark(before).place
if before
else None
),
after=(
sqlakeyset.unserialize_bookmark(after).place if after else None
),
per_page=per_page,
)
)

if isinstance(session, AsyncSession):

async def resolve_async(nodes=nodes):
return await session.run_sync(lambda s: resolve_nodes(s))

return resolve_async()

return resolve_nodes(session)


@overload
def resolve_model_nodes(
source: Union[
Expand Down
Loading

0 comments on commit 05bf3e6

Please sign in to comment.