Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oneOf+const JSON Schema Option for Literals #9029

Closed
wants to merge 9 commits into from
8 changes: 7 additions & 1 deletion pydantic/_internal/_std_types_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,19 @@ def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchem

enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
case_descriptions = [
(c.value, inspect.cleandoc(c.__doc__))
for c in cases
if c.__doc__ is not None and inspect.cleandoc(c.__doc__) != description
]
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}

def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
metadata = {'enum_case_descriptions': case_descriptions}
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref, metadata=metadata))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
return json_schema
Expand Down
42 changes: 38 additions & 4 deletions pydantic/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from ._internal._schema_generation_shared import GetJsonSchemaFunction
from .main import BaseModel


CoreSchemaOrFieldType = Literal[core_schema.CoreSchemaType, core_schema.CoreSchemaFieldType]
"""
A type alias for defined schema types that represents a union of
Expand Down Expand Up @@ -242,6 +241,7 @@ class GenerateJsonSchema:
this value can be modified on subclasses to easily control which warnings are emitted.
by_alias: Whether to use field aliases when generating the schema.
ref_template: The format string used when generating reference names.
literal_type: Whether to generate Literal values using `enum` or `oneOf` + `const`.
core_to_json_refs: A mapping of core refs to JSON refs.
core_to_defs_refs: A mapping of core refs to definition refs.
defs_to_core_refs: A mapping of definition refs to core refs.
Expand All @@ -262,9 +262,15 @@ class GenerateJsonSchema:
# this value can be modified on subclasses to easily control which warnings are emitted
ignored_warning_kinds: set[JsonSchemaWarningKind] = {'skipped-choice'}

def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE):
def __init__(
self,
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
literal_type: Literal['enum', 'oneof-const'] = 'enum',
):
self.by_alias = by_alias
self.ref_template = ref_template
self.literal_type = literal_type

self.core_to_json_refs: dict[CoreModeRef, JsonRef] = {}
self.core_to_defs_refs: dict[CoreModeRef, DefsRef] = {}
Expand Down Expand Up @@ -732,10 +738,31 @@ def literal_schema(self, schema: core_schema.LiteralSchema) -> JsonSchemaValue:
# jsonify the expected values
expected = [to_jsonable_python(v) for v in expected]

result: dict[str, Any] = {'enum': expected}
result: dict[str, Any] = {}
if len(expected) == 1:
result['const'] = expected[0]

if self.literal_type == 'enum':
result['enum'] = expected
elif self.literal_type == 'oneof-const':
# TODO (rmehyde): do we want this condition or not? why do we still produce 'enum' for single values?
if len(expected) > 1:
descriptions = schema.get('metadata', {}).get('enum_case_descriptions', [])
members = []
for e in expected:
member = {'const': e}

try:
description_idx = [d[0] for d in descriptions].index(e)
member['description'] = descriptions[description_idx][1]
except ValueError:
pass

members.append(member)
result['oneOf'] = members
else:
raise ValueError(f"Unknown literal type '{self.literal_type}'")

types = {type(e) for e in expected}
if types == {str}:
result['type'] = 'string'
Expand Down Expand Up @@ -2157,6 +2184,7 @@ def model_json_schema(
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = 'validation',
literal_type: Literal['enum', 'oneof-const'] = 'enum',
) -> dict[str, Any]:
"""Utility function to generate a JSON Schema for a model.

Expand All @@ -2170,13 +2198,19 @@ def model_json_schema(

- 'validation': Generate a JSON Schema for validating data.
- 'serialization': Generate a JSON Schema for serializing data.
literal_type: Whether to generate Literal values using `enum` or `oneOf` + `const`. `oneof-const` will
add docstrings to member `description`s if available.

Returns:
The generated JSON Schema.
"""
from .main import BaseModel

schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
schema_generator_instance = schema_generator(
by_alias=by_alias,
ref_template=ref_template,
literal_type=literal_type,
)
if isinstance(cls.__pydantic_validator__, _mock_val_ser.MockValSer):
cls.__pydantic_validator__.rebuild()

Expand Down
12 changes: 10 additions & 2 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typing
import warnings
from copy import copy, deepcopy
from typing import Any, ClassVar, Dict, Generator, Set, Tuple, TypeVar, Union
from typing import Any, ClassVar, Dict, Generator, Literal, Set, Tuple, TypeVar, Union

import pydantic_core
import typing_extensions
Expand Down Expand Up @@ -383,6 +383,7 @@ def model_json_schema(
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = 'validation',
literal_type: Literal['enum', 'oneof-const'] = 'enum',
) -> dict[str, Any]:
"""Generates a JSON schema for a model class.

Expand All @@ -392,12 +393,19 @@ def model_json_schema(
schema_generator: To override the logic used to generate the JSON schema, as a subclass of
`GenerateJsonSchema` with your desired modifications
mode: The mode in which to generate the schema.
literal_type: Whether to generate Literal values using `enum` or `oneOf` + `const`. `oneof-const` will
add docstrings to member `description`s if available.

Returns:
The JSON schema for the given model class.
"""
return model_json_schema(
cls, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator, mode=mode
cls,
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
literal_type=literal_type,
)

@classmethod
Expand Down
128 changes: 128 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,134 @@ class Model(BaseModel):
}


def test_enum_schema_oneof_const():
class FooBar(str, Enum):
"""
This enum Foos and Bars
"""

foo = 'foo'
bar = 'bar'

class Model(BaseModel):
enum: FooBar

assert Model.model_json_schema(literal_type='oneof-const') == {
'title': 'Model',
'type': 'object',
'properties': {'enum': {'$ref': '#/$defs/FooBar'}},
'required': ['enum'],
'$defs': {
'FooBar': {
'title': 'FooBar',
'description': 'This enum Foos and Bars',
'oneOf': [
{'const': 'foo'},
{'const': 'bar'},
],
'type': 'string',
}
},
}


def test_enum_schema_oneof_const_member_docstring():
class DocumentedStrEnum(str, Enum):
"""
Courtesy of Ethan Furman: https://stackoverflow.com/a/50473952
"""

def __new__(cls, value, doc=None):
self = str.__new__(cls)
self._value_ = value
if doc is not None:
self.__doc__ = doc
return self

class FooBar(DocumentedStrEnum):
"""
This enum Foos and Bars
"""

foo = 'foo', 'this foos'
bar = 'bar', 'this bars'

class Model(BaseModel):
enum: FooBar

assert Model.model_json_schema(literal_type='oneof-const') == {
'title': 'Model',
'type': 'object',
'properties': {'enum': {'$ref': '#/$defs/FooBar'}},
'required': ['enum'],
'$defs': {
'FooBar': {
'title': 'FooBar',
'description': 'This enum Foos and Bars',
'oneOf': [
{'const': 'foo', 'description': 'this foos'},
{'const': 'bar', 'description': 'this bars'},
],
'type': 'string',
}
},
}


def test_enum_schema_oneof_const_single_value():
class FooEnum(str, Enum):
"""
The Foo Enum
"""

foo = 'foo'

class Model(BaseModel):
enum: FooEnum

assert Model.model_json_schema(literal_type='oneof-const') == {
'title': 'Model',
'type': 'object',
'properties': {'enum': {'$ref': '#/$defs/FooEnum'}},
'required': ['enum'],
'$defs': {
'FooEnum': {
'title': 'FooEnum',
'description': 'The Foo Enum',
'const': 'foo',
'type': 'string',
}
},
}


@pytest.mark.skipif(sys.version_info[:2] == (3, 8), reason="ListEnum doesn't work in 3.8")
def test_enum_schema_oneof_const_list_enum():
class ListEnum(List[int], Enum):
a = [123]
b = [456]

class Model(BaseModel):
enum: ListEnum

assert Model.model_json_schema(literal_type='oneof-const') == {
'title': 'Model',
'type': 'object',
'properties': {'enum': {'$ref': '#/$defs/ListEnum'}},
'required': ['enum'],
'$defs': {
'ListEnum': {
'title': 'ListEnum',
'oneOf': [
{'const': [123]},
{'const': [456]},
],
'type': 'array',
}
},
}


def test_decimal_json_schema():
class Model(BaseModel):
a: bytes = b'foobar'
Expand Down