diff options
author | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-07-01 03:28:04 +0300 |
---|---|---|
committer | Claudiu Popa <cpopa@cloudbasesolutions.com> | 2015-07-01 03:28:04 +0300 |
commit | 47cc2e4a103afb185190995b6518b4f169baea8f (patch) | |
tree | 92f6511458d84ac1f199aea03ebd8b6687ba1242 | |
parent | b1a5e534e68324dda2e2fd396e8d91e4daf0ed59 (diff) | |
download | astroid-47cc2e4a103afb185190995b6518b4f169baea8f.tar.gz |
Add annotation support for function.as_string(). Closes issue #37.
-rw-r--r-- | ChangeLog | 2 | ||||
-rw-r--r-- | astroid/as_string.py | 18 | ||||
-rw-r--r-- | astroid/node_classes.py | 10 | ||||
-rw-r--r-- | astroid/tests/unittest_python3.py | 14 |
4 files changed, 35 insertions, 9 deletions
@@ -210,6 +210,8 @@ Change log for the astroid package (used to be astng) * Star unpacking in assignments returns properly a list, not the individual components. Closes issue #138. + * Add annotation support for function.as_string(). Closes issue #37. + 2015-03-14 -- 1.3.6 diff --git a/astroid/as_string.py b/astroid/as_string.py index b9ec447..01eb718 100644 --- a/astroid/as_string.py +++ b/astroid/as_string.py @@ -25,6 +25,8 @@ import sys +import six + INDENT = ' ' # 4 spaces ; keep indentation variable @@ -269,10 +271,18 @@ class AsStringVisitor(object): decorate = node.decorators and node.decorators.accept(self) or '' docs = node.doc and '\n%s"""%s"""' % (INDENT, node.doc) or '' return_annotation = '' - if node.returns: - return_annotation = ' -> ' + node.returns.name - return '\n%sdef %s(%s):%s%s\n%s' % (decorate, node.name, node.args.accept(self), - return_annotation, docs, self._stmt_list(node.body)) + if six.PY3 and node.returns: + return_annotation = '->' + node.returns.as_string() + trailer = return_annotation + ":" + else: + trailer = ":" + def_format = "\n{decorators}def {name}({args}){trailer}{docs}\n{body}" + return def_format.format(decorators=decorate, + name=node.name, + args=node.args.accept(self), + trailer=trailer, + docs=docs, + body=self._stmt_list(node.body)) def visit_genexpr(self, node): """return an astroid.GenExpr node as string""" diff --git a/astroid/node_classes.py b/astroid/node_classes.py index e5a7167..bf3a9e4 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -317,7 +317,8 @@ class Arguments(NodeNG, AssignTypeMixin): result = [] if self.args: result.append( - _format_args(self.args, self.annotations, self.defaults) + _format_args(self.args, getattr(self, 'annotations', None), + self.defaults) ) if self.vararg: result.append('*%s' % self.vararg) @@ -385,13 +386,14 @@ def _format_args(args, annotations=None, defaults=None): annotations = [] if defaults is not None: default_offset = len(args) - len(defaults) - for i, (arg, annotation) in enumerate(zip(args, annotations)): + packed = six.moves.zip_longest(args, annotations) + for i, (arg, annotation) in enumerate(packed): if isinstance(arg, Tuple): values.append('(%s)' % _format_args(arg.elts)) else: argname = arg.name - if annotation is not None: - argname += ': ' + annotation.name + if annotation is not None: + argname += ':' + annotation.as_string() values.append(argname) if defaults is not None and i >= default_offset: diff --git a/astroid/tests/unittest_python3.py b/astroid/tests/unittest_python3.py index 4bb2fc8..187b212 100644 --- a/astroid/tests/unittest_python3.py +++ b/astroid/tests/unittest_python3.py @@ -21,7 +21,7 @@ import unittest from astroid.node_classes import Assign, Discard, YieldFrom, Name, Const from astroid.builder import AstroidBuilder from astroid.scoped_nodes import Class, Function -from astroid.test_utils import require_version +from astroid.test_utils import require_version, extract_node class Python3TC(unittest.TestCase): @@ -213,6 +213,18 @@ class Python3TC(unittest.TestCase): self.assertEqual(func.args.annotations[1].name, 'str') self.assertIsNone(func.returns) + @require_version('3.0') + def test_annotation_as_string(self): + code1 = dedent(''' + def test(a, b:int=4, c=2, f:'lala'=4)->2: + pass''') + code2 = dedent(''' + def test(a:typing.Generic[T], c:typing.Any=24)->typing.Iterable: + pass''') + for code in (code1, code2): + func = extract_node(code) + self.assertEqual(func.as_string(), code) + if __name__ == '__main__': unittest.main() |