diff options
-rw-r--r-- | astroid/_ast.py | 17 | ||||
-rw-r--r-- | astroid/rebuilder.py | 37 | ||||
-rw-r--r-- | astroid/scoped_nodes.py | 31 | ||||
-rw-r--r-- | astroid/tests/unittest_nodes.py | 49 |
4 files changed, 127 insertions, 7 deletions
diff --git a/astroid/_ast.py b/astroid/_ast.py index 8220e108..58a58c77 100644 --- a/astroid/_ast.py +++ b/astroid/_ast.py @@ -1,4 +1,6 @@ import ast +from collections import namedtuple +from typing import Optional _ast_py2 = _ast_py3 = None try: @@ -8,6 +10,9 @@ except ImportError: pass +FunctionType = namedtuple('FunctionType', ['argtypes', 'returns']) + + def _get_parser_module(parse_python_two: bool = False): if parse_python_two: parser_module = _ast_py2 @@ -19,3 +24,15 @@ def _get_parser_module(parse_python_two: bool = False): def _parse(string: str, parse_python_two: bool = False): return _get_parser_module(parse_python_two=parse_python_two).parse(string) + + +def parse_function_type_comment(type_comment: str) -> Optional[FunctionType]: + """Given a correct type comment, obtain a FunctionType object""" + if _ast_py3 is None: + return None + + func_type = _ast_py3.parse(type_comment, "<type_comment>", "func_type") + return FunctionType( + argtypes=func_type.argtypes, + returns=func_type.returns, + ) diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 223ddcdd..911d5601 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -13,7 +13,7 @@ order to get a single Astroid representation import sys import astroid -from astroid._ast import _parse, _get_parser_module +from astroid._ast import _parse, _get_parser_module, parse_function_type_comment from astroid import nodes @@ -259,6 +259,24 @@ class TreeRebuilder(object): return type_object.value + def check_function_type_comment(self, node): + type_comment = getattr(node, 'type_comment', None) + if not type_comment: + return None + + try: + type_comment_ast = parse_function_type_comment(type_comment) + except SyntaxError: + # Invalid type comment, just skip it. + return None + + returns = None + argtypes = [self.visit(elem, node) for elem in (type_comment_ast.argtypes or [])] + if type_comment_ast.returns: + returns = self.visit(type_comment_ast.returns, node) + + return returns, argtypes + def visit_assign(self, node, parent): """visit a Assign node by returning a fresh instance of it""" type_annotation = self.check_type_comment(node) @@ -528,10 +546,19 @@ class TreeRebuilder(object): returns = self.visit(node.returns, newnode) else: returns = None - newnode.postinit(self.visit(node.args, newnode), - [self.visit(child, newnode) - for child in node.body], - decorators, returns) + + type_comment_args = type_comment_returns = None + type_comment_annotation = self.check_function_type_comment(node) + if type_comment_annotation: + type_comment_returns, type_comment_args = type_comment_annotation + newnode.postinit( + args=self.visit(node.args, newnode), + body=[self.visit(child, newnode) for child in node.body], + decorators=decorators, + returns=returns, + type_comment_returns=type_comment_returns, + type_comment_args=type_comment_args, + ) self._global_names.pop() return newnode diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index de3cc5c3..e63402e4 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -1243,9 +1243,26 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda): :type: bool """ + type_annotation = None + """If present, this will contain the type annotation passed by a type comment + + :type: NodeNG or None + """ + type_comment_args = None + """ + If present, this will contain the type annotation for arguments + passed by a type comment + """ + type_comment_returns = None + """If present, this will contain the return type annotation, passed by a type comment""" # attributes below are set by the builder module or by raw factories _other_fields = ('name', 'doc') - _other_other_fields = ('locals', '_type') + _other_other_fields = ( + 'locals', + '_type', + 'type_comment_returns', + 'type_comment_args', + ) _type = None def __init__(self, name=None, doc=None, lineno=None, @@ -1286,7 +1303,11 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda): frame.set_local(name, self) # pylint: disable=arguments-differ; different than Lambdas - def postinit(self, args, body, decorators=None, returns=None): + def postinit(self, args, body, + decorators=None, + returns=None, + type_comment_returns=None, + type_comment_args=None): """Do some setup after initialisation. :param args: The arguments that the function takes. @@ -1298,11 +1319,17 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda): :param decorators: The decorators that are applied to this method or function. :type decorators: Decorators or None + :params type_comment_returns: + The return type annotation passed via a type comment. + :params type_comment_args: + The args type annotation passed via a type comment. """ self.args = args self.body = body self.decorators = decorators self.returns = returns + self.type_comment_returns = type_comment_returns + self.type_comment_args = type_comment_args if isinstance(self.parent.frame(), ClassDef): self.set_local('__class__', self.parent.frame()) diff --git a/astroid/tests/unittest_nodes.py b/astroid/tests/unittest_nodes.py index 4f46ca29..a90c4cea 100644 --- a/astroid/tests/unittest_nodes.py +++ b/astroid/tests/unittest_nodes.py @@ -874,5 +874,54 @@ def test_type_comments_invalid_expression(): assert node.type_annotation is None +@pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") +def test_type_comments_invalid_function_comments(): + module = builder.parse(''' + def func(): + # type: something completely invalid + pass + def func1(): + # typeee: 2*+4 + pass + def func2(): + # type: List[int + pass + ''') + for node in module.body: + assert node.type_comment_returns is None + assert node.type_comment_args is None + + +@pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") +def test_type_comments_function(): + module = builder.parse(''' + def func(): + # type: (int) -> str + pass + def func1(): + # type: (int, int, int) -> (str, str) + pass + def func2(): + # type: (int, int, str, List[int]) -> List[int] + pass + ''') + expected_annotations = [ + (["int"], astroid.Name, "str"), + (["int", "int", "int"], astroid.Tuple, "(str, str)"), + (["int", "int", "str", "List[int]"], astroid.Subscript, "List[int]"), + ] + for node, ( + expected_args, + expected_returns_type, + expected_returns_string + ) in zip(module.body, expected_annotations): + assert node.type_comment_returns is not None + assert node.type_comment_args is not None + for expected_arg, actual_arg in zip(expected_args, node.type_comment_args): + assert actual_arg.as_string() == expected_arg + assert isinstance(node.type_comment_returns, expected_returns_type) + assert node.type_comment_returns.as_string() == expected_returns_string + + if __name__ == '__main__': unittest.main() |