Skip to content

Commit

Permalink
Closes #1779 (#1916)
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn authored Feb 27, 2021
1 parent a605c95 commit b554dd4
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 76 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Semantic versioning in our case means:
- Fixes `CognitiveModuleComplexityViolation` to not trigger
for a single-item modules
- Fixes that `ConstantConditionViolation` was not reported for a BoolOp
- Functions and methods marked as `@overload` or `@typing.overload`
do not count in complexity rules


## 0.15.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ async def method(cls): ...
async def method2(cls): ...
"""

# regression1779

class_with_overloades = """
class First(object):
@overload
def my_method(self): ...
@typing.overload
def my_method(self): ...
"""


@pytest.mark.parametrize('code', [
module_without_methods,
Expand All @@ -102,6 +113,7 @@ async def method2(cls): ...
class_with_async_and_usual_class_methods,
class_with_staticmethods,
class_with_async_staticmethods,
class_with_overloades,
])
def test_method_counts_normal(
assert_errors,
Expand Down Expand Up @@ -144,3 +156,22 @@ def test_method_counts_violation(

assert_errors(visitor, [TooManyMethodsViolation])
assert_error_text(visitor, '2', option_values.max_methods)


@pytest.mark.parametrize('code', [
class_with_overloades,
])
def test_method_counts_exceptions(
assert_errors,
parse_ast_tree,
code,
options,
):
"""Testing that violations are raised not when using special cases."""
tree = parse_ast_tree(code)

option_values = options(max_methods=0)
visitor = MethodMembersVisitor(option_values, tree=tree)
visitor.run()

assert_errors(visitor, [])
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ async def test(self): ...
def other(self): ...
"""

# regression1779
module_with_overloads = """
@overload
def first(): ...
@typing.overload
def first(): ...
# Only this def counts:
def first(): ...
"""

# Empty:

empty_module = ''
Expand All @@ -116,6 +128,7 @@ def other(self): ...
module_with_staticmethods,
module_with_classmethods,
module_with_single_class,
module_with_overloads,
])
def test_module_counts_normal(
assert_errors,
Expand Down Expand Up @@ -166,6 +179,7 @@ def test_module_counts_violation(
module_with_single_function,
module_with_single_async_function,
module_with_single_class,
module_with_overloads,
])
def test_module_counts_single_member(
assert_errors,
Expand Down
27 changes: 27 additions & 0 deletions wemake_python_styleguide/logic/tree/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ast

from wemake_python_styleguide.types import AnyFunctionDef


def has_overload_decorator(function: AnyFunctionDef) -> bool:
"""
Detects if a function has ``@overload`` or ``@typing.overload`` decorators.
It is useful, because ``@overload`` function defs
have slightly different rules: for example, they do not count as real defs
in complexity rules.
"""
for decorator in function.decorator_list:
is_partial_name = (
isinstance(decorator, ast.Name) and
decorator.id == 'overload'
)
is_full_name = (
isinstance(decorator, ast.Attribute) and
decorator.attr == 'overload' and
isinstance(decorator.value, ast.Name) and
decorator.value.id == 'typing'
)
if is_partial_name or is_full_name:
return True
return False
15 changes: 12 additions & 3 deletions wemake_python_styleguide/logic/tree/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ast import Call, Return, Yield, YieldFrom, arg, walk
from typing import Container, Iterable, List, Optional, Tuple, Type, Union
from typing import Container, Iterable, List, Tuple, Type, Union

from typing_extensions import Final

from wemake_python_styleguide.compat.functions import get_posonlyargs
from wemake_python_styleguide.logic import source
Expand All @@ -26,6 +28,13 @@
Type[YieldFrom],
]

#: Method types
_METHOD_TYPES: Final = frozenset((
'method',
'classmethod',
'staticmethod',
))


def given_function_called(
node: Call,
Expand All @@ -47,7 +56,7 @@ def given_function_called(
return ''


def is_method(function_type: Optional[str]) -> bool:
def is_method(function_type: str) -> bool:
"""
Returns whether a given function type belongs to a class.
Expand All @@ -70,7 +79,7 @@ def is_method(function_type: Optional[str]) -> bool:
False
"""
return function_type in {'method', 'classmethod', 'staticmethod'}
return function_type in _METHOD_TYPES


def get_all_arguments(node: AnyFunctionDefAndLambda) -> List[arg]:
Expand Down
17 changes: 5 additions & 12 deletions wemake_python_styleguide/visitors/ast/complexity/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from wemake_python_styleguide.logic.naming import access
from wemake_python_styleguide.logic.nodes import get_parent
from wemake_python_styleguide.logic.tree import classes
from wemake_python_styleguide.logic.tree import classes, decorators
from wemake_python_styleguide.types import AnyFunctionDef
from wemake_python_styleguide.violations.complexity import (
TooManyBaseClassesViolation,
Expand Down Expand Up @@ -38,10 +38,6 @@ def t(self):
File "<stdin>", line 3
SyntaxError: cannot use named assignment with attribute
Raises:
TooManyBaseClassesViolation
TooManyPublicAttributesViolation
"""
self._check_base_classes(node)
self._check_public_attributes(node)
Expand Down Expand Up @@ -92,17 +88,14 @@ def __init__(self, *args, **kwargs) -> None:
self._methods: DefaultDict[ast.ClassDef, int] = defaultdict(int)

def visit_any_function(self, node: AnyFunctionDef) -> None:
"""
Counts the number of methods in a single class.
Raises:
TooManyMethodsViolation
"""
"""Counts the number of methods in a single class."""
self._check_method(node)
self.generic_visit(node)

def _check_method(self, node: AnyFunctionDef) -> None:
if decorators.has_overload_decorator(node):
return # we don't count `@overload` methods

parent = get_parent(node)
if isinstance(parent, ast.ClassDef):
self._methods[parent] += 1
Expand Down
91 changes: 30 additions & 61 deletions wemake_python_styleguide/visitors/ast/complexity/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing_extensions import final

from wemake_python_styleguide import constants
from wemake_python_styleguide.compat.aliases import FunctionNodes
from wemake_python_styleguide.logic.nodes import get_parent
from wemake_python_styleguide.logic.tree.functions import is_method
from wemake_python_styleguide.logic.tree import decorators, functions
from wemake_python_styleguide.types import AnyFunctionDef
from wemake_python_styleguide.violations import complexity
from wemake_python_styleguide.visitors.base import BaseNodeVisitor
Expand All @@ -32,22 +33,21 @@ def __init__(self, *args, **kwargs) -> None:
self._public_items_count = 0

def visit_module_members(self, node: _ModuleMembers) -> None:
"""
Counts the number of _ModuleMembers in a single module.
Raises:
TooManyModuleMembersViolation
"""
"""Counts the number of _ModuleMembers in a single module."""
self._check_decorators_count(node)
self._check_members_count(node)
self.generic_visit(node)

def _check_members_count(self, node: _ModuleMembers) -> None:
"""This method increases the number of module members."""
is_real_method = is_method(getattr(node, 'function_type', None))
if functions.is_method(getattr(node, 'function_type', '')):
return

if isinstance(node, FunctionNodes):
if decorators.has_overload_decorator(node):
return # We don't count `@overload` defs as real defs

if isinstance(get_parent(node), ast.Module) and not is_real_method:
if isinstance(get_parent(node), ast.Module):
self._public_items_count += 1

def _check_decorators_count(self, node: _ModuleMembers) -> None:
Expand Down Expand Up @@ -76,24 +76,12 @@ class ConditionsVisitor(BaseNodeVisitor):
"""Checks booleans for condition counts."""

def visit_BoolOp(self, node: ast.BoolOp) -> None:
"""
Counts the number of conditions.
Raises:
TooManyConditionsViolation
"""
"""Counts the number of conditions."""
self._check_conditions(node)
self.generic_visit(node)

def visit_Compare(self, node: ast.Compare) -> None:
"""
Counts the number of compare parts.
Raises:
TooLongCompareViolation
"""
"""Counts the number of compare parts."""
self._check_compares(node)
self.generic_visit(node)

Expand Down Expand Up @@ -148,13 +136,7 @@ def __init__(self, *args, **kwargs) -> None:
)

def visit_If(self, node: ast.If) -> None:
"""
Checks condition not to reimplement switch.
Raises:
TooManyElifsViolation
"""
"""Checks condition not to reimplement switch."""
self._check_elifs(node)
self.generic_visit(node)

Expand Down Expand Up @@ -196,14 +178,7 @@ class TryExceptVisitor(BaseNodeVisitor):
"""Visits all try/except nodes to ensure that they are not too complex."""

def visit_Try(self, node: ast.Try) -> None:
"""
Ensures that try/except is correct.
Raises:
TooManyExceptCasesViolation
TooLongTryBodyViolation
"""
"""Ensures that try/except is correct."""
self._check_except_count(node)
self._check_try_body_length(node)
self.generic_visit(node)
Expand Down Expand Up @@ -234,13 +209,7 @@ class YieldTupleVisitor(BaseNodeVisitor):
"""Finds too long ``tuples`` in ``yield`` expressions."""

def visit_Yield(self, node: ast.Yield) -> None:
"""
Helper to get all ``yield`` nodes in a function at once.
Raises:
TooLongYieldTupleViolation
"""
"""Helper to get all ``yield`` nodes in a function at once."""
self._check_yield_values(node)
self.generic_visit(node)

Expand All @@ -261,21 +230,21 @@ class TupleUnpackVisitor(BaseNodeVisitor):
"""Finds statements with too many variables receiving an unpacked tuple."""

def visit_Assign(self, node: ast.Assign) -> None:
"""
Finds statements using too many variables to receive an unpacked tuple.
"""Finds statements using too many variables to unpack a tuple."""
self._check_tuple_unpack(node)
self.generic_visit(node)

Raises:
TooLongTupleUnpackViolation
def _check_tuple_unpack(self, node: ast.Assign) -> None:
if not isinstance(node.targets[0], ast.Tuple):
return

"""
if isinstance(node.targets[0], ast.Tuple):
if len(node.targets[0].elts) > self.options.max_tuple_unpack_length:
self.add_violation(
complexity.TooLongTupleUnpackViolation(
node,
text=str(len(node.targets[0].elts)),
baseline=self.options.max_tuple_unpack_length,
),
)
if len(node.targets[0].elts) <= self.options.max_tuple_unpack_length:
return

self.generic_visit(node)
self.add_violation(
complexity.TooLongTupleUnpackViolation(
node,
text=str(len(node.targets[0].elts)),
baseline=self.options.max_tuple_unpack_length,
),
)

0 comments on commit b554dd4

Please sign in to comment.