summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <cpopa@cloudbasesolutions.com>2015-07-01 03:28:04 +0300
committerClaudiu Popa <cpopa@cloudbasesolutions.com>2015-07-01 03:28:04 +0300
commit47cc2e4a103afb185190995b6518b4f169baea8f (patch)
tree92f6511458d84ac1f199aea03ebd8b6687ba1242
parentb1a5e534e68324dda2e2fd396e8d91e4daf0ed59 (diff)
downloadastroid-47cc2e4a103afb185190995b6518b4f169baea8f.tar.gz
Add annotation support for function.as_string(). Closes issue #37.
-rw-r--r--ChangeLog2
-rw-r--r--astroid/as_string.py18
-rw-r--r--astroid/node_classes.py10
-rw-r--r--astroid/tests/unittest_python3.py14
4 files changed, 35 insertions, 9 deletions
diff --git a/ChangeLog b/ChangeLog
index 8c34d52..cfef731 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -210,6 +210,8 @@ Change log for the astroid package (used to be astng)
* Star unpacking in assignments returns properly a list,
not the individual components. Closes issue #138.
+ * Add annotation support for function.as_string(). Closes issue #37.
+
2015-03-14 -- 1.3.6
diff --git a/astroid/as_string.py b/astroid/as_string.py
index b9ec447..01eb718 100644
--- a/astroid/as_string.py
+++ b/astroid/as_string.py
@@ -25,6 +25,8 @@
import sys
+import six
+
INDENT = ' ' # 4 spaces ; keep indentation variable
@@ -269,10 +271,18 @@ class AsStringVisitor(object):
decorate = node.decorators and node.decorators.accept(self) or ''
docs = node.doc and '\n%s"""%s"""' % (INDENT, node.doc) or ''
return_annotation = ''
- if node.returns:
- return_annotation = ' -> ' + node.returns.name
- return '\n%sdef %s(%s):%s%s\n%s' % (decorate, node.name, node.args.accept(self),
- return_annotation, docs, self._stmt_list(node.body))
+ if six.PY3 and node.returns:
+ return_annotation = '->' + node.returns.as_string()
+ trailer = return_annotation + ":"
+ else:
+ trailer = ":"
+ def_format = "\n{decorators}def {name}({args}){trailer}{docs}\n{body}"
+ return def_format.format(decorators=decorate,
+ name=node.name,
+ args=node.args.accept(self),
+ trailer=trailer,
+ docs=docs,
+ body=self._stmt_list(node.body))
def visit_genexpr(self, node):
"""return an astroid.GenExpr node as string"""
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index e5a7167..bf3a9e4 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -317,7 +317,8 @@ class Arguments(NodeNG, AssignTypeMixin):
result = []
if self.args:
result.append(
- _format_args(self.args, self.annotations, self.defaults)
+ _format_args(self.args, getattr(self, 'annotations', None),
+ self.defaults)
)
if self.vararg:
result.append('*%s' % self.vararg)
@@ -385,13 +386,14 @@ def _format_args(args, annotations=None, defaults=None):
annotations = []
if defaults is not None:
default_offset = len(args) - len(defaults)
- for i, (arg, annotation) in enumerate(zip(args, annotations)):
+ packed = six.moves.zip_longest(args, annotations)
+ for i, (arg, annotation) in enumerate(packed):
if isinstance(arg, Tuple):
values.append('(%s)' % _format_args(arg.elts))
else:
argname = arg.name
- if annotation is not None:
- argname += ': ' + annotation.name
+ if annotation is not None:
+ argname += ':' + annotation.as_string()
values.append(argname)
if defaults is not None and i >= default_offset:
diff --git a/astroid/tests/unittest_python3.py b/astroid/tests/unittest_python3.py
index 4bb2fc8..187b212 100644
--- a/astroid/tests/unittest_python3.py
+++ b/astroid/tests/unittest_python3.py
@@ -21,7 +21,7 @@ import unittest
from astroid.node_classes import Assign, Discard, YieldFrom, Name, Const
from astroid.builder import AstroidBuilder
from astroid.scoped_nodes import Class, Function
-from astroid.test_utils import require_version
+from astroid.test_utils import require_version, extract_node
class Python3TC(unittest.TestCase):
@@ -213,6 +213,18 @@ class Python3TC(unittest.TestCase):
self.assertEqual(func.args.annotations[1].name, 'str')
self.assertIsNone(func.returns)
+ @require_version('3.0')
+ def test_annotation_as_string(self):
+ code1 = dedent('''
+ def test(a, b:int=4, c=2, f:'lala'=4)->2:
+ pass''')
+ code2 = dedent('''
+ def test(a:typing.Generic[T], c:typing.Any=24)->typing.Iterable:
+ pass''')
+ for code in (code1, code2):
+ func = extract_node(code)
+ self.assertEqual(func.as_string(), code)
+
if __name__ == '__main__':
unittest.main()