-
-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: DTO factory narrowed with a generic alias. #2791
base: main
Are you sure you want to change the base?
Changes from 1 commit
2f2df02
7c318ec
a02e76d
26f7dfc
2226f5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
from dataclasses import MISSING, fields, replace | ||
from typing import TYPE_CHECKING, Generic, TypeVar | ||
|
||
from typing_extensions import get_origin | ||
|
||
from litestar.dto.base_dto import AbstractDTO | ||
from litestar.dto.data_structures import DTOFieldDefinition | ||
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField | ||
|
@@ -29,7 +31,8 @@ class DataclassDTO(AbstractDTO[T], Generic[T]): | |
def generate_field_definitions( | ||
cls, model_type: type[DataclassProtocol] | ||
) -> Generator[DTOFieldDefinition, None, None]: | ||
dc_fields = {f.name: f for f in fields(model_type)} | ||
model_origin = get_origin(model_type) or model_type | ||
dc_fields = {f.name: f for f in fields(model_origin)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll have to test this against all of the other dto factory types that support generics too, b/c there is bound to be cases where the |
||
for key, field_definition in cls.get_model_type_hints(model_type).items(): | ||
if not (dc_field := dc_fields.get(key)): | ||
continue | ||
|
@@ -41,7 +44,7 @@ def generate_field_definitions( | |
field_definition=field_definition, | ||
default_factory=default_factory, | ||
dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()), | ||
model_name=model_type.__name__, | ||
model_name=model_origin.__name__, | ||
), | ||
name=key, | ||
default=default, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,21 @@ | |
from copy import deepcopy | ||
from dataclasses import dataclass, is_dataclass, replace | ||
from inspect import Parameter, Signature | ||
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, Protocol, Sequence, TypeVar, cast | ||
from typing import ( # type: ignore[attr-defined] | ||
Any, | ||
AnyStr, | ||
Callable, | ||
ClassVar, | ||
Collection, | ||
ForwardRef, | ||
Literal, | ||
Mapping, | ||
Protocol, | ||
Sequence, | ||
TypeVar, | ||
_GenericAlias, # pyright: ignore | ||
cast, | ||
) | ||
|
||
from msgspec import UnsetType | ||
from typing_extensions import NotRequired, Required, Self, get_args, get_origin, get_type_hints, is_typeddict | ||
|
@@ -442,6 +456,19 @@ def is_subclass_of(self, cl: type[Any] | tuple[type[Any], ...]) -> bool: | |
if self.origin in UnionTypes: | ||
return all(t.is_subclass_of(cl) for t in self.inner_types) | ||
|
||
if isinstance(self.annotation, _GenericAlias) and self.origin not in (ClassVar, Literal): | ||
cl_args = get_args(cl) | ||
cl_origin = get_origin(cl) or cl | ||
return ( | ||
issubclass(self.origin, cl_origin) | ||
and (len(cl_args) == len(self.args) if cl_args else True) | ||
and ( | ||
all(t.is_subclass_of(cl_arg) for t, cl_arg in zip(self.inner_types, cl_args)) | ||
if cl_args | ||
else True | ||
) | ||
) | ||
|
||
Comment on lines
+459
to
+471
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Trying to determine when a This says that when
With 2 and 3, if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll have to add a lot of tests for this. |
||
return self.origin not in UnionTypes and is_class_and_subclass(self.origin, cl) | ||
|
||
if self.annotation is AnyStr: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -249,7 +249,9 @@ def get_type_hints_with_generics_resolved( | |
if origin is None: | ||
# Implies the generic types have not been specified in the annotation | ||
type_hints = get_type_hints(annotation, globalns=globalns, localns=localns, include_extras=include_extras) | ||
typevar_map = {p: p for p in annotation.__parameters__} | ||
if not (parameters := getattr(annotation, "__parameters__", None)): | ||
return type_hints | ||
typevar_map = {p: p for p in parameters} | ||
Comment on lines
+252
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to support using this function without knowing up front if the type we are passing in is a generic type or not. If not, it won't have the parameters attribute, and so we just return the type hints without any post processing. |
||
else: | ||
type_hints = get_type_hints(origin, globalns=globalns, localns=localns, include_extras=include_extras) | ||
# the __parameters__ is only available on the origin itself and not the annotation | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,15 @@ | |
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Tuple, TypeVar, Union | ||
from typing import TYPE_CHECKING, Generic, Tuple, TypeVar, Union | ||
|
||
import pytest | ||
from typing_extensions import Annotated | ||
|
||
from litestar import Request | ||
from litestar.dto import DataclassDTO, DTOConfig | ||
from litestar.exceptions.dto_exceptions import InvalidAnnotationException | ||
from litestar.types.empty import Empty | ||
from litestar.typing import FieldDefinition | ||
|
||
from . import Model | ||
|
@@ -19,7 +20,8 @@ | |
|
||
from litestar.dto._backend import DTOBackend | ||
|
||
T = TypeVar("T", bound=Model) | ||
T = TypeVar("T") | ||
ModelT = TypeVar("ModelT", bound=Model) | ||
|
||
|
||
def get_backend(dto_type: type[DataclassDTO[Any]]) -> DTOBackend: | ||
|
@@ -77,7 +79,7 @@ def test_extra_annotated_metadata_ignored() -> None: | |
|
||
def test_overwrite_config() -> None: | ||
first = DTOConfig(exclude={"a"}) | ||
generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore | ||
generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore | ||
second = DTOConfig(exclude={"b"}) | ||
dto = generic_dto[Annotated[Model, second]] # pyright: ignore | ||
assert dto.config is second | ||
|
@@ -86,13 +88,13 @@ def test_overwrite_config() -> None: | |
def test_existing_config_not_overwritten() -> None: | ||
assert getattr(DataclassDTO, "_config", None) is None | ||
first = DTOConfig(exclude={"a"}) | ||
generic_dto = DataclassDTO[Annotated[T, first]] # pyright: ignore | ||
generic_dto = DataclassDTO[Annotated[ModelT, first]] # pyright: ignore | ||
dto = generic_dto[Model] # pyright: ignore | ||
assert dto.config is first | ||
|
||
|
||
def test_config_assigned_via_subclassing() -> None: | ||
class CustomGenericDTO(DataclassDTO[T]): | ||
class CustomGenericDTO(DataclassDTO[ModelT]): | ||
config = DTOConfig(exclude={"a"}) | ||
|
||
concrete_dto = CustomGenericDTO[Model] | ||
|
@@ -161,3 +163,28 @@ class SubType(Model): | |
assert ( | ||
dto_type._dto_backends["handler_id"]["data_backend"].parsed_field_definitions[-1].name == "c" # pyright: ignore | ||
) | ||
|
||
|
||
def test_type_narrowing_with_generic_type() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should name it |
||
@dataclass | ||
class Foo(Generic[T]): | ||
foo: T | ||
|
||
hints = DataclassDTO.get_model_type_hints(Foo[int]) | ||
assert hints == { | ||
"foo": FieldDefinition( | ||
raw=int, | ||
annotation=int, | ||
type_wrappers=(), | ||
origin=None, | ||
args=(), | ||
metadata=(), | ||
instantiable_origin=None, | ||
safe_generic_origin=None, | ||
inner_types=(), | ||
default=Empty, | ||
extra={}, | ||
kwarg_definition=None, | ||
name="foo", | ||
) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this feature, anywhere that we could previously assume that the type that narrowed the dto was just a regular class, we now have to account for the fact that it could be an instance of
_GenericAlias
.In the my first pass of this PR I've taken the quickest and dirtiest approach to get things passing, but we might need some abstraction over the type that handles the differences.