diff --git a/lizard.py b/lizard.py index 63f0153..4ebb4f5 100755 --- a/lizard.py +++ b/lizard.py @@ -301,11 +301,17 @@ def unqualified_name(self): " %(name)s@%(start_line)s-%(end_line)s@%(filename)s" % self.__dict__) - parameter_count = property(lambda self: len(self.full_parameters)) + parameter_count = property(lambda self: len(self.parameters)) @property def parameters(self): - matches = [re.search(r'(\w+)(\s=.*)?$', f) + # Exclude empty tokens as parameters. These can occur in languages + # allowing a trailing comma on the last parameter in an function + # argument list. + # Regex matches the parameter name, then optionally: + # - a default value given after an '=' sign + # - a type annotation given after a ':' + matches = [re.search(r'(\w+)(\s=.*)?(\s:.*)?$', f) for f in self.full_parameters] return [m.group(1) for m in matches if m] diff --git a/lizard_languages/python.py b/lizard_languages/python.py index 4a23f9e..3e935cc 100644 --- a/lizard_languages/python.py +++ b/lizard_languages/python.py @@ -84,6 +84,8 @@ def _function(self, token): def _dec(self, token): if token == ')': self._state = self._state_colon + elif token == '[': + self._state = self._state_parameterized_type_annotation else: self.context.parameter(token) return @@ -100,3 +102,8 @@ def _state_first_line(self, token): if token.startswith('"""') or token.startswith("'''"): self.context.add_nloc(-token.count('\n') - 1) self._state_global(token) + + def _state_parameterized_type_annotation(self, token): + self.context.add_to_long_function_name(" " + token) + if token == ']': + self._state = self._dec diff --git a/test/test_languages/testPython.py b/test/test_languages/testPython.py index 107df2f..a052024 100644 --- a/test/test_languages/testPython.py +++ b/test/test_languages/testPython.py @@ -149,6 +149,62 @@ def function_with_2_parameters_and_default_value(a, b=None): functions = get_python_function_list(inspect.getsource(namespace_df)) self.assertEqual(2, functions[0].parameter_count) self.assertEqual(['a', 'b'], functions[0].parameters) + self.assertEqual("function_with_2_parameters_and_default_value( a , b = None )", + functions[0].long_name) + + def test_parameter_count_with_type_annotations(self): + functions = get_python_function_list(''' + def function_with_3_parameters(a: str, b: int, c: float): + pass + ''') + self.assertEqual(1, len(functions)) + self.assertEqual(3, functions[0].parameter_count) + self.assertEqual(['a', 'b', 'c'], functions[0].parameters) + self.assertEqual("function_with_3_parameters( a : str , b : int , c : float )", + functions[0].long_name) + + def test_parameter_count_with_type_annotation_and_default(self): + functions = get_python_function_list(''' + def function_with_3_parameters(a: int = 1): + pass + ''') + self.assertEqual(1, len(functions)) + self.assertEqual(1, functions[0].parameter_count) + self.assertEqual(['a'], functions[0].parameters) + self.assertEqual("function_with_3_parameters( a : int = 1 )", + functions[0].long_name) + + def test_parameter_count_with_parameterized_type_annotations(self): + functions = get_python_function_list(''' + def function_with_parameterized_parameter(a: dict[str, tuple[int, float]]): + pass + def function_with_3_parameterized_parameters(a: dict[str, int], + b: list[float], + c: tuple[int, float, str] + ): + pass + + ''') + self.assertEqual(2, len(functions)) + self.assertEqual(1, functions[0].parameter_count) + self.assertEqual(['a'], functions[0].parameters) + self.assertEqual("function_with_parameterized_parameter( a : dict [ str , tuple [ int , float ] ] )", + functions[0].long_name) + self.assertEqual(3, functions[1].parameter_count) + self.assertEqual(['a', 'b', 'c'], functions[1].parameters) + self.assertEqual("function_with_3_parameterized_parameters( a : dict [ str , int ] , b : list [ float ] , c : tuple [ int , float , str ] )", + functions[1].long_name) + + def test_parameter_count_with_trailing_comma(self): + functions = get_python_function_list(''' + def foo(arg1, + arg2, + ): + # comment + return True + ''') + self.assertEqual(2, functions[0].parameter_count) + self.assertEqual(['arg1', 'arg2'], functions[0].parameters) def test_function_end(self): class namespace3: