Skip to content

Commit

Permalink
Added requests validation based on swagger schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
trezorg committed Feb 28, 2018
1 parent be39d48 commit 6aba400
Show file tree
Hide file tree
Showing 10 changed files with 1,006 additions and 25 deletions.
4 changes: 3 additions & 1 deletion aiohttp_swagger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
generate_doc_from_each_end_point,
load_doc_from_yaml_file,
swagger_path,
swagger_validation,
add_swagger_validation,
)

try:
Expand Down Expand Up @@ -89,7 +91,7 @@ def setup_swagger(app: web.Application,
)

if swagger_validate_schema:
pass
add_swagger_validation(app, swagger_info)

swagger_info = json.dumps(swagger_info)

Expand Down
1 change: 1 addition & 0 deletions aiohttp_swagger/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .builders import * # noqa
from .decorators import * # noqa
from .validation import * # noqa
58 changes: 49 additions & 9 deletions aiohttp_swagger/helpers/builders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import (
MutableMapping,
Mapping,
Expand All @@ -13,18 +14,21 @@
from aiohttp import web
from aiohttp.hdrs import METH_ANY, METH_ALL
from jinja2 import Template

try:
import ujson as json
except ImportError: # pragma: no cover
import json

from .validation import validate_decorator


SWAGGER_TEMPLATE = abspath(join(dirname(__file__), "..", "templates"))


def _extract_swagger_docs(end_point_doc, method="get"):
# Find Swagger start point in doc
def _extract_swagger_docs(end_point_doc: str) -> Mapping:
"""
Find Swagger start point in doc.
"""
end_point_swagger_start = 0
for i, doc_line in enumerate(end_point_doc):
if "---" in doc_line:
Expand All @@ -42,7 +46,7 @@ def _extract_swagger_docs(end_point_doc, method="get"):
"from docstring ⚠",
"tags": ["Invalid Swagger"]
}
return {method: end_point_swagger_doc}
return end_point_swagger_doc


def _build_doc_from_func_doc(route):
Expand All @@ -58,16 +62,14 @@ def _build_doc_from_func_doc(route):
method = getattr(route.handler, method_name)
if method.__doc__ is not None and "---" in method.__doc__:
end_point_doc = method.__doc__.splitlines()
out.update(
_extract_swagger_docs(end_point_doc, method=method_name))
out[method_name] = _extract_swagger_docs(end_point_doc)

else:
try:
end_point_doc = route.handler.__doc__.splitlines()
except AttributeError:
return {}
out.update(_extract_swagger_docs(
end_point_doc, method=route.method.lower()))
out[route.method.lower()] = _extract_swagger_docs(end_point_doc)
return out


Expand Down Expand Up @@ -150,7 +152,45 @@ def load_doc_from_yaml_file(doc_path: str) -> MutableMapping:
return yaml.load(open(doc_path, "r").read())


def add_swagger_validation(app, swagger_info: Mapping):
for route in app.router.routes():
method = route.method.lower()
handler = route.handler
formatter = route.get_info()['formatter']
if method != '*':
swagger_endpoint_info_for_method = \
swagger_info['paths'].get(formatter, {}).get(method)
swagger_endpoint_info = \
{method: swagger_endpoint_info_for_method} if \
swagger_endpoint_info_for_method is not None else {}
else:
# all methods
swagger_endpoint_info = swagger_info['paths'].get(formatter, {})
for method, info in swagger_endpoint_info.items():
logging.debug(
'Added validation for method: {}. Path: {}'.
format(method.upper(), formatter)
)
if issubclass(handler, web.View) and route.method == METH_ANY:
# whole class validation
should_be_validated = getattr(handler, 'validation', False)
cls_method = getattr(handler, method, None)
if cls_method is not None:
if not should_be_validated:
# method validation
should_be_validated = \
getattr(handler, 'validation', False)
if should_be_validated:
new_cls_method = validate_decorator(info)(cls_method)
setattr(handler, method, new_cls_method)
else:
should_be_validated = getattr(handler, 'validation', False)
if should_be_validated:
route._handler = validate_decorator(info)(handler)


__all__ = (
"generate_doc_from_each_end_point",
"load_doc_from_yaml_file"
"load_doc_from_yaml_file",
"add_swagger_validation",
)
22 changes: 21 additions & 1 deletion aiohttp_swagger/helpers/decorators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
class swagger_path(object):
from functools import partial
from inspect import isfunction, isclass

__all__ = (
'swagger_path',
'swagger_validation',
)


class swagger_path:

def __init__(self, swagger_file):
self.swagger_file = swagger_file

def __call__(self, f):
f.swagger_file = self.swagger_file
return f


def swagger_validation(func=None, *, validation=True):

if func is None or not (isfunction(func) or isclass(func)):
validation = func
return partial(swagger_validation, validation=validation)

func.validation = validation
return func
205 changes: 205 additions & 0 deletions aiohttp_swagger/helpers/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from copy import deepcopy
import sys
import json
import logging
from functools import wraps
from traceback import format_exc
from itertools import groupby
from operator import itemgetter
from typing import (
Mapping,
Iterable,
Optional,
)

from aiohttp import web
from aiohttp.web import (
Request,
Response,
json_response,
)
from collections import defaultdict
from jsonschema import (
validate,
ValidationError,
FormatChecker,
)
from jsonschema.validators import validator_for


__all__ = (
'validate_decorator',
)


logger = logging.getLogger(__name__)


def serialize_error_response(message: str, code: int, padding='error',
traceback: bool=False, **kwargs):
obj = {padding: {'message': message, 'code': code, **kwargs}}
if traceback and sys.exc_info()[0]:
obj[padding]['traceback'] = format_exc()
return json.dumps(obj, default=lambda x: str(x))


def multi_dict_to_dict(mld: Mapping) -> Mapping:
return {
key: value[0]
if isinstance(value, (list, tuple)) and len(value) == 1 else value
for key, value in mld.items()
}


def validate_schema(obj: Mapping, schema: Mapping):
validate(obj, schema, format_checker=FormatChecker())


def validate_multi_dict(obj, schema):
validate(multi_dict_to_dict(obj), schema, format_checker=FormatChecker())


def validate_content_type(swagger: Mapping, content_type: str):
consumes = swagger.get('consumes')
if consumes and not any(content_type == consume for consume in consumes):
raise ValidationError(
message='Unsupported content type: {}'.format(content_type))


async def validate_request(
request: Request,
parameter_groups: Mapping,
swagger: Mapping):
validate_content_type(swagger, request.content_type)
for group_name, group_schemas in parameter_groups.items():
if group_name == 'header':
headers = request.headers
for schema in group_schemas:
validate_multi_dict(headers, schema)
if group_name == 'query':
query = request.query
for schema in group_schemas:
validate_multi_dict(query, schema)
if group_name == 'formData':
try:
data = await request.post()
except ValueError:
data = None
for schema in group_schemas:
validate_multi_dict(data, schema)
if group_name == 'body':
try:
content = await request.json()
except json.JSONDecodeError:
content = None
for schema in group_schemas:
validate_schema(content, schema)
if group_name == 'path':
params = dict(request.match_info)
for schema in group_schemas:
validate_schema(params, schema)


def adjust_swagger_item_to_json_schemes(*schemes: Mapping) -> Mapping:
new_schema = {
'type': 'object',
'properties': {},
}
required_fields = []
for schema in schemes:
required = schema.get('required', False)
name = schema['name']
_schema = schema.get('schema')
if _schema is not None:
new_schema['properties'][name] = _schema
else:
new_schema['properties'][name] = {
key: value for key, value in schema.items()
if key not in ('required',)
}
if required:
required_fields.append(name)
if required_fields:
new_schema['required'] = required_fields
validator_for(new_schema).check_schema(new_schema)
return new_schema


def adjust_swagger_body_item_to_json_schema(schema: Mapping) -> Mapping:
required = schema.get('required', False)
_schema = schema.get('schema')
new_schema = deepcopy(_schema)
if not required:
new_schema = {
'anyOf': [
{'type': 'null'},
new_schema,
]
}
validator_for(new_schema).check_schema(new_schema)
return new_schema


def adjust_swagger_to_json_schema(parameter_groups: Iterable) -> Mapping:
res = defaultdict(list)
for group_name, group_schemas in parameter_groups:
if group_name in ('query', 'header', 'path', 'formData'):
json_schema = adjust_swagger_item_to_json_schemes(*group_schemas)
res[group_name].append(json_schema)
else:
# only one possible schema for in: body
schema = list(group_schemas)[0]
json_schema = adjust_swagger_body_item_to_json_schema(schema)
res[group_name].append(json_schema)
return res


def validation_exc_to_dict(exc, code=400):
paths = list(exc.path)
field = str(paths[-1]) if paths else ''
value = exc.instance
validator = exc.validator
message = exc.message
try:
schema = dict(exc.schema)
except TypeError:
schema = {}
return {
'message': message,
'code': code,
'description': {
'validator': validator,
'schema': schema,
'field': field,
'value': value,
}
}


def validate_decorator(swagger: Mapping):

parameters = swagger.get('parameters', [])
parameter_groups = adjust_swagger_to_json_schema(
groupby(parameters, key=itemgetter('in'))
)

def _func_wrapper(func):

@wraps(func)
async def _wrapper(*args, **kwargs) -> Response:
request = args[0].request \
if isinstance(args[0], web.View) else args[0]
try:
await validate_request(request, parameter_groups, swagger)
except ValidationError as exc:
logger.exception(exc)
exc_dict = validation_exc_to_dict(exc)
return json_response(
text=serialize_error_response(**exc_dict),
status=400
)
return await func(*args, **kwargs)

return _wrapper

return _func_wrapper
Loading

0 comments on commit 6aba400

Please sign in to comment.