Skip to content

Commit

Permalink
Merge pull request #863 from vitalik/forbid-extra
Browse files Browse the repository at this point in the history
Forbid extra
  • Loading branch information
vitalik authored Sep 26, 2023
2 parents aba664b + bba690f commit 7139f5f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 9 deletions.
6 changes: 3 additions & 3 deletions ninja/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def _result_to_response(
return temporal_response

resp_object = ResponseObject(result)
# ^ we need object because getter_dict seems work only with from_orm
result = response_model.from_orm(resp_object).model_dump(
# ^ we need object because getter_dict seems work only with model_validate
result = response_model.model_validate(resp_object).model_dump(
by_alias=self.by_alias,
exclude_unset=self.exclude_unset,
exclude_defaults=self.exclude_defaults,
Expand Down Expand Up @@ -419,7 +419,7 @@ def _not_allowed(self) -> HttpResponse:


class ResponseObject:
"Basically this is just a helper to be able to pass response to pydantic's from_orm"
"Basically this is just a helper to be able to pass response to pydantic's model_validate"

def __init__(self, response: HttpResponse) -> None:
self.response = response
19 changes: 14 additions & 5 deletions ninja/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def resolve_initials(self, obj):
from django.template import Variable, VariableDoesNotExist
from pydantic import BaseModel, Field, ValidationInfo, model_validator, validator
from pydantic._internal._model_construction import ModelMetaclass
from pydantic.functional_validators import ModelWrapValidatorHandler
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue

from ninja.signature.utils import get_args_names, has_kwargs
Expand All @@ -45,7 +46,7 @@ def resolve_initials(self, obj):
class DjangoGetter:
__slots__ = ("_obj", "_schema_cls", "_context")

def __init__(self, obj: Any, schema_cls: "Schema", context: Any = None):
def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None):
self._obj = obj
self._schema_cls = schema_cls
self._context = context
Expand All @@ -54,7 +55,7 @@ def __getattr__(self, key: str) -> Any:
# if key.startswith("__pydantic"):
# return getattr(self._obj, key)

resolver = self._schema_cls._ninja_resolvers.get(key) # type: ignore
resolver = self._schema_cls._ninja_resolvers.get(key)
if resolver:
value = resolver(getter=self)
else:
Expand Down Expand Up @@ -198,10 +199,18 @@ class Schema(BaseModel, metaclass=ResolverMetaclass):
class Config:
from_attributes = True # aka orm_mode

@model_validator(mode="before")
def _run_root_validator(cls, values: Any, info: ValidationInfo) -> Any:
@model_validator(mode="wrap")
@classmethod
def _run_root_validator(
cls, values: Any, handler: ModelWrapValidatorHandler[S], info: ValidationInfo
) -> S:
# when extra is "forbid" we need to perform default pydantic valudation
# as DjangoGetter does not act as dict and pydantic will not be able to validate it
if cls.model_config.get("extra") == "forbid":
handler(values)

values = DjangoGetter(values, cls, info.context)
return values
return handler(values)

@classmethod
def from_orm(cls: Type[S], obj: Any) -> S:
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ dev = [
"pre-commit",
]


# BLACK ==========================================================

[tool.black]
line-length = 88
skip-string-normalization = false
target-version = ['py311']


[tool.ruff]
select = [
"E", # pycodestyle errors
Expand Down
76 changes: 75 additions & 1 deletion tests/test_request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from typing import Optional

import pytest
from pydantic import ConfigDict

from ninja import Cookie, Header, Router
from ninja import Body, Cookie, Header, Router, Schema
from ninja.testing import TestClient


class OptionalEmptySchema(Schema):
model_config = ConfigDict(extra="forbid")
name: Optional[str] = None


class ExtraForbidSchema(Schema):
model_config = ConfigDict(extra="forbid")
name: str
metadata: Optional[OptionalEmptySchema] = None


router = Router()


Expand Down Expand Up @@ -41,6 +56,11 @@ def cookies2(request, wpn: str = Cookie(..., alias="weapon")):
return wpn


@router.post("/test-schema")
def test_schema(request, payload: ExtraForbidSchema = Body(...)):
return "ok"


client = TestClient(router)


Expand Down Expand Up @@ -77,3 +97,57 @@ def test_headers(path, expected_status, expected_response):
assert response.status_code == expected_status, response.content
print(response.json())
assert response.json() == expected_response


@pytest.mark.parametrize(
"path,json,expected_status,expected_response",
[
(
"/test-schema",
{"name": "test", "extra_name": "test2"},
422,
{
"detail": [
{
"type": "extra_forbidden",
"loc": ["body", "payload", "extra_name"],
"msg": "Extra inputs are not permitted",
}
]
},
),
(
"/test-schema",
{"name": "test", "metadata": {"extra_name": "xxx"}},
422,
{
"detail": [
{
"loc": ["body", "payload", "metadata", "extra_name"],
"msg": "Extra inputs are not permitted",
"type": "extra_forbidden",
}
]
},
),
(
"/test-schema",
{"name": "test", "metadata": "test2"},
422,
{
"detail": [
{
"type": "model_attributes_type",
"loc": ["body", "payload", "metadata"],
"msg": "Input should be a valid dictionary or object to extract fields from",
}
]
},
),
],
)
def test_pydantic_config(path, json, expected_status, expected_response):
# test extra forbid
response = client.post(path, json=json)
assert response.json() == expected_response
assert response.status_code == expected_status

0 comments on commit 7139f5f

Please sign in to comment.