Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a context variable to pass Schema context #2707

Merged
merged 25 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
07c9c18
Add CONTEXT context variable
lafrech Dec 30, 2024
4f4e048
Expose Context context manager
lafrech Dec 30, 2024
a5003b8
Remove Schema.context and Field.context
lafrech Dec 30, 2024
7bb901f
Expose context as field/schema property
lafrech Dec 31, 2024
a454a88
Make current_context a Context class attribute
lafrech Jan 1, 2025
0f285ed
Don't provide None as default context
lafrech Jan 1, 2025
feffc2e
Fix Function field docstring about context
lafrech Jan 1, 2025
6fac57d
Allow passing a default to Context.get
lafrech Jan 1, 2025
1a4eec7
Never pass context to functions in Function field
lafrech Jan 1, 2025
4447c07
Remove utils.get_func_args
lafrech Jan 2, 2025
72de755
Make _CURRENT_CONTEXT a module-level attribute
lafrech Jan 2, 2025
17bd038
Move Context into experimental
lafrech Jan 2, 2025
63abfc1
Add typing to context.py
sloria Jan 2, 2025
c6c4e88
Add tests for decorated processors with context
lafrech Jan 3, 2025
c7d0bca
Merge branch '4.0' into context
lafrech Jan 3, 2025
947de51
Update documentation about removal of context
lafrech Jan 4, 2025
5b10b84
Update versionchanged in docstrings
lafrech Jan 4, 2025
318cae0
Update changelog about Context
lafrech Jan 4, 2025
e638af3
Context: initialize token at __init__
lafrech Jan 4, 2025
5e485cb
Merge branch '4.0' into context
sloria Jan 5, 2025
5604959
Minor edit to upgrading guide
sloria Jan 5, 2025
c89c15a
Add more documentation for Context
sloria Jan 5, 2025
f1cbe27
More complete examples
sloria Jan 5, 2025
63d46aa
Exemplify using type aliases for Context
sloria Jan 5, 2025
5edd8b5
Merge branch '4.0' into context
lafrech Jan 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/marshmallow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from packaging.version import Version

from marshmallow.context import Context
from marshmallow.decorators import (
post_dump,
post_load,
Expand Down Expand Up @@ -66,6 +67,7 @@ def __getattr__(name: str) -> typing.Any:
"EXCLUDE",
"INCLUDE",
"RAISE",
"Context",
"Schema",
"SchemaOpts",
"fields",
Expand Down
23 changes: 23 additions & 0 deletions src/marshmallow/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Objects related to serializtion/deserialization context"""

import contextlib
import contextvars


class Context(contextlib.AbstractContextManager):
sloria marked this conversation as resolved.
Show resolved Hide resolved
_current_context: contextvars.ContextVar = contextvars.ContextVar("context")
lafrech marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, context):
self.context = context

def __enter__(self):
self.token = self._current_context.set(self.context)

def __exit__(self, *args, **kwargs):
self._current_context.reset(self.token)

@classmethod
def get(cls, default=...):
if default is not ...:
return cls._current_context.get(default)
return cls._current_context.get()
29 changes: 4 additions & 25 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,6 @@ def _deserialize(
"""
return value

# Properties

@property
def context(self):
"""The context dictionary for the parent :class:`Schema`."""
return self.parent.context


class Raw(Field):
"""Field that applies no formatting."""
Expand Down Expand Up @@ -498,8 +491,6 @@ def schema(self):
Renamed from `serializer` to `schema`.
"""
if not self._schema:
# Inherit context from parent.
context = getattr(self.parent, "context", {})
if callable(self.nested) and not isinstance(self.nested, type):
nested = self.nested()
else:
Expand All @@ -512,7 +503,6 @@ def schema(self):

if isinstance(nested, SchemaABC):
self._schema = copy.copy(nested)
self._schema.context.update(context)
# Respect only and exclude passed from parent and re-initialize fields
set_class = self._schema.set_class
if self.only is not None:
Expand All @@ -539,7 +529,6 @@ def schema(self):
many=self.many,
only=self.only,
exclude=self.exclude,
context=context,
load_only=self._nested_normalized_option("load_only"),
dump_only=self._nested_normalized_option("dump_only"),
)
Expand Down Expand Up @@ -1909,14 +1898,12 @@ class Function(Field):

:param serialize: A callable from which to retrieve the value.
The function must take a single argument ``obj`` which is the object
to be serialized. It can also optionally take a ``context`` argument,
which is a dictionary of context variables passed to the serializer.
to be serialized.
If no callable is provided then the ```load_only``` flag will be set
to True.
:param deserialize: A callable from which to retrieve the value.
The function must take a single argument ``value`` which is the value
to be deserialized. It can also optionally take a ``context`` argument,
which is a dictionary of context variables passed to the deserializer.
to be deserialized.
If no callable is provided then ```value``` will be passed through
unchanged.

Expand Down Expand Up @@ -1951,21 +1938,13 @@ def __init__(
self.deserialize_func = deserialize and utils.callable_or_raise(deserialize)

def _serialize(self, value, attr, obj, **kwargs):
return self._call_or_raise(self.serialize_func, obj, attr)
return self.serialize_func(obj)

def _deserialize(self, value, attr, data, **kwargs):
if self.deserialize_func:
return self._call_or_raise(self.deserialize_func, value, attr)
return self.deserialize_func(value)
return value

def _call_or_raise(self, func, value, attr):
if len(utils.get_func_args(func)) > 1:
lafrech marked this conversation as resolved.
Show resolved Hide resolved
if self.parent.context is None:
msg = f"No context available for Function field {attr!r}"
raise ValidationError(msg)
return func(value, self.parent.context)
return func(value)


class Constant(Field):
"""A field that (de)serializes to a preset constant. If you only want the
Expand Down
4 changes: 0 additions & 4 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ class AlbumSchema(Schema):
delimiters.
:param many: Should be set to `True` if ``obj`` is a collection
so that the object will be serialized to a list.
:param context: Optional context passed to :class:`fields.Method` and
:class:`fields.Function` fields.
:param load_only: Fields to skip during serialization (write-only fields)
:param dump_only: Fields to skip during deserialization (read-only fields)
:param partial: Whether to ignore missing fields and not require
Expand Down Expand Up @@ -346,7 +344,6 @@ def __init__(
only: types.StrSequenceOrSet | None = None,
exclude: types.StrSequenceOrSet = (),
many: bool | None = None,
context: dict | None = None,
load_only: types.StrSequenceOrSet = (),
dump_only: types.StrSequenceOrSet = (),
partial: bool | types.StrSequenceOrSet | None = None,
Expand All @@ -373,7 +370,6 @@ def __init__(
if unknown is None
else validate_unknown_parameter_value(unknown)
)
self.context = context or {}
self._normalize_nested_options()
#: Dictionary mapping field_names -> :class:`Field` objects
self.fields = {} # type: dict[str, ma_fields.Field]
Expand Down
17 changes: 13 additions & 4 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@

import pytest

from marshmallow import EXCLUDE, INCLUDE, RAISE, Schema, fields, validate
from marshmallow import (
EXCLUDE,
INCLUDE,
RAISE,
Context,
Schema,
fields,
validate,
)
from marshmallow.exceptions import ValidationError
from marshmallow.validate import Equal
from tests.base import (
Expand Down Expand Up @@ -1000,10 +1008,11 @@ class Parent(Schema):

field = fields.Function(
lambda x: None,
deserialize=lambda val, context: val.upper() + context["key"],
deserialize=lambda val: val.upper() + Context.get()["key"],
)
field.parent = Parent(context={"key": "BAR"})
assert field.deserialize("foo") == "FOOBAR"
field.parent = Parent()
with Context({"key": "BAR"}):
assert field.deserialize("foo") == "FOOBAR"

def test_function_field_passed_deserialize_only_is_load_only(self):
field = fields.Function(deserialize=lambda val: val.upper())
Expand Down
126 changes: 62 additions & 64 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EXCLUDE,
INCLUDE,
RAISE,
Context,
Schema,
class_registry,
fields,
Expand Down Expand Up @@ -353,15 +354,13 @@ class NestedSchema(Schema):
bar = fields.Str()

def on_bind_field(self, field_name, field_obj):
field_obj.metadata["fname"] = self.context["fname"]
lafrech marked this conversation as resolved.
Show resolved Hide resolved
assert field_obj.parent is self
field_obj.metadata["fname"] = field_name

foo = fields.Nested(NestedSchema)

schema1 = MySchema(context={"fname": "foobar"})
schema2 = MySchema(context={"fname": "quxquux"})

assert schema1.fields["foo"].schema.fields["bar"].metadata["fname"] == "foobar"
assert schema2.fields["foo"].schema.fields["bar"].metadata["fname"] == "quxquux"
schema = MySchema()
assert schema.fields["foo"].schema.fields["bar"].metadata["fname"] == "bar"


class TestValidate:
Expand Down Expand Up @@ -2158,51 +2157,60 @@ class ValidatingSchema(Schema):

class UserContextSchema(Schema):
is_owner = fields.Method("get_is_owner")
is_collab = fields.Function(lambda user, ctx: user in ctx["blog"])
is_collab = fields.Function(lambda user: user in Context.get()["blog"])

def get_is_owner(self, user):
return self.context["blog"].user.name == user.name
return Context.get()["blog"].user.name == user.name


class TestContext:
def test_context_load_dump(self):
class ContextField(fields.Integer):
def _serialize(self, value, attr, obj, **kwargs):
if (context := Context.get(None)) is not None:
value *= context.get("factor", 1)
return super()._serialize(value, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
val = super()._deserialize(value, attr, data, **kwargs)
if (context := Context.get(None)) is not None:
val *= context.get("factor", 1)
return val

class ContextSchema(Schema):
ctx_fld = ContextField()

ctx_schema = ContextSchema()

assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 1}
assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 1}
with Context({"factor": 2}):
assert ctx_schema.load({"ctx_fld": 1}) == {"ctx_fld": 2}
assert ctx_schema.dump({"ctx_fld": 1}) == {"ctx_fld": 2}

def test_context_method(self):
owner = User("Joe")
blog = Blog(title="Joe Blog", user=owner)
context = {"blog": blog}
serializer = UserContextSchema()
serializer.context = context
data = serializer.dump(owner)
assert data["is_owner"] is True
nonowner = User("Fred")
data = serializer.dump(nonowner)
assert data["is_owner"] is False
with Context({"blog": blog}):
data = serializer.dump(owner)
assert data["is_owner"] is True
nonowner = User("Fred")
data = serializer.dump(nonowner)
assert data["is_owner"] is False

def test_context_method_function(self):
owner = User("Fred")
blog = Blog("Killer Queen", user=owner)
collab = User("Brian")
blog.collaborators.append(collab)
context = {"blog": blog}
serializer = UserContextSchema()
serializer.context = context
data = serializer.dump(collab)
assert data["is_collab"] is True
noncollab = User("Foo")
data = serializer.dump(noncollab)
assert data["is_collab"] is False

def test_function_field_raises_error_when_context_not_available(self):
lafrech marked this conversation as resolved.
Show resolved Hide resolved
# only has a function field
class UserFunctionContextSchema(Schema):
is_collab = fields.Function(lambda user, ctx: user in ctx["blog"])

owner = User("Joe")
serializer = UserFunctionContextSchema()
# no context
serializer.context = None
msg = "No context available for Function field {!r}".format("is_collab")
with pytest.raises(ValidationError, match=msg):
serializer.dump(owner)
with Context({"blog": blog}):
serializer = UserContextSchema()
data = serializer.dump(collab)
assert data["is_collab"] is True
noncollab = User("Foo")
data = serializer.dump(noncollab)
assert data["is_collab"] is False

def test_function_field_handles_bound_serializer(self):
class SerializeA:
Expand All @@ -2217,32 +2225,21 @@ class UserFunctionContextSchema(Schema):

owner = User("Joe")
serializer = UserFunctionContextSchema()
# no context
serializer.context = None
data = serializer.dump(owner)
assert data["is_collab"] == "value"

def test_fields_context(self):
class CSchema(Schema):
name = fields.String()

ser = CSchema()
ser.context["foo"] = 42

assert ser.fields["name"].context == {"foo": 42}

def test_nested_fields_inherit_context(self):
class InnerSchema(Schema):
likes_bikes = fields.Function(lambda obj, ctx: "bikes" in ctx["info"])
likes_bikes = fields.Function(lambda obj: "bikes" in Context.get()["info"])

class CSchema(Schema):
inner = fields.Nested(InnerSchema)

ser = CSchema()
ser.context["info"] = "i like bikes"
obj = {"inner": {}}
result = ser.dump(obj)
assert result["inner"]["likes_bikes"] is True
with Context({"info": "i like bikes"}):
obj = {"inner": {}}
result = ser.dump(obj)
assert result["inner"]["likes_bikes"] is True

# Regression test for https://github.com/marshmallow-code/marshmallow/issues/820
def test_nested_list_fields_inherit_context(self):
Expand All @@ -2251,19 +2248,19 @@ class InnerSchema(Schema):

@validates("foo")
def validate_foo(self, value):
if "foo_context" not in self.context:
if "foo_context" not in Context.get():
raise ValidationError("Missing context")

class OuterSchema(Schema):
bars = fields.List(fields.Nested(InnerSchema()))

inner = InnerSchema()
inner.context["foo_context"] = "foo"
assert inner.load({"foo": 42})
with Context({"foo_context": "foo"}):
assert inner.load({"foo": 42})

outer = OuterSchema()
outer.context["foo_context"] = "foo"
assert outer.load({"bars": [{"foo": 42}]})
with Context({"foo_context": "foo"}):
assert outer.load({"bars": [{"foo": 42}]})

# Regression test for https://github.com/marshmallow-code/marshmallow/issues/820
def test_nested_dict_fields_inherit_context(self):
Expand All @@ -2272,19 +2269,19 @@ class InnerSchema(Schema):

@validates("foo")
def validate_foo(self, value):
if "foo_context" not in self.context:
if "foo_context" not in Context.get():
raise ValidationError("Missing context")

class OuterSchema(Schema):
bars = fields.Dict(values=fields.Nested(InnerSchema()))

inner = InnerSchema()
inner.context["foo_context"] = "foo"
assert inner.load({"foo": 42})
with Context({"foo_context": "foo"}):
assert inner.load({"foo": 42})

outer = OuterSchema()
outer.context["foo_context"] = "foo"
assert outer.load({"bars": {"test": {"foo": 42}}})
with Context({"foo_context": "foo"}):
assert outer.load({"bars": {"test": {"foo": 42}}})

# Regression test for https://github.com/marshmallow-code/marshmallow/issues/1404
def test_nested_field_with_unpicklable_object_in_context(self):
Expand All @@ -2296,11 +2293,12 @@ class InnerSchema(Schema):
foo = fields.Field()

class OuterSchema(Schema):
inner = fields.Nested(InnerSchema(context={"unp": Unpicklable()}))
inner = fields.Nested(InnerSchema())

outer = OuterSchema()
obj = {"inner": {"foo": 42}}
assert outer.dump(obj)
with Context({"unp": Unpicklable()}):
assert outer.dump(obj)


def test_serializer_can_specify_nested_object_as_attribute(blog):
Expand Down
Loading
Loading