forked from reanahub/reana-server
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes reanahub/reana#680
- Loading branch information
Showing
1 changed file
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
# | ||
# This file is part of REANA. | ||
# Copyright (C) 2022 CERN. | ||
# | ||
# REANA is free software; you can redistribute it and/or modify it | ||
# under the terms of the MIT License; see LICENSE file for more details. | ||
"""OpenAPI specification checker.""" | ||
|
||
import ast | ||
import inspect | ||
import json | ||
import pkgutil | ||
import sys | ||
from typing import Set | ||
|
||
import click | ||
from apispec.utils import load_yaml_from_docstring | ||
|
||
from reana_server.factory import create_app | ||
|
||
|
||
def load_rwc_spec(): | ||
"""Load and parse the OpenAPI specification of RWC.""" | ||
spec = pkgutil.get_data( | ||
"reana_commons", "openapi_specifications/reana_workflow_controller.json" | ||
) | ||
if not spec: | ||
raise RuntimeError("Cannot load RWC specification") | ||
return json.loads(spec) | ||
|
||
|
||
class ViewVisitor(ast.NodeVisitor): | ||
"""Class that analyzes a view function by visiting the nodes of the AST.""" | ||
|
||
def __init__(self, checker: "Checker"): | ||
super().__init__() | ||
self.checker = checker | ||
self.called_rwc_operations: Set[str] = set() | ||
self.status_codes: Set[int] = set() | ||
self.returns_rwc_status_code: bool = False | ||
self.signin_required: bool = False | ||
|
||
def check_rwc_method(self, node: ast.Call) -> None: | ||
"""Check if the given function call is a request to RWC. | ||
Calls to RWC are usually in the form of | ||
``current_rwc_api_client.api.operation_id(...)``. | ||
""" | ||
if not isinstance(node.func, ast.Attribute): | ||
return | ||
operation_id = node.func.attr | ||
api_node = node.func.value | ||
if not isinstance(api_node, ast.Attribute) or api_node.attr != "api": | ||
return | ||
rwc_node = api_node.value | ||
if ( | ||
not isinstance(rwc_node, ast.Name) | ||
or rwc_node.id != "current_rwc_api_client" | ||
): | ||
return | ||
self.called_rwc_operations.add(operation_id) | ||
|
||
def check_status_code(self, node: ast.Return) -> None: | ||
"""Check which status code is returned.""" | ||
ret_value = node.value | ||
if not isinstance(ret_value, ast.Tuple): | ||
self.checker.warning("returned value is not a tuple") | ||
elif len(ret_value.elts) != 2: | ||
self.checker.warning("returned tuple does not have two elements") | ||
else: | ||
status_code = ret_value.elts[1] | ||
if ( | ||
isinstance(status_code, ast.Attribute) | ||
and status_code.attr == "status_code" | ||
): | ||
# probably returning error coming from RWC | ||
self.returns_rwc_status_code = True | ||
elif isinstance(status_code, ast.Constant): | ||
# e.g. return ..., 404 | ||
try: | ||
self.status_codes.add(int(status_code.value)) | ||
except Exception: | ||
self.checker.warning("status code is not an integer") | ||
else: | ||
self.checker.warning("unknown status code") | ||
|
||
def check_signin_required(self, node: ast.FunctionDef) -> None: | ||
"""Check whether `signin_required` decorator is applied to the view function.""" | ||
for decorator in node.decorator_list: | ||
if ( | ||
isinstance(decorator, ast.Call) | ||
and isinstance(decorator.func, ast.Name) | ||
and decorator.func.id == "signin_required" | ||
): | ||
self.signin_required = True | ||
|
||
def visit_Return(self, node: ast.Return) -> None: | ||
self.check_status_code(node) | ||
|
||
def visit_Call(self, node: ast.Call) -> None: | ||
self.check_rwc_method(node) | ||
self.generic_visit(node) | ||
|
||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: | ||
self.check_signin_required(node) | ||
self.generic_visit(node) | ||
|
||
|
||
class Checker: | ||
"""Check view functions to find common mistakes in their OpenAPI specification.""" | ||
|
||
rwc_spec = load_rwc_spec() | ||
|
||
@classmethod | ||
def get_rwc_status_codes(cls, operation_id): | ||
"""Get the possible status codes returned by a RWC operation.""" | ||
for path in cls.rwc_spec["paths"]: | ||
for operation in cls.rwc_spec["paths"][path].values(): | ||
if operation["operationId"] == operation_id: | ||
return set(map(int, operation["responses"].keys())) | ||
raise ValueError("RWC operation not found") | ||
|
||
def __init__(self): | ||
self.errors: bool = False | ||
|
||
def info(self, msg: str) -> None: | ||
"""Print an info message.""" | ||
print(msg) | ||
|
||
def warning(self, msg: str) -> None: | ||
"""Print a warning message""" | ||
click.secho(f" --> {msg}", fg="yellow") | ||
|
||
def error(self, msg: str) -> None: | ||
"""Print an error message""" | ||
self.errors = True | ||
click.secho(f" --> {msg}", fg="red") | ||
|
||
def check_view(self, name: str, view) -> None: | ||
"""Check for errors in the OpenAPI specification of the given view.""" | ||
self.info(f"Checking {name}") | ||
|
||
spec = load_yaml_from_docstring(view.__doc__) | ||
if not spec: | ||
self.warning("no specification found") | ||
return | ||
spec_codes = {code for op in spec.values() for code in op["responses"]} | ||
|
||
# Parse the view code | ||
tree = ast.parse(inspect.getsource(view)) | ||
visitor = ViewVisitor(self) | ||
visitor.visit(tree) | ||
|
||
if visitor.called_rwc_operations and not visitor.returns_rwc_status_code: | ||
self.warning("detected request to RWC, but status code is not propagated") | ||
|
||
if visitor.returns_rwc_status_code: | ||
# Check that status codes returned by RWC are propagated | ||
for operation_id in visitor.called_rwc_operations: | ||
rwc_codes = self.get_rwc_status_codes(operation_id) | ||
missing_codes = rwc_codes - spec_codes | ||
for code in missing_codes: | ||
self.error(f"missing {code} returned by `{operation_id}` of RWC") | ||
|
||
for code in visitor.status_codes: | ||
if code not in spec_codes: | ||
self.error(f"missing {code} returned by view") | ||
|
||
if visitor.signin_required: | ||
# `signin_required` returns 401 and 403 if credentials are not valid | ||
for code in (401, 403): | ||
if code not in spec_codes: | ||
self.error(f"missing {code} returned by `signin_required`") | ||
|
||
|
||
def main(): | ||
app = create_app() | ||
checker = Checker() | ||
for name, view in app.view_functions.items(): | ||
checker.check_view(name, view) | ||
if checker.errors: | ||
sys.exit(1) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |