Skip to content

Commit

Permalink
scripts: add openapi checker
Browse files Browse the repository at this point in the history
  • Loading branch information
mdonadoni committed Feb 6, 2023
1 parent b80cf83 commit f7b3414
Showing 1 changed file with 188 additions and 0 deletions.
188 changes: 188 additions & 0 deletions scripts/openapi_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# This file is part of REANA.
# Copyright (C) 2022 CERN.
#
# REANA is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""OpenAPI specification checker."""

import ast
import inspect
import json
import pkgutil
import sys
from typing import Set

import click
from apispec.utils import load_yaml_from_docstring

from reana_server.factory import create_app


def load_rwc_spec():
"""Load and parse the OpenAPI specification of RWC."""
spec = pkgutil.get_data(
"reana_commons", "openapi_specifications/reana_workflow_controller.json"
)
if not spec:
raise RuntimeError("Cannot load RWC specification")
return json.loads(spec)


class ViewVisitor(ast.NodeVisitor):
"""Class that analyzes a view function by visiting the nodes of the AST."""

def __init__(self, checker: "Checker"):
super().__init__()
self.checker = checker
self.called_rwc_operations: Set[str] = set()
self.status_codes: Set[int] = set()
self.returns_rwc_status_code: bool = False
self.signin_required: bool = False

def check_rwc_method(self, node: ast.Call) -> None:
"""Check if the given function call is a request to RWC.
Calls to RWC are usually in the form of
``current_rwc_api_client.api.operation_id(...)``.
"""
if not isinstance(node.func, ast.Attribute):
return
operation_id = node.func.attr
api_node = node.func.value
if not isinstance(api_node, ast.Attribute) or api_node.attr != "api":
return
rwc_node = api_node.value
if (
not isinstance(rwc_node, ast.Name)
or rwc_node.id != "current_rwc_api_client"
):
return
self.called_rwc_operations.add(operation_id)

def check_status_code(self, node: ast.Return) -> None:
"""Check which status code is returned."""
ret_value = node.value
if not isinstance(ret_value, ast.Tuple):
self.checker.warning("returned value is not a tuple")
elif len(ret_value.elts) != 2:
self.checker.warning("returned tuple does not have two elements")
else:
status_code = ret_value.elts[1]
if (
isinstance(status_code, ast.Attribute)
and status_code.attr == "status_code"
):
# probably returning error coming from RWC
self.returns_rwc_status_code = True
elif isinstance(status_code, ast.Constant):
# e.g. return ..., 404
try:
self.status_codes.add(int(status_code.value))
except Exception:
self.checker.warning("status code is not an integer")
else:
self.checker.warning("unknown status code")

def check_signin_required(self, node: ast.FunctionDef) -> None:
"""Check whether `signin_required` decorator is applied to the view function."""
for decorator in node.decorator_list:
if (
isinstance(decorator, ast.Call)
and isinstance(decorator.func, ast.Name)
and decorator.func.id == "signin_required"
):
self.signin_required = True

def visit_Return(self, node: ast.Return) -> None:
self.check_status_code(node)

def visit_Call(self, node: ast.Call) -> None:
self.check_rwc_method(node)
self.generic_visit(node)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.check_signin_required(node)
self.generic_visit(node)


class Checker:
"""Check view functions to find common mistakes in their OpenAPI specification."""

rwc_spec = load_rwc_spec()

@classmethod
def get_rwc_status_codes(cls, operation_id):
"""Get the possible status codes returned by a RWC operation."""
for path in cls.rwc_spec["paths"]:
for operation in cls.rwc_spec["paths"][path].values():
if operation["operationId"] == operation_id:
return set(map(int, operation["responses"].keys()))
raise ValueError("RWC operation not found")

def __init__(self):
self.errors: bool = False

def info(self, msg: str) -> None:
"""Print an info message."""
print(msg)

def warning(self, msg: str) -> None:
"""Print a warning message"""
click.secho(f" --> {msg}", fg="yellow")

def error(self, msg: str) -> None:
"""Print an error message"""
self.errors = True
click.secho(f" --> {msg}", fg="red")

def check_view(self, name: str, view) -> None:
"""Check for errors in the OpenAPI specification of the given view."""
self.info(f"Checking {name}")

spec = load_yaml_from_docstring(view.__doc__)
if not spec:
self.warning("no specification found")
return
spec_codes = {code for op in spec.values() for code in op["responses"]}

# Parse the view code
tree = ast.parse(inspect.getsource(view))
visitor = ViewVisitor(self)
visitor.visit(tree)

if visitor.called_rwc_operations and not visitor.returns_rwc_status_code:
self.warning("detected request to RWC, but status code is not propagated")

if visitor.returns_rwc_status_code:
# Check that status codes returned by RWC are propagated
for operation_id in visitor.called_rwc_operations:
rwc_codes = self.get_rwc_status_codes(operation_id)
missing_codes = rwc_codes - spec_codes
for code in missing_codes:
self.error(f"missing {code} returned by `{operation_id}` of RWC")

for code in visitor.status_codes:
if code not in spec_codes:
self.error(f"missing {code} returned by view")

if visitor.signin_required:
# `signin_required` returns 401 and 403 if credentials are not valid
for code in (401, 403):
if code not in spec_codes:
self.error(f"missing {code} returned by `signin_required`")


def main():
app = create_app()
checker = Checker()
for name, view in app.view_functions.items():
checker.check_view(name, view)
if checker.errors:
sys.exit(1)


if __name__ == "__main__":
main()

0 comments on commit f7b3414

Please sign in to comment.