diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ceed9c6bf..eb2588eb9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,7 +26,7 @@ Other changes: As a consequence of this change: - Time with time offsets are now supported. - YYYY-MM-DD is now accepted as a datetime and deserialized as naive 00:00 AM. - - `from_iso_date`, `from_iso_time` and `from_iso_datetime` are removed from `marshmallow.utils` + - `from_iso_date`, `from_iso_time` and `from_iso_datetime` are removed from `marshmallow.utils`. - *Backwards-incompatible*: Custom validators must raise a `ValidationError ` for invalid values. Returning `False` is no longer supported (:issue:`1775`). @@ -56,6 +56,33 @@ As a consequence of this change: Thanks :user:`ddelange` for the PR. +- *Backwards-incompatible*: Remove `Schema `'s ``context`` attribute. Passing a context + should be done using `contextvars.ContextVar` (:issue:`1826`). + marshmallow 4 provides an experimental `Context ` + manager class that can be used to both set and retrieve context. + +.. code-block:: python + + import typing + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + class UserContext(typing.TypedDict): + suffix: str + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda obj: obj["name"] + Context[UserContext].get()["suffix"] + ) + + + with Context[UserContext]({"suffix": "bar"}): + UserSchema().dump({"name": "foo"}) + # {'name_suffixed': 'foobar'} + Deprecations/Removals: - *Backwards-incompatible*: Remove implicit field creation, i.e. using the ``fields`` or ``additional`` class Meta options with undeclared fields (:issue:`1356`). diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 68bc78484..d38e56d2c 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -10,6 +10,7 @@ API Reference marshmallow.decorators marshmallow.validate marshmallow.utils + marshmallow.experimental.context marshmallow.error_store marshmallow.class_registry marshmallow.exceptions diff --git a/docs/custom_fields.rst b/docs/custom_fields.rst index a5f8cf94f..f14b54a96 100644 --- a/docs/custom_fields.rst +++ b/docs/custom_fields.rst @@ -95,38 +95,72 @@ Both :class:`Function ` and :class:`Method 100.0 -.. _adding-context: +.. _using_context: -Adding context to `Method` and `Function` fields ------------------------------------------------- +Using context +------------- -A :class:`Function ` or :class:`Method ` field may need information about its environment to know how to serialize a value. +A field may need information about its environment to know how to (de)serialize a value. -In these cases, you can set the ``context`` attribute (a dictionary) of a `Schema`. :class:`Function ` and :class:`Method ` fields will have access to this dictionary. +You can use the experimental `Context ` class +to set and retrieve context. -As an example, you might want your ``UserSchema`` to output whether or not a ``User`` is the author of a ``Blog`` or whether a certain word appears in a ``Blog's`` title. +Let's say your ``UserSchema`` needs to output +whether or not a ``User`` is the author of a ``Blog`` or +whether a certain word appears in a ``Blog's`` title. .. code-block:: python + import typing + from dataclasses import dataclass + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + @dataclass + class User: + name: str + + + @dataclass + class Blog: + title: str + author: User + + + class ContextDict(typing.TypedDict): + blog: Blog + + class UserSchema(Schema): name = fields.String() - # Function fields optionally receive context argument - is_author = fields.Function(lambda user, context: user == context["blog"].author) + + is_author = fields.Function( + lambda user: user == Context[ContextDict].get()["blog"].author + ) likes_bikes = fields.Method("writes_about_bikes") - def writes_about_bikes(self, user): - return "bicycle" in self.context["blog"].title.lower() + def writes_about_bikes(self, user: User) -> bool: + return "bicycle" in Context[ContextDict].get()["blog"].title.lower() +.. note:: + You can use `Context.get ` + within custom fields, pre-/post-processing methods, and validators. + +When (de)serializing, set the context by using `Context ` as a context manager. + +.. code-block:: python - schema = UserSchema() user = User("Freddie Mercury", "fred@queen.com") blog = Blog("Bicycle Blog", author=user) - schema.context = {"blog": blog} - result = schema.dump(user) - result["is_author"] # => True - result["likes_bikes"] # => True + schema = UserSchema() + with Context({"blog": blog}): + result = schema.dump(user) + print(result["is_author"]) # => True + print(result["likes_bikes"]) # => True Customizing error messages diff --git a/docs/extending.rst b/docs/extending.rst index 7accd4b43..54b96acd0 100644 --- a/docs/extending.rst +++ b/docs/extending.rst @@ -454,19 +454,6 @@ Our application schemas can now inherit from our custom schema class. result = ser.dump(user) result # {"user": {"name": "Keith", "email": "keith@stones.com"}} -Using context -------------- - -The ``context`` attribute of a `Schema` is a general-purpose store for extra information that may be needed for (de)serialization. It may be used in both ``Schema`` and ``Field`` methods. - -.. code-block:: python - - schema = UserSchema() - # Make current HTTP request available to - # custom fields, schema methods, schema validators, etc. - schema.context["request"] = request - schema.dump(user) - Custom error messages --------------------- diff --git a/docs/marshmallow.experimental.context.rst b/docs/marshmallow.experimental.context.rst new file mode 100644 index 000000000..50f8e0e61 --- /dev/null +++ b/docs/marshmallow.experimental.context.rst @@ -0,0 +1,5 @@ +Context (experimental) +====================== + +.. automodule:: marshmallow.experimental.context + :members: diff --git a/docs/upgrading.rst b/docs/upgrading.rst index 27a51fa66..7e1d76561 100644 --- a/docs/upgrading.rst +++ b/docs/upgrading.rst @@ -124,6 +124,58 @@ If you want to use anonymous functions, you can use this helper function. class UserSchema(Schema): password = fields.String(validate=predicate(lambda x: x == "password")) +New context API +*************** + +Passing context to `Schema ` classes is no longer supported. Use `contextvars.ContextVar` for passing context to +fields, pre-/post-processing methods, and validators instead. + +marshmallow 4 provides an experimental `Context ` +manager class that can be used to both set and retrieve context. + +.. code-block:: python + + # 3.x + from marshmallow import Schema, fields + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda obj, context: obj["name"] + context["suffix"] + ) + + + user_schema = UserSchema() + user_schema.context = {"suffix": "bar"} + user_schema.dump({"name": "foo"}) + # {'name_suffixed': 'foobar'} + + # 4.x + import typing + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + class UserContext(typing.TypedDict): + suffix: str + + + UserSchemaContext = Context[UserContext] + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda obj: obj["name"] + UserSchemaContext.get()["suffix"] + ) + + + with UserSchemaContext({"suffix": "bar"}): + UserSchema().dump({"name": "foo"}) + # {'name_suffixed': 'foobar'} + +See :ref:`using_context` for more information. + Implicit field creation is removed ********************************** @@ -237,8 +289,8 @@ if you need to change the final output type. ``pass_many`` is renamed to ``pass_collection`` in decorators ************************************************************* -The ``pass_many`` argument to `pre_load `, -`post_load `, `pre_dump `, +The ``pass_many`` argument to `pre_load `, +`post_load `, `pre_dump `, and `post_dump ` is renamed to ``pass_collection``. The behavior is unchanged. @@ -309,7 +361,7 @@ Upgrading to 3.13 ``load_default`` and ``dump_default`` +++++++++++++++++++++++++++++++++++++ -The ``missing`` and ``default`` parameters of fields are renamed to +The ``missing`` and ``default`` parameters of fields are renamed to ``load_default`` and ``dump_default``, respectively. .. code-block:: python @@ -330,6 +382,7 @@ The ``missing`` and ``default`` parameters of fields are renamed to ``load_default`` and ``dump_default`` are passed to the field constructor as keyword arguments. + Upgrading to 3.3 ++++++++++++++++ diff --git a/docs/why.rst b/docs/why.rst index be1fa8849..9e3b61ccd 100644 --- a/docs/why.rst +++ b/docs/why.rst @@ -55,39 +55,6 @@ In this example, a single schema produced three different outputs! The dynamic n .. _Django REST Framework: https://www.django-rest-framework.org/ .. _Flask-RESTful: https://flask-restful.readthedocs.io/ - -Context-aware serialization ---------------------------- - -marshmallow schemas can modify their output based on the context in which they are used. Field objects have access to a ``context`` dictionary that can be changed at runtime. - -Here's a simple example that shows how a `Schema ` can anonymize a person's name when a boolean is set on the context. - -.. code-block:: python - - class PersonSchema(Schema): - id = fields.Integer() - name = fields.Method("get_name") - - def get_name(self, person, context): - if context.get("anonymize"): - return "" - return person.name - - - person = Person(name="Monty") - schema = PersonSchema() - schema.dump(person) # {'id': 143, 'name': 'Monty'} - - # In a different context, anonymize the name - schema.context["anonymize"] = True - schema.dump(person) # {'id': 143, 'name': ''} - - -.. seealso:: - - See the relevant section of the :ref:`usage guide ` to learn more about context-aware serialization. - Advanced schema nesting ----------------------- diff --git a/src/marshmallow/experimental/__init__.py b/src/marshmallow/experimental/__init__.py new file mode 100644 index 000000000..b8f6f65ba --- /dev/null +++ b/src/marshmallow/experimental/__init__.py @@ -0,0 +1,5 @@ +"""Experimental features. + +The features in this subpackage are experimental. Breaking changes may be +introduced in minor marshmallow versions. +""" diff --git a/src/marshmallow/experimental/context.py b/src/marshmallow/experimental/context.py new file mode 100644 index 000000000..bd06d5fb8 --- /dev/null +++ b/src/marshmallow/experimental/context.py @@ -0,0 +1,61 @@ +"""Helper API for setting serialization/deserialization context. + +Example usage: + +.. code-block:: python + + import typing + + from marshmallow import Schema, fields + from marshmallow.experimental.context import Context + + + class UserContext(typing.TypedDict): + suffix: str + + + UserSchemaContext = Context[UserContext] + + + class UserSchema(Schema): + name_suffixed = fields.Function( + lambda user: user["name"] + UserSchemaContext.get()["suffix"] + ) + + + with UserSchemaContext({"suffix": "bar"}): + print(UserSchema().dump({"name": "foo"})) + # {'name_suffixed': 'foobar'} +""" + +import contextlib +import contextvars +import typing + +_T = typing.TypeVar("_T") +_CURRENT_CONTEXT: contextvars.ContextVar = contextvars.ContextVar("context") + + +class Context(contextlib.AbstractContextManager, typing.Generic[_T]): + """Context manager for setting and retrieving context.""" + + def __init__(self, context: _T) -> None: + self.context = context + self.token: contextvars.Token | None = None + + def __enter__(self) -> None: + self.token = _CURRENT_CONTEXT.set(self.context) + + def __exit__(self, *args, **kwargs) -> None: + _CURRENT_CONTEXT.reset(typing.cast(contextvars.Token, self.token)) + + @classmethod + def get(cls, default=...) -> _T: + """Get the current context. + + :param default: Default value to return if no context is set. + If not provided and no context is set, a :exc:`LookupError` is raised. + """ + if default is not ...: + return _CURRENT_CONTEXT.get(default) + return _CURRENT_CONTEXT.get() diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 072ab1f97..3a96aa048 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -164,6 +164,9 @@ class Field(typing.Generic[_InternalType]): .. versionchanged:: 3.13.0 Replace ``missing`` and ``default`` parameters with ``load_default`` and ``dump_default``. + + .. versionchanged:: 4.0.0 + Remove ``context`` property. """ # Some fields, such as Method fields and Function fields, are not expected @@ -442,15 +445,6 @@ def _deserialize( """ return value - # Properties - - @property - def context(self) -> dict | None: - """The context dictionary for the parent :class:`Schema`.""" - if self.parent: - return self.parent.context - return None - class Raw(Field[typing.Any]): """Field that applies no formatting.""" @@ -540,8 +534,6 @@ def __init__( def schema(self) -> Schema: """The nested Schema object.""" if not self._schema: - # Inherit context from parent. - context = getattr(self.parent, "context", {}) if callable(self.nested) and not isinstance(self.nested, type): nested = self.nested() else: @@ -554,7 +546,6 @@ def schema(self) -> Schema: if isinstance(nested, Schema): self._schema = copy.copy(nested) - self._schema.context.update(context) # Respect only and exclude passed from parent and re-initialize fields set_class = typing.cast(type[set], self._schema.set_class) if self.only is not None: @@ -581,7 +572,6 @@ def schema(self) -> Schema: many=self.many, only=self.only, exclude=self.exclude, - context=context, load_only=self._nested_normalized_option("load_only"), dump_only=self._nested_normalized_option("dump_only"), ) @@ -1994,14 +1984,12 @@ class Function(Field): :param serialize: A callable from which to retrieve the value. The function must take a single argument ``obj`` which is the object - to be serialized. It can also optionally take a ``context`` argument, - which is a dictionary of context variables passed to the serializer. + to be serialized. If no callable is provided then the ```load_only``` flag will be set to True. :param deserialize: A callable from which to retrieve the value. The function must take a single argument ``value`` which is the value - to be deserialized. It can also optionally take a ``context`` argument, - which is a dictionary of context variables passed to the deserializer. + to be deserialized. If no callable is provided then ```value``` will be passed through unchanged. @@ -2010,6 +1998,9 @@ class Function(Field): .. versionchanged:: 3.0.0a1 Removed ``func`` parameter. + + .. versionchanged:: 4.0.0 + Don't pass context to serialization and deserialization functions. """ _CHECK_ATTRIBUTE = False @@ -2036,21 +2027,13 @@ def __init__( self.deserialize_func = deserialize and utils.callable_or_raise(deserialize) def _serialize(self, value, attr, obj, **kwargs): - return self._call_or_raise(self.serialize_func, obj, attr) + return self.serialize_func(obj) def _deserialize(self, value, attr, data, **kwargs): if self.deserialize_func: - return self._call_or_raise(self.deserialize_func, value, attr) + return self.deserialize_func(value) return value - def _call_or_raise(self, func, value, attr): - if len(utils.get_func_args(func)) > 1: - if self.parent.context is None: - msg = f"No context available for Function field {attr!r}" - raise ValidationError(msg) - return func(value, self.parent.context) - return func(value) - _ContantType = typing.TypeVar("_ContantType") diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 1bd72c501..51a909a6d 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -237,8 +237,6 @@ class AlbumSchema(Schema): delimiters. :param many: Should be set to `True` if ``obj`` is a collection so that the object will be serialized to a list. - :param context: Optional context passed to :class:`fields.Method` and - :class:`fields.Function` fields. :param load_only: Fields to skip during serialization (write-only fields) :param dump_only: Fields to skip during deserialization (read-only fields) :param partial: Whether to ignore missing fields and not require @@ -249,7 +247,10 @@ class AlbumSchema(Schema): fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`. .. versionchanged:: 3.0.0 - `prefix` parameter removed. + Remove ``prefix`` parameter. + + .. versionchanged:: 4.0.0 + Remove ``context`` parameter. """ TYPE_MAPPING: dict[type, type[ma_fields.Field]] = { @@ -329,7 +330,6 @@ def __init__( only: types.StrSequenceOrSet | None = None, exclude: types.StrSequenceOrSet = (), many: bool | None = None, - context: dict | None = None, load_only: types.StrSequenceOrSet = (), dump_only: types.StrSequenceOrSet = (), partial: bool | types.StrSequenceOrSet | None = None, @@ -355,7 +355,6 @@ def __init__( if unknown is None else validate_unknown_parameter_value(unknown) ) - self.context = context or {} self._normalize_nested_options() #: Dictionary mapping field_names -> :class:`Field` objects self.fields: dict[str, ma_fields.Field] = {} diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py index 70b556d2a..3f416f272 100644 --- a/src/marshmallow/utils.py +++ b/src/marshmallow/utils.py @@ -3,7 +3,6 @@ from __future__ import annotations import datetime as dt -import functools import inspect import typing from collections.abc import Mapping @@ -239,21 +238,6 @@ def _signature(func: typing.Callable) -> list[str]: return list(inspect.signature(func).parameters.keys()) -def get_func_args(func: typing.Callable) -> list[str]: - """Given a callable, return a list of argument names. Handles - `functools.partial` objects and class-based callables. - - .. versionchanged:: 3.0.0a1 - Do not return bound arguments, eg. ``self``. - """ - if inspect.isfunction(func) or inspect.ismethod(func): - return _signature(func) - if isinstance(func, functools.partial): - return _signature(func.func) - # Callable class - return _signature(func) - - def timedelta_to_microseconds(value: dt.timedelta) -> int: """Compute the total microseconds of a timedelta. diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 000000000..53fa83476 --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,249 @@ +import typing + +import pytest + +from marshmallow import ( + Schema, + fields, + post_dump, + post_load, + pre_dump, + pre_load, + validates, + validates_schema, +) +from marshmallow.exceptions import ValidationError +from marshmallow.experimental.context import Context +from tests.base import Blog, User + + +class UserContextSchema(Schema): + is_owner = fields.Method("get_is_owner") + is_collab = fields.Function( + lambda user: user in Context[dict[str, typing.Any]].get()["blog"] + ) + + def get_is_owner(self, user): + return Context.get()["blog"].user.name == user.name + + +class TestContext: + def test_context_load_dump(self): + class ContextField(fields.Integer): + def _serialize(self, value, attr, obj, **kwargs): + if (context := Context.get(None)) is not None: + value *= context.get("factor", 1) + return super()._serialize(value, attr, obj, **kwargs) + + def _deserialize(self, value, attr, data, **kwargs): + val = super()._deserialize(value, attr, data, **kwargs) + if (context := Context.get(None)) is not None: + val *= context.get("factor", 1) + return val + + class ContextSchema(Schema): + ctx_fld = ContextField() + + ctx_schema = ContextSchema() + + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 1} + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1} + with Context({"factor": 2}): + assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 2} + assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 2} + + def test_context_method(self): + owner = User("Joe") + blog = Blog(title="Joe Blog", user=owner) + serializer = UserContextSchema() + with Context({"blog": blog}): + data = serializer.dump(owner) + assert data["is_owner"] is True + nonowner = User("Fred") + data = serializer.dump(nonowner) + assert data["is_owner"] is False + + def test_context_function(self): + owner = User("Fred") + blog = Blog("Killer Queen", user=owner) + collab = User("Brian") + blog.collaborators.append(collab) + with Context({"blog": blog}): + serializer = UserContextSchema() + data = serializer.dump(collab) + assert data["is_collab"] is True + noncollab = User("Foo") + data = serializer.dump(noncollab) + assert data["is_collab"] is False + + def test_function_field_handles_bound_serializer(self): + class SerializeA: + def __call__(self, value): + return "value" + + serialize = SerializeA() + + # only has a function field + class UserFunctionContextSchema(Schema): + is_collab = fields.Function(serialize) + + owner = User("Joe") + serializer = UserFunctionContextSchema() + data = serializer.dump(owner) + assert data["is_collab"] == "value" + + def test_nested_fields_inherit_context(self): + class InnerSchema(Schema): + likes_bikes = fields.Function(lambda obj: "bikes" in Context.get()["info"]) + + class CSchema(Schema): + inner = fields.Nested(InnerSchema) + + ser = CSchema() + with Context({"info": "i like bikes"}): + obj = {"inner": {}} + result = ser.dump(obj) + assert result["inner"]["likes_bikes"] is True + + # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 + def test_nested_list_fields_inherit_context(self): + class InnerSchema(Schema): + foo = fields.Field() + + @validates("foo") + def validate_foo(self, value): + if "foo_context" not in Context.get(): + raise ValidationError("Missing context") + + class OuterSchema(Schema): + bars = fields.List(fields.Nested(InnerSchema())) + + inner = InnerSchema() + with Context({"foo_context": "foo"}): + assert inner.load({"foo": 42}) + + outer = OuterSchema() + with Context({"foo_context": "foo"}): + assert outer.load({"bars": [{"foo": 42}]}) + + # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 + def test_nested_dict_fields_inherit_context(self): + class InnerSchema(Schema): + foo = fields.Field() + + @validates("foo") + def validate_foo(self, value): + if "foo_context" not in Context.get(): + raise ValidationError("Missing context") + + class OuterSchema(Schema): + bars = fields.Dict(values=fields.Nested(InnerSchema())) + + inner = InnerSchema() + with Context({"foo_context": "foo"}): + assert inner.load({"foo": 42}) + + outer = OuterSchema() + with Context({"foo_context": "foo"}): + assert outer.load({"bars": {"test": {"foo": 42}}}) + + # Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404 + def test_nested_field_with_unpicklable_object_in_context(self): + class Unpicklable: + def __deepcopy__(self, _): + raise NotImplementedError + + class InnerSchema(Schema): + foo = fields.Field() + + class OuterSchema(Schema): + inner = fields.Nested(InnerSchema()) + + outer = OuterSchema() + obj = {"inner": {"foo": 42}} + with Context({"unp": Unpicklable()}): + assert outer.dump(obj) + + def test_function_field_passed_serialize_with_context(self, user): + class Parent(Schema): + pass + + field = fields.Function( + serialize=lambda obj: obj.name.upper() + Context.get()["key"] + ) + field.parent = Parent() + with Context({"key": "BAR"}): + assert field.serialize("key", user) == "MONTYBAR" + + def test_function_field_deserialization_with_context(self): + class Parent(Schema): + pass + + field = fields.Function( + lambda x: None, + deserialize=lambda val: val.upper() + Context.get()["key"], + ) + field.parent = Parent() + with Context({"key": "BAR"}): + assert field.deserialize("foo") == "FOOBAR" + + def test_decorated_processors_with_context(self): + class MySchema(Schema): + f_1 = fields.Integer() + f_2 = fields.Integer() + f_3 = fields.Integer() + f_4 = fields.Integer() + + @pre_dump + def multiply_f_1(self, item, **kwargs): + item["f_1"] *= Context.get()[1] + return item + + @pre_load + def multiply_f_2(self, data, **kwargs): + data["f_2"] *= Context.get()[2] + return data + + @post_dump + def multiply_f_3(self, item, **kwargs): + item["f_3"] *= Context.get()[3] + return item + + @post_load + def multiply_f_4(self, data, **kwargs): + data["f_4"] *= Context.get()[4] + return data + + schema = MySchema() + + with Context({1: 2, 2: 3, 3: 4, 4: 5}): + assert schema.dump({"f_1": 1, "f_2": 1, "f_3": 1, "f_4": 1}) == { + "f_1": 2, + "f_2": 1, + "f_3": 4, + "f_4": 1, + } + assert schema.load({"f_1": 1, "f_2": 1, "f_3": 1, "f_4": 1}) == { + "f_1": 1, + "f_2": 3, + "f_3": 1, + "f_4": 5, + } + + def test_validates_schema_with_context(self): + class MySchema(Schema): + f_1 = fields.Integer() + f_2 = fields.Integer() + + @validates_schema + def validate_schema(self, data, **kwargs): + if data["f_2"] != data["f_1"] * Context.get(): + raise ValidationError("Fail") + + schema = MySchema() + + with Context(2): + schema.load({"f_1": 1, "f_2": 2}) + with pytest.raises(ValidationError) as excinfo: + schema.load({"f_1": 1, "f_2": 3}) + assert excinfo.value.messages["_schema"] == ["Fail"] diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 66e7a2b75..6ba50c423 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -7,7 +7,14 @@ import pytest -from marshmallow import EXCLUDE, INCLUDE, RAISE, Schema, fields, validate +from marshmallow import ( + EXCLUDE, + INCLUDE, + RAISE, + Schema, + fields, + validate, +) from marshmallow.exceptions import ValidationError from tests.base import ( ALL_FIELDS, @@ -994,17 +1001,6 @@ def test_function_field_deserialization_with_callable(self): field = fields.Function(lambda x: None, deserialize=lambda val: val.upper()) assert field.deserialize("foo") == "FOO" - def test_function_field_deserialization_with_context(self): - class Parent(Schema): - pass - - field = fields.Function( - lambda x: None, - deserialize=lambda val, context: val.upper() + context["key"], - ) - field.parent = Parent(context={"key": "BAR"}) - assert field.deserialize("foo") == "FOOBAR" - def test_function_field_passed_deserialize_only_is_load_only(self): field = fields.Function(deserialize=lambda val: val.upper()) assert field.load_only is True diff --git a/tests/test_schema.py b/tests/test_schema.py index 796773ec8..5ec92e410 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -350,15 +350,13 @@ class NestedSchema(Schema): bar = fields.Str() def on_bind_field(self, field_name, field_obj): - field_obj.metadata["fname"] = self.context["fname"] + assert field_obj.parent is self + field_obj.metadata["fname"] = field_name foo = fields.Nested(NestedSchema) - schema1 = MySchema(context={"fname": "foobar"}) - schema2 = MySchema(context={"fname": "quxquux"}) - - assert schema1.fields["foo"].schema.fields["bar"].metadata["fname"] == "foobar" - assert schema2.fields["foo"].schema.fields["bar"].metadata["fname"] == "quxquux" + schema = MySchema() + assert schema.fields["foo"].schema.fields["bar"].metadata["fname"] == "bar" class TestValidate: @@ -2156,153 +2154,6 @@ class ValidatingSchema(Schema): assert "Color must be red or blue" in errors["color"] -class UserContextSchema(Schema): - is_owner = fields.Method("get_is_owner") - is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) - - def get_is_owner(self, user): - return self.context["blog"].user.name == user.name - - -class TestContext: - def test_context_method(self): - owner = User("Joe") - blog = Blog(title="Joe Blog", user=owner) - context = {"blog": blog} - serializer = UserContextSchema() - serializer.context = context - data = serializer.dump(owner) - assert data["is_owner"] is True - nonowner = User("Fred") - data = serializer.dump(nonowner) - assert data["is_owner"] is False - - def test_context_method_function(self): - owner = User("Fred") - blog = Blog("Killer Queen", user=owner) - collab = User("Brian") - blog.collaborators.append(collab) - context = {"blog": blog} - serializer = UserContextSchema() - serializer.context = context - data = serializer.dump(collab) - assert data["is_collab"] is True - noncollab = User("Foo") - data = serializer.dump(noncollab) - assert data["is_collab"] is False - - def test_function_field_raises_error_when_context_not_available(self): - # only has a function field - class UserFunctionContextSchema(Schema): - is_collab = fields.Function(lambda user, ctx: user in ctx["blog"]) - - owner = User("Joe") - serializer = UserFunctionContextSchema() - # no context - serializer.context = None - msg = "No context available for Function field {!r}".format("is_collab") - with pytest.raises(ValidationError, match=msg): - serializer.dump(owner) - - def test_function_field_handles_bound_serializer(self): - class SerializeA: - def __call__(self, value): - return "value" - - serialize = SerializeA() - - # only has a function field - class UserFunctionContextSchema(Schema): - is_collab = fields.Function(serialize) - - owner = User("Joe") - serializer = UserFunctionContextSchema() - # no context - serializer.context = None - data = serializer.dump(owner) - assert data["is_collab"] == "value" - - def test_fields_context(self): - class CSchema(Schema): - name = fields.String() - - ser = CSchema() - ser.context["foo"] = 42 - - assert ser.fields["name"].context == {"foo": 42} - - def test_nested_fields_inherit_context(self): - class InnerSchema(Schema): - likes_bikes = fields.Function(lambda obj, ctx: "bikes" in ctx["info"]) - - class CSchema(Schema): - inner = fields.Nested(InnerSchema) - - ser = CSchema() - ser.context["info"] = "i like bikes" - obj = {"inner": {}} - result = ser.dump(obj) - assert result["inner"]["likes_bikes"] is True - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 - def test_nested_list_fields_inherit_context(self): - class InnerSchema(Schema): - foo = fields.Raw() - - @validates("foo") - def validate_foo(self, value): - if "foo_context" not in self.context: - raise ValidationError("Missing context") - - class OuterSchema(Schema): - bars = fields.List(fields.Nested(InnerSchema())) - - inner = InnerSchema() - inner.context["foo_context"] = "foo" - assert inner.load({"foo": 42}) - - outer = OuterSchema() - outer.context["foo_context"] = "foo" - assert outer.load({"bars": [{"foo": 42}]}) - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/820 - def test_nested_dict_fields_inherit_context(self): - class InnerSchema(Schema): - foo = fields.Raw() - - @validates("foo") - def validate_foo(self, value): - if "foo_context" not in self.context: - raise ValidationError("Missing context") - - class OuterSchema(Schema): - bars = fields.Dict(values=fields.Nested(InnerSchema())) - - inner = InnerSchema() - inner.context["foo_context"] = "foo" - assert inner.load({"foo": 42}) - - outer = OuterSchema() - outer.context["foo_context"] = "foo" - assert outer.load({"bars": {"test": {"foo": 42}}}) - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404 - def test_nested_field_with_unpicklable_object_in_context(self): - class Unpicklable: - def __deepcopy__(self, _): - raise NotImplementedError - - class InnerSchema(Schema): - foo = fields.Raw() - - class OuterSchema(Schema): - inner = fields.Nested(InnerSchema(context={"unp": Unpicklable()})) - - outer = OuterSchema() - obj = {"inner": {"foo": 42}} - assert outer.dump(obj) - - def test_serializer_can_specify_nested_object_as_attribute(blog): class BlogUsernameSchema(Schema): author_name = fields.String(attribute="user.name") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index bdc7940b3..7c30520dc 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -85,16 +85,6 @@ def test_function_field_load_only(self): field = fields.Function(deserialize=lambda obj: None) assert field.load_only - def test_function_field_passed_serialize_with_context(self, user, monkeypatch): - class Parent(Schema): - pass - - field = fields.Function( - serialize=lambda obj, context: obj.name.upper() + context["key"] - ) - field.parent = Parent(context={"key": "BAR"}) - assert "FOOBAR" == field.serialize("key", user) - def test_function_field_passed_uncallable_object(self): with pytest.raises(TypeError): fields.Function("uncallable") diff --git a/tests/test_utils.py b/tests/test_utils.py index 9c4981523..84b812907 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import datetime as dt from collections import namedtuple from copy import copy, deepcopy -from functools import partial import pytest @@ -200,22 +199,6 @@ def test_from_timestamp_with_overflow_value(): utils.from_timestamp(value) -def test_get_func_args(): - def f1(foo, bar): - pass - - f2 = partial(f1, "baz") - - class F3: - def __call__(self, foo, bar): - pass - - f3 = F3() - - for func in [f1, f2, f3]: - assert utils.get_func_args(func) == ["foo", "bar"] - - # Regression test for https://github.com/marshmallow-code/marshmallow/issues/540 def test_function_field_using_type_annotation(): def get_split_words(value: str): # noqa