diff --git a/CHANGELOG.md b/CHANGELOG.md index 87a64ed2c..5d7d79e9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ Semantic versioning in our case means: - Fixes `CognitiveModuleComplexityViolation` to not trigger for a single-item modules - Fixes that `ConstantConditionViolation` was not reported for a BoolOp +- Functions and methods marked as `@overload` or `@typing.overload` + do not count in complexity rules ## 0.15.1 diff --git a/tests/test_visitors/test_ast/test_complexity/test_classes/test_method_counts.py b/tests/test_visitors/test_ast/test_complexity/test_classes/test_method_counts.py index 0a40aab8a..e3c77ca66 100644 --- a/tests/test_visitors/test_ast/test_complexity/test_classes/test_method_counts.py +++ b/tests/test_visitors/test_ast/test_complexity/test_classes/test_method_counts.py @@ -89,6 +89,17 @@ async def method(cls): ... async def method2(cls): ... """ +# regression1779 + +class_with_overloades = """ +class First(object): + @overload + def my_method(self): ... + + @typing.overload + def my_method(self): ... +""" + @pytest.mark.parametrize('code', [ module_without_methods, @@ -102,6 +113,7 @@ async def method2(cls): ... class_with_async_and_usual_class_methods, class_with_staticmethods, class_with_async_staticmethods, + class_with_overloades, ]) def test_method_counts_normal( assert_errors, @@ -144,3 +156,22 @@ def test_method_counts_violation( assert_errors(visitor, [TooManyMethodsViolation]) assert_error_text(visitor, '2', option_values.max_methods) + + +@pytest.mark.parametrize('code', [ + class_with_overloades, +]) +def test_method_counts_exceptions( + assert_errors, + parse_ast_tree, + code, + options, +): + """Testing that violations are raised not when using special cases.""" + tree = parse_ast_tree(code) + + option_values = options(max_methods=0) + visitor = MethodMembersVisitor(option_values, tree=tree) + visitor.run() + + assert_errors(visitor, []) diff --git a/tests/test_visitors/test_ast/test_complexity/test_counts/test_module_counts.py b/tests/test_visitors/test_ast/test_complexity/test_counts/test_module_counts.py index a961684ed..0661b6c7a 100644 --- a/tests/test_visitors/test_ast/test_complexity/test_counts/test_module_counts.py +++ b/tests/test_visitors/test_ast/test_complexity/test_counts/test_module_counts.py @@ -99,6 +99,18 @@ async def test(self): ... def other(self): ... """ +# regression1779 +module_with_overloads = """ +@overload +def first(): ... + +@typing.overload +def first(): ... + +# Only this def counts: +def first(): ... +""" + # Empty: empty_module = '' @@ -116,6 +128,7 @@ def other(self): ... module_with_staticmethods, module_with_classmethods, module_with_single_class, + module_with_overloads, ]) def test_module_counts_normal( assert_errors, @@ -166,6 +179,7 @@ def test_module_counts_violation( module_with_single_function, module_with_single_async_function, module_with_single_class, + module_with_overloads, ]) def test_module_counts_single_member( assert_errors, diff --git a/wemake_python_styleguide/logic/tree/decorators.py b/wemake_python_styleguide/logic/tree/decorators.py new file mode 100644 index 000000000..078307f7f --- /dev/null +++ b/wemake_python_styleguide/logic/tree/decorators.py @@ -0,0 +1,27 @@ +import ast + +from wemake_python_styleguide.types import AnyFunctionDef + + +def has_overload_decorator(function: AnyFunctionDef) -> bool: + """ + Detects if a function has ``@overload`` or ``@typing.overload`` decorators. + + It is useful, because ``@overload`` function defs + have slightly different rules: for example, they do not count as real defs + in complexity rules. + """ + for decorator in function.decorator_list: + is_partial_name = ( + isinstance(decorator, ast.Name) and + decorator.id == 'overload' + ) + is_full_name = ( + isinstance(decorator, ast.Attribute) and + decorator.attr == 'overload' and + isinstance(decorator.value, ast.Name) and + decorator.value.id == 'typing' + ) + if is_partial_name or is_full_name: + return True + return False diff --git a/wemake_python_styleguide/logic/tree/functions.py b/wemake_python_styleguide/logic/tree/functions.py index 62bded516..b53718465 100644 --- a/wemake_python_styleguide/logic/tree/functions.py +++ b/wemake_python_styleguide/logic/tree/functions.py @@ -1,5 +1,7 @@ from ast import Call, Return, Yield, YieldFrom, arg, walk -from typing import Container, Iterable, List, Optional, Tuple, Type, Union +from typing import Container, Iterable, List, Tuple, Type, Union + +from typing_extensions import Final from wemake_python_styleguide.compat.functions import get_posonlyargs from wemake_python_styleguide.logic import source @@ -26,6 +28,13 @@ Type[YieldFrom], ] +#: Method types +_METHOD_TYPES: Final = frozenset(( + 'method', + 'classmethod', + 'staticmethod', +)) + def given_function_called( node: Call, @@ -47,7 +56,7 @@ def given_function_called( return '' -def is_method(function_type: Optional[str]) -> bool: +def is_method(function_type: str) -> bool: """ Returns whether a given function type belongs to a class. @@ -70,7 +79,7 @@ def is_method(function_type: Optional[str]) -> bool: False """ - return function_type in {'method', 'classmethod', 'staticmethod'} + return function_type in _METHOD_TYPES def get_all_arguments(node: AnyFunctionDefAndLambda) -> List[arg]: diff --git a/wemake_python_styleguide/visitors/ast/complexity/classes.py b/wemake_python_styleguide/visitors/ast/complexity/classes.py index b356a24ea..657c84281 100644 --- a/wemake_python_styleguide/visitors/ast/complexity/classes.py +++ b/wemake_python_styleguide/visitors/ast/complexity/classes.py @@ -6,7 +6,7 @@ from wemake_python_styleguide.logic.naming import access from wemake_python_styleguide.logic.nodes import get_parent -from wemake_python_styleguide.logic.tree import classes +from wemake_python_styleguide.logic.tree import classes, decorators from wemake_python_styleguide.types import AnyFunctionDef from wemake_python_styleguide.violations.complexity import ( TooManyBaseClassesViolation, @@ -38,10 +38,6 @@ def t(self): File "", line 3 SyntaxError: cannot use named assignment with attribute - Raises: - TooManyBaseClassesViolation - TooManyPublicAttributesViolation - """ self._check_base_classes(node) self._check_public_attributes(node) @@ -92,17 +88,14 @@ def __init__(self, *args, **kwargs) -> None: self._methods: DefaultDict[ast.ClassDef, int] = defaultdict(int) def visit_any_function(self, node: AnyFunctionDef) -> None: - """ - Counts the number of methods in a single class. - - Raises: - TooManyMethodsViolation - - """ + """Counts the number of methods in a single class.""" self._check_method(node) self.generic_visit(node) def _check_method(self, node: AnyFunctionDef) -> None: + if decorators.has_overload_decorator(node): + return # we don't count `@overload` methods + parent = get_parent(node) if isinstance(parent, ast.ClassDef): self._methods[parent] += 1 diff --git a/wemake_python_styleguide/visitors/ast/complexity/counts.py b/wemake_python_styleguide/visitors/ast/complexity/counts.py index c8aa7d969..3214f4dae 100644 --- a/wemake_python_styleguide/visitors/ast/complexity/counts.py +++ b/wemake_python_styleguide/visitors/ast/complexity/counts.py @@ -5,8 +5,9 @@ from typing_extensions import final from wemake_python_styleguide import constants +from wemake_python_styleguide.compat.aliases import FunctionNodes from wemake_python_styleguide.logic.nodes import get_parent -from wemake_python_styleguide.logic.tree.functions import is_method +from wemake_python_styleguide.logic.tree import decorators, functions from wemake_python_styleguide.types import AnyFunctionDef from wemake_python_styleguide.violations import complexity from wemake_python_styleguide.visitors.base import BaseNodeVisitor @@ -32,22 +33,21 @@ def __init__(self, *args, **kwargs) -> None: self._public_items_count = 0 def visit_module_members(self, node: _ModuleMembers) -> None: - """ - Counts the number of _ModuleMembers in a single module. - - Raises: - TooManyModuleMembersViolation - - """ + """Counts the number of _ModuleMembers in a single module.""" self._check_decorators_count(node) self._check_members_count(node) self.generic_visit(node) def _check_members_count(self, node: _ModuleMembers) -> None: """This method increases the number of module members.""" - is_real_method = is_method(getattr(node, 'function_type', None)) + if functions.is_method(getattr(node, 'function_type', '')): + return + + if isinstance(node, FunctionNodes): + if decorators.has_overload_decorator(node): + return # We don't count `@overload` defs as real defs - if isinstance(get_parent(node), ast.Module) and not is_real_method: + if isinstance(get_parent(node), ast.Module): self._public_items_count += 1 def _check_decorators_count(self, node: _ModuleMembers) -> None: @@ -76,24 +76,12 @@ class ConditionsVisitor(BaseNodeVisitor): """Checks booleans for condition counts.""" def visit_BoolOp(self, node: ast.BoolOp) -> None: - """ - Counts the number of conditions. - - Raises: - TooManyConditionsViolation - - """ + """Counts the number of conditions.""" self._check_conditions(node) self.generic_visit(node) def visit_Compare(self, node: ast.Compare) -> None: - """ - Counts the number of compare parts. - - Raises: - TooLongCompareViolation - - """ + """Counts the number of compare parts.""" self._check_compares(node) self.generic_visit(node) @@ -148,13 +136,7 @@ def __init__(self, *args, **kwargs) -> None: ) def visit_If(self, node: ast.If) -> None: - """ - Checks condition not to reimplement switch. - - Raises: - TooManyElifsViolation - - """ + """Checks condition not to reimplement switch.""" self._check_elifs(node) self.generic_visit(node) @@ -196,14 +178,7 @@ class TryExceptVisitor(BaseNodeVisitor): """Visits all try/except nodes to ensure that they are not too complex.""" def visit_Try(self, node: ast.Try) -> None: - """ - Ensures that try/except is correct. - - Raises: - TooManyExceptCasesViolation - TooLongTryBodyViolation - - """ + """Ensures that try/except is correct.""" self._check_except_count(node) self._check_try_body_length(node) self.generic_visit(node) @@ -234,13 +209,7 @@ class YieldTupleVisitor(BaseNodeVisitor): """Finds too long ``tuples`` in ``yield`` expressions.""" def visit_Yield(self, node: ast.Yield) -> None: - """ - Helper to get all ``yield`` nodes in a function at once. - - Raises: - TooLongYieldTupleViolation - - """ + """Helper to get all ``yield`` nodes in a function at once.""" self._check_yield_values(node) self.generic_visit(node) @@ -261,21 +230,21 @@ class TupleUnpackVisitor(BaseNodeVisitor): """Finds statements with too many variables receiving an unpacked tuple.""" def visit_Assign(self, node: ast.Assign) -> None: - """ - Finds statements using too many variables to receive an unpacked tuple. + """Finds statements using too many variables to unpack a tuple.""" + self._check_tuple_unpack(node) + self.generic_visit(node) - Raises: - TooLongTupleUnpackViolation + def _check_tuple_unpack(self, node: ast.Assign) -> None: + if not isinstance(node.targets[0], ast.Tuple): + return - """ - if isinstance(node.targets[0], ast.Tuple): - if len(node.targets[0].elts) > self.options.max_tuple_unpack_length: - self.add_violation( - complexity.TooLongTupleUnpackViolation( - node, - text=str(len(node.targets[0].elts)), - baseline=self.options.max_tuple_unpack_length, - ), - ) + if len(node.targets[0].elts) <= self.options.max_tuple_unpack_length: + return - self.generic_visit(node) + self.add_violation( + complexity.TooLongTupleUnpackViolation( + node, + text=str(len(node.targets[0].elts)), + baseline=self.options.max_tuple_unpack_length, + ), + )