Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for marshmallow.fields.Enum in marshmallow ≥ v3.18 #170

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions marshmallow_jsonschema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,21 @@
from marshmallow.class_registry import get_class
from marshmallow.decorators import post_dump
from marshmallow.utils import _Missing

from marshmallow import INCLUDE, EXCLUDE, RAISE
# marshmallow.fields.Enum support has been added in marshmallow v3.18
# see https://github.com/marshmallow-code/marshmallow/blob/dev/CHANGELOG.rst#3180-2022-09-15
from marshmallow import __version__ as _MarshmallowVersion
# the package "packaging" is a requirement of marshmallow itself => we don't need to install it separately
# see https://github.com/marshmallow-code/marshmallow/blob/ddbe06f923befe754e213e03fb95be54e996403d/setup.py#L61
from packaging.version import Version


def marshmallow_version_supports_native_enums() -> bool:
"""
returns true if and only if the version of marshmallow installed supports enums natively
"""
return Version(_MarshmallowVersion) >= Version("3.18")


try:
from marshmallow_union import Union
Expand All @@ -20,11 +33,15 @@
ALLOW_UNIONS = False

try:
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_enum import EnumField as MarshmallowEnumEnumField, LoadDumpOptions

ALLOW_ENUMS = True
ALLOW_MARSHMALLOW_ENUM_ENUMS = True
except ImportError:
ALLOW_ENUMS = False
ALLOW_MARSHMALLOW_ENUM_ENUMS = False

ALLOW_MARSHMALLOW_NATIVE_ENUMS = marshmallow_version_supports_native_enums()
if ALLOW_MARSHMALLOW_NATIVE_ENUMS:
from marshmallow.fields import Enum as MarshmallowNativeEnumField

from .exceptions import UnsupportedValueError
from .validation import (
Expand Down Expand Up @@ -92,10 +109,12 @@
(fields.Nested, dict),
]

if ALLOW_ENUMS:
if ALLOW_MARSHMALLOW_NATIVE_ENUMS:
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowNativeEnumField, Enum))
if ALLOW_MARSHMALLOW_ENUM_ENUMS:
# We currently only support loading enum's from their names. So the possible
# values will always map to string in the JSONSchema
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((EnumField, Enum))
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowEnumEnumField, Enum))


FIELD_VALIDATORS = {
Expand Down Expand Up @@ -191,8 +210,10 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
if field.default is not missing and not callable(field.default):
json_schema["default"] = field.default

if ALLOW_ENUMS and isinstance(field, EnumField):
json_schema["enum"] = self._get_enum_values(field)
if ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField):
json_schema["enum"] = self._get_marshmallow_native_enum_values(field)
elif ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField):
json_schema["enum"] = self._get_marshmallow_enum_enum_values(field)

if field.allow_none:
previous_type = json_schema["type"]
Expand All @@ -218,8 +239,8 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
)
return json_schema

def _get_enum_values(self, field) -> typing.List[str]:
assert ALLOW_ENUMS and isinstance(field, EnumField)
def _get_marshmallow_enum_enum_values(self, field) -> typing.List[str]:
assert ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField)

if field.load_by == LoadDumpOptions.value:
# Python allows enum values to be almost anything, so it's easier to just load from the
Expand All @@ -229,6 +250,17 @@ def _get_enum_values(self, field) -> typing.List[str]:
)

return [value.name for value in field.enum]
def _get_marshmallow_native_enum_values(self, field) -> typing.List[str]:
assert ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField)

if field.by_value:
# Python allows enum values to be almost anything, so it's easier to just load from the
# names of the enum's which will have to be strings.
raise NotImplementedError(
"Currently do not support JSON schema for enums loaded by value"
)

return [value.name for value in field.enum]

def _from_union_schema(
self, obj, field
Expand Down
58 changes: 53 additions & 5 deletions tests/test_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@

import pytest
from marshmallow import Schema, fields, validate
from marshmallow_enum import EnumField
from marshmallow_enum import EnumField as MarshmallowEnumEnumField
from marshmallow_union import Union

import marshmallow_jsonschema
from marshmallow_jsonschema import JSONSchema, UnsupportedValueError
from . import UserSchema, validate_and_dump

TEST_MARSHMALLOW_NATIVE_ENUM = marshmallow_jsonschema.base.marshmallow_version_supports_native_enums()
try:
from marshmallow.fields import Enum as MarshmallowNativeEnumField
except ImportError:
assert TEST_MARSHMALLOW_NATIVE_ENUM is False

def test_dump_schema():
schema = UserSchema()
Expand Down Expand Up @@ -648,14 +654,14 @@ class Meta:
assert properties_names == ["d", "c", "a"]


def test_enum_based():
def test_marshmallow_enum_enum_based():
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = EnumField(TestEnum)
enum_prop = MarshmallowEnumEnumField(TestEnum)

# Should be sorting of fields
schema = TestSchema()
Expand All @@ -671,15 +677,39 @@ class TestSchema(Schema):
)
assert received_enum_values == ["value_1", "value_2", "value_3"]

def test_native_marshmallow_enum_based():
if not TEST_MARSHMALLOW_NATIVE_ENUM:
return
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = MarshmallowNativeEnumField(TestEnum)

# Should be sorting of fields
schema = TestSchema()

json_schema = JSONSchema()
data = json_schema.dump(schema)

assert (
data["definitions"]["TestSchema"]["properties"]["enum_prop"]["type"] == "string"
)
received_enum_values = sorted(
data["definitions"]["TestSchema"]["properties"]["enum_prop"]["enum"]
)
assert received_enum_values == ["value_1", "value_2", "value_3"]

def test_enum_based_load_dump_value():
def test_marshmallow_enum_enum_based_load_dump_value():
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = EnumField(TestEnum, by_value=True)
enum_prop = MarshmallowEnumEnumField(TestEnum, by_value=True)

# Should be sorting of fields
schema = TestSchema()
Expand All @@ -689,6 +719,24 @@ class TestSchema(Schema):
with pytest.raises(NotImplementedError):
validate_and_dump(json_schema.dump(schema))

def test_native_marshmallow_enum_based_load_dump_value():
if not TEST_MARSHMALLOW_NATIVE_ENUM:
return
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = MarshmallowNativeEnumField(TestEnum, by_value=True)

# Should be sorting of fields
schema = TestSchema()

json_schema = JSONSchema()

with pytest.raises(NotImplementedError):
validate_and_dump(json_schema.dump(schema))

def test_union_based():
class TestNestedSchema(Schema):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_import_marshmallow_enum(monkeypatch):

base = importlib.reload(marshmallow_jsonschema.base)

assert not base.ALLOW_ENUMS
assert not base.ALLOW_MARSHMALLOW_ENUM_ENUMS

monkeypatch.undo()

Expand Down