Skip to content

Commit

Permalink
Added requests validation based on swagger schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
trezorg committed Feb 28, 2018
1 parent be39d48 commit 08f089e
Show file tree
Hide file tree
Showing 14 changed files with 1,158 additions and 32 deletions.
17 changes: 11 additions & 6 deletions aiohttp_swagger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
generate_doc_from_each_end_point,
load_doc_from_yaml_file,
swagger_path,
swagger_validation,
add_swagger_validation,
)

try:
import ujson as json
except ImportError:
import json

__all__ = (
"setup_swagger",
"swagger_path",
"swagger_validation",
)


@asyncio.coroutine
def _swagger_home(request):
Expand Down Expand Up @@ -89,7 +97,7 @@ def setup_swagger(app: web.Application,
)

if swagger_validate_schema:
pass
add_swagger_validation(app, swagger_info)

swagger_info = json.dumps(swagger_info)

Expand Down Expand Up @@ -119,12 +127,9 @@ def setup_swagger(app: web.Application,
with open(join(STATIC_PATH, "index.html"), "r") as f:
app["SWAGGER_TEMPLATE_CONTENT"] = (
f.read()
.replace("##SWAGGER_CONFIG##", '/{}{}'.
.replace("##SWAGGER_CONFIG##", '{}{}'.
format(api_base_url.lstrip('/'), _swagger_def_url))
.replace("##STATIC_PATH##", '/{}{}'.
.replace("##STATIC_PATH##", '{}{}'.
format(api_base_url.lstrip('/'), statics_path))
.replace("##SWAGGER_VALIDATOR_URL##", swagger_validator_url)
)


__all__ = ("setup_swagger", "swagger_path")
1 change: 1 addition & 0 deletions aiohttp_swagger/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .builders import * # noqa
from .decorators import * # noqa
from .validation import * # noqa
62 changes: 53 additions & 9 deletions aiohttp_swagger/helpers/builders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import (
MutableMapping,
Mapping,
Expand All @@ -13,18 +14,21 @@
from aiohttp import web
from aiohttp.hdrs import METH_ANY, METH_ALL
from jinja2 import Template

try:
import ujson as json
except ImportError: # pragma: no cover
import json

from .validation import validate_decorator


SWAGGER_TEMPLATE = abspath(join(dirname(__file__), "..", "templates"))


def _extract_swagger_docs(end_point_doc, method="get"):
# Find Swagger start point in doc
def _extract_swagger_docs(end_point_doc: str) -> Mapping:
"""
Find Swagger start point in doc.
"""
end_point_swagger_start = 0
for i, doc_line in enumerate(end_point_doc):
if "---" in doc_line:
Expand All @@ -42,7 +46,7 @@ def _extract_swagger_docs(end_point_doc, method="get"):
"from docstring ⚠",
"tags": ["Invalid Swagger"]
}
return {method: end_point_swagger_doc}
return end_point_swagger_doc


def _build_doc_from_func_doc(route):
Expand All @@ -58,16 +62,14 @@ def _build_doc_from_func_doc(route):
method = getattr(route.handler, method_name)
if method.__doc__ is not None and "---" in method.__doc__:
end_point_doc = method.__doc__.splitlines()
out.update(
_extract_swagger_docs(end_point_doc, method=method_name))
out[method_name] = _extract_swagger_docs(end_point_doc)

else:
try:
end_point_doc = route.handler.__doc__.splitlines()
except AttributeError:
return {}
out.update(_extract_swagger_docs(
end_point_doc, method=route.method.lower()))
out[route.method.lower()] = _extract_swagger_docs(end_point_doc)
return out


Expand Down Expand Up @@ -150,7 +152,49 @@ def load_doc_from_yaml_file(doc_path: str) -> MutableMapping:
return yaml.load(open(doc_path, "r").read())


def add_swagger_validation(app, swagger_info: Mapping):
for route in app.router.routes():
method = route.method.lower()
handler = route.handler
url_info = route.get_info()
url = url_info.get('path') or url_info.get('formatter')

if method != '*':
swagger_endpoint_info_for_method = \
swagger_info['paths'].get(url, {}).get(method)
swagger_endpoint_info = \
{method: swagger_endpoint_info_for_method} if \
swagger_endpoint_info_for_method is not None else {}
else:
# all methods
swagger_endpoint_info = swagger_info['paths'].get(url, {})
for method, info in swagger_endpoint_info.items():
logging.debug(
'Added validation for method: {}. Path: {}'.
format(method.upper(), url)
)
if issubclass(handler, web.View) and route.method == METH_ANY:
# whole class validation
should_be_validated = getattr(handler, 'validation', False)
cls_method = getattr(handler, method, None)
if cls_method is not None:
if not should_be_validated:
# method validation
should_be_validated = \
getattr(handler, 'validation', False)
if should_be_validated:
new_cls_method = \
validate_decorator(swagger_info, info)(cls_method)
setattr(handler, method, new_cls_method)
else:
should_be_validated = getattr(handler, 'validation', False)
if should_be_validated:
route._handler = \
validate_decorator(swagger_info, info)(handler)


__all__ = (
"generate_doc_from_each_end_point",
"load_doc_from_yaml_file"
"load_doc_from_yaml_file",
"add_swagger_validation",
)
22 changes: 21 additions & 1 deletion aiohttp_swagger/helpers/decorators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
class swagger_path(object):
from functools import partial
from inspect import isfunction, isclass

__all__ = (
'swagger_path',
'swagger_validation',
)


class swagger_path:

def __init__(self, swagger_file):
self.swagger_file = swagger_file

def __call__(self, f):
f.swagger_file = self.swagger_file
return f


def swagger_validation(func=None, *, validation=True):

if func is None or not (isfunction(func) or isclass(func)):
validation = func
return partial(swagger_validation, validation=validation)

func.validation = validation
return func
Loading

0 comments on commit 08f089e

Please sign in to comment.