summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--astroid/_ast.py17
-rw-r--r--astroid/rebuilder.py37
-rw-r--r--astroid/scoped_nodes.py31
-rw-r--r--astroid/tests/unittest_nodes.py49
4 files changed, 127 insertions, 7 deletions
diff --git a/astroid/_ast.py b/astroid/_ast.py
index 8220e108..58a58c77 100644
--- a/astroid/_ast.py
+++ b/astroid/_ast.py
@@ -1,4 +1,6 @@
import ast
+from collections import namedtuple
+from typing import Optional
_ast_py2 = _ast_py3 = None
try:
@@ -8,6 +10,9 @@ except ImportError:
pass
+FunctionType = namedtuple('FunctionType', ['argtypes', 'returns'])
+
+
def _get_parser_module(parse_python_two: bool = False):
if parse_python_two:
parser_module = _ast_py2
@@ -19,3 +24,15 @@ def _get_parser_module(parse_python_two: bool = False):
def _parse(string: str,
parse_python_two: bool = False):
return _get_parser_module(parse_python_two=parse_python_two).parse(string)
+
+
+def parse_function_type_comment(type_comment: str) -> Optional[FunctionType]:
+ """Given a correct type comment, obtain a FunctionType object"""
+ if _ast_py3 is None:
+ return None
+
+ func_type = _ast_py3.parse(type_comment, "<type_comment>", "func_type")
+ return FunctionType(
+ argtypes=func_type.argtypes,
+ returns=func_type.returns,
+ )
diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py
index 223ddcdd..911d5601 100644
--- a/astroid/rebuilder.py
+++ b/astroid/rebuilder.py
@@ -13,7 +13,7 @@ order to get a single Astroid representation
import sys
import astroid
-from astroid._ast import _parse, _get_parser_module
+from astroid._ast import _parse, _get_parser_module, parse_function_type_comment
from astroid import nodes
@@ -259,6 +259,24 @@ class TreeRebuilder(object):
return type_object.value
+ def check_function_type_comment(self, node):
+ type_comment = getattr(node, 'type_comment', None)
+ if not type_comment:
+ return None
+
+ try:
+ type_comment_ast = parse_function_type_comment(type_comment)
+ except SyntaxError:
+ # Invalid type comment, just skip it.
+ return None
+
+ returns = None
+ argtypes = [self.visit(elem, node) for elem in (type_comment_ast.argtypes or [])]
+ if type_comment_ast.returns:
+ returns = self.visit(type_comment_ast.returns, node)
+
+ return returns, argtypes
+
def visit_assign(self, node, parent):
"""visit a Assign node by returning a fresh instance of it"""
type_annotation = self.check_type_comment(node)
@@ -528,10 +546,19 @@ class TreeRebuilder(object):
returns = self.visit(node.returns, newnode)
else:
returns = None
- newnode.postinit(self.visit(node.args, newnode),
- [self.visit(child, newnode)
- for child in node.body],
- decorators, returns)
+
+ type_comment_args = type_comment_returns = None
+ type_comment_annotation = self.check_function_type_comment(node)
+ if type_comment_annotation:
+ type_comment_returns, type_comment_args = type_comment_annotation
+ newnode.postinit(
+ args=self.visit(node.args, newnode),
+ body=[self.visit(child, newnode) for child in node.body],
+ decorators=decorators,
+ returns=returns,
+ type_comment_returns=type_comment_returns,
+ type_comment_args=type_comment_args,
+ )
self._global_names.pop()
return newnode
diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py
index de3cc5c3..e63402e4 100644
--- a/astroid/scoped_nodes.py
+++ b/astroid/scoped_nodes.py
@@ -1243,9 +1243,26 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda):
:type: bool
"""
+ type_annotation = None
+ """If present, this will contain the type annotation passed by a type comment
+
+ :type: NodeNG or None
+ """
+ type_comment_args = None
+ """
+ If present, this will contain the type annotation for arguments
+ passed by a type comment
+ """
+ type_comment_returns = None
+ """If present, this will contain the return type annotation, passed by a type comment"""
# attributes below are set by the builder module or by raw factories
_other_fields = ('name', 'doc')
- _other_other_fields = ('locals', '_type')
+ _other_other_fields = (
+ 'locals',
+ '_type',
+ 'type_comment_returns',
+ 'type_comment_args',
+ )
_type = None
def __init__(self, name=None, doc=None, lineno=None,
@@ -1286,7 +1303,11 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda):
frame.set_local(name, self)
# pylint: disable=arguments-differ; different than Lambdas
- def postinit(self, args, body, decorators=None, returns=None):
+ def postinit(self, args, body,
+ decorators=None,
+ returns=None,
+ type_comment_returns=None,
+ type_comment_args=None):
"""Do some setup after initialisation.
:param args: The arguments that the function takes.
@@ -1298,11 +1319,17 @@ class FunctionDef(mixins.MultiLineBlockMixin, node_classes.Statement, Lambda):
:param decorators: The decorators that are applied to this
method or function.
:type decorators: Decorators or None
+ :params type_comment_returns:
+ The return type annotation passed via a type comment.
+ :params type_comment_args:
+ The args type annotation passed via a type comment.
"""
self.args = args
self.body = body
self.decorators = decorators
self.returns = returns
+ self.type_comment_returns = type_comment_returns
+ self.type_comment_args = type_comment_args
if isinstance(self.parent.frame(), ClassDef):
self.set_local('__class__', self.parent.frame())
diff --git a/astroid/tests/unittest_nodes.py b/astroid/tests/unittest_nodes.py
index 4f46ca29..a90c4cea 100644
--- a/astroid/tests/unittest_nodes.py
+++ b/astroid/tests/unittest_nodes.py
@@ -874,5 +874,54 @@ def test_type_comments_invalid_expression():
assert node.type_annotation is None
+@pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast")
+def test_type_comments_invalid_function_comments():
+ module = builder.parse('''
+ def func():
+ # type: something completely invalid
+ pass
+ def func1():
+ # typeee: 2*+4
+ pass
+ def func2():
+ # type: List[int
+ pass
+ ''')
+ for node in module.body:
+ assert node.type_comment_returns is None
+ assert node.type_comment_args is None
+
+
+@pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast")
+def test_type_comments_function():
+ module = builder.parse('''
+ def func():
+ # type: (int) -> str
+ pass
+ def func1():
+ # type: (int, int, int) -> (str, str)
+ pass
+ def func2():
+ # type: (int, int, str, List[int]) -> List[int]
+ pass
+ ''')
+ expected_annotations = [
+ (["int"], astroid.Name, "str"),
+ (["int", "int", "int"], astroid.Tuple, "(str, str)"),
+ (["int", "int", "str", "List[int]"], astroid.Subscript, "List[int]"),
+ ]
+ for node, (
+ expected_args,
+ expected_returns_type,
+ expected_returns_string
+ ) in zip(module.body, expected_annotations):
+ assert node.type_comment_returns is not None
+ assert node.type_comment_args is not None
+ for expected_arg, actual_arg in zip(expected_args, node.type_comment_args):
+ assert actual_arg.as_string() == expected_arg
+ assert isinstance(node.type_comment_returns, expected_returns_type)
+ assert node.type_comment_returns.as_string() == expected_returns_string
+
+
if __name__ == '__main__':
unittest.main()