summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog3
-rw-r--r--astroid/as_string.py5
-rw-r--r--astroid/node_classes.py5
-rw-r--r--astroid/protocols.py4
-rw-r--r--astroid/rebuilder.py3
-rw-r--r--astroid/tests/unittest_python3.py63
6 files changed, 81 insertions, 2 deletions
diff --git a/ChangeLog b/ChangeLog
index 568d2beb..f895040d 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -177,6 +177,9 @@ Change log for the astroid package (used to be astng)
* Starred expressions are now inferred correctly for tuple,
list, set, and dictionary literals.
+ * Support for asynchronous comprehensions introduced in Python 3.6.
+
+ Fixes #399. See PEP530 for details.
2015-11-29 -- 1.4.1
diff --git a/astroid/as_string.py b/astroid/as_string.py
index fc4c8e4c..85afdc30 100644
--- a/astroid/as_string.py
+++ b/astroid/as_string.py
@@ -503,6 +503,11 @@ class AsStringVisitor3(AsStringVisitor):
def visit_formattedvalue(self, node):
return '{%s}' % node.value.accept(self)
+ def visit_comprehension(self, node):
+ """return an astroid.Comprehension node as string"""
+ return '%s%s' % ('async ' if node.is_async else '',
+ super(AsStringVisitor3, self).visit_comprehension(node))
+
def _import_string(names):
"""return a list of (name, asname) formatted as a string"""
diff --git a/astroid/node_classes.py b/astroid/node_classes.py
index b3a42ad2..cfd60042 100644
--- a/astroid/node_classes.py
+++ b/astroid/node_classes.py
@@ -1210,19 +1210,22 @@ class Compare(NodeNG):
class Comprehension(NodeNG):
"""class representing a Comprehension node"""
_astroid_fields = ('target', 'iter', 'ifs')
+ _other_fields = ('is_async',)
target = None
iter = None
ifs = None
+ is_async = None
def __init__(self, parent=None):
super(Comprehension, self).__init__()
self.parent = parent
# pylint: disable=redefined-builtin; same name as builtin ast module.
- def postinit(self, target=None, iter=None, ifs=None):
+ def postinit(self, target=None, iter=None, ifs=None, is_async=None):
self.target = target
self.iter = iter
self.ifs = ifs
+ self.is_async = is_async
optional_assign = True
def assign_type(self):
diff --git a/astroid/protocols.py b/astroid/protocols.py
index a6e6a97c..d1abdea0 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -242,6 +242,10 @@ def _resolve_looppart(parts, asspath, context):
@decorators.raise_if_nothing_inferred
def for_assigned_stmts(self, node=None, context=None, asspath=None):
+ if isinstance(self, nodes.AsyncFor) or getattr(self, 'is_async', False):
+ # Skip inferring of async code for now
+ raise StopIteration(dict(node=self, unknown=node,
+ assign_path=asspath, context=context))
if asspath is None:
for lst in self.iter.infer(context):
if isinstance(lst, (nodes.Tuple, nodes.List)):
diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py
index f66809bc..115a8d45 100644
--- a/astroid/rebuilder.py
+++ b/astroid/rebuilder.py
@@ -358,7 +358,8 @@ class TreeRebuilder(object):
newnode.postinit(self.visit(node.target, newnode),
self.visit(node.iter, newnode),
[self.visit(child, newnode)
- for child in node.ifs])
+ for child in node.ifs],
+ getattr(node, 'is_async', None))
return newnode
def visit_decorators(self, node, parent):
diff --git a/astroid/tests/unittest_python3.py b/astroid/tests/unittest_python3.py
index 671ea011..ea817272 100644
--- a/astroid/tests/unittest_python3.py
+++ b/astroid/tests/unittest_python3.py
@@ -267,7 +267,70 @@ class Python3TC(unittest.TestCase):
self.assertIsInstance(inferred, nodes.Const)
self.assertEqual(inferred.value, expected)
+ @require_version('3.6')
+ def test_async_comprehensions(self):
+ async_comprehensions = [
+ extract_node("async def f(): return __([i async for i in aiter() if i % 2])"),
+ extract_node("async def f(): return __({i async for i in aiter() if i % 2})"),
+ extract_node("async def f(): return __((i async for i in aiter() if i % 2))"),
+ extract_node("async def f(): return __({i: i async for i in aiter() if i % 2})")
+ ]
+ non_async_comprehensions = [
+ extract_node("async def f(): return __({i: i for i in iter() if i % 2})")
+ ]
+
+ for comp in async_comprehensions:
+ self.assertTrue(comp.generators[0].is_async)
+ for comp in non_async_comprehensions:
+ self.assertFalse(comp.generators[0].is_async)
+
+ @require_version('3.7')
+ def test_async_comprehensions_outside_coroutine(self):
+ # When async and await will become keywords, async comprehensions
+ # will be allowed outside of coroutines body
+ comprehensions = [
+ "[i async for i in aiter() if condition(i)]",
+ "[await fun() for fun in funcs]",
+ "{await fun() for fun in funcs}",
+ "{fun: await fun() for fun in funcs}",
+ "[await fun() for fun in funcs if await smth]",
+ "{await fun() for fun in funcs if await smth}",
+ "{fun: await fun() for fun in funcs if await smth}",
+ "[await fun() async for fun in funcs]",
+ "{await fun() async for fun in funcs}",
+ "{fun: await fun() async for fun in funcs}",
+ "[await fun() async for fun in funcs if await smth]",
+ "{await fun() async for fun in funcs if await smth}",
+ "{fun: await fun() async for fun in funcs if await smth}",
+ ]
+ for comp in comprehensions:
+ node = extract_node(comp)
+ self.assertTrue(node.generators[0].is_async)
+
+ @require_version('3.6')
+ def test_async_comprehensions_as_string(self):
+ func_bodies = [
+ "return [i async for i in aiter() if condition(i)]",
+ "return [await fun() for fun in funcs]",
+ "return {await fun() for fun in funcs}",
+ "return {fun: await fun() for fun in funcs}",
+ "return [await fun() for fun in funcs if await smth]",
+ "return {await fun() for fun in funcs if await smth}",
+ "return {fun: await fun() for fun in funcs if await smth}",
+ "return [await fun() async for fun in funcs]",
+ "return {await fun() async for fun in funcs}",
+ "return {fun: await fun() async for fun in funcs}",
+ "return [await fun() async for fun in funcs if await smth]",
+ "return {await fun() async for fun in funcs if await smth}",
+ "return {fun: await fun() async for fun in funcs if await smth}",
+ ]
+ for func_body in func_bodies:
+ code = dedent('''
+ async def f():
+ {}'''.format(func_body))
+ func = extract_node(code)
+ self.assertEqual(func.as_string().strip(), code.strip())
if __name__ == '__main__':