summaryrefslogtreecommitdiff
path: root/Lib/test/test_coroutines.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_coroutines.py')
-rw-r--r--Lib/test/test_coroutines.py489
1 files changed, 398 insertions, 91 deletions
diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py
index 4a327b5ba9..78439a2aca 100644
--- a/Lib/test/test_coroutines.py
+++ b/Lib/test/test_coroutines.py
@@ -69,55 +69,130 @@ def silence_coro_gc():
class AsyncBadSyntaxTest(unittest.TestCase):
def test_badsyntax_1(self):
- with self.assertRaisesRegex(SyntaxError, "'await' outside"):
- import test.badsyntax_async1
+ samples = [
+ """def foo():
+ await something()
+ """,
- def test_badsyntax_2(self):
- with self.assertRaisesRegex(SyntaxError, "'await' outside"):
- import test.badsyntax_async2
+ """await something()""",
- def test_badsyntax_3(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async3
+ """async def foo():
+ yield from []
+ """,
- def test_badsyntax_4(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async4
+ """async def foo():
+ await await fut
+ """,
- def test_badsyntax_5(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async5
+ """async def foo(a=await something()):
+ pass
+ """,
- def test_badsyntax_6(self):
- with self.assertRaisesRegex(
- SyntaxError, "'yield' inside async function"):
+ """async def foo(a:await something()):
+ pass
+ """,
- import test.badsyntax_async6
+ """async def foo():
+ def bar():
+ [i async for i in els]
+ """,
- def test_badsyntax_7(self):
- with self.assertRaisesRegex(
- SyntaxError, "'yield from' inside async function"):
+ """async def foo():
+ def bar():
+ [await i for i in els]
+ """,
- import test.badsyntax_async7
+ """async def foo():
+ def bar():
+ [i for i in els
+ async for b in els]
+ """,
- def test_badsyntax_8(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async8
+ """async def foo():
+ def bar():
+ [i for i in els
+ for c in b
+ async for b in els]
+ """,
- def test_badsyntax_9(self):
- ns = {}
- for comp in {'(await a for a in b)',
- '[await a for a in b]',
- '{await a for a in b}',
- '{await a: c for a in b}'}:
+ """async def foo():
+ def bar():
+ [i for i in els
+ async for b in els
+ for c in b]
+ """,
- with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'):
- exec('async def f():\n\t{}'.format(comp), ns, ns)
+ """async def foo():
+ def bar():
+ [i for i in els
+ for b in await els]
+ """,
- def test_badsyntax_10(self):
- # Tests for issue 24619
+ """async def foo():
+ def bar():
+ [i for i in els
+ for b in els
+ if await b]
+ """,
+
+ """async def foo():
+ def bar():
+ [i for i in await els]
+ """,
+
+ """async def foo():
+ def bar():
+ [i for i in els if await i]
+ """,
+
+ """def bar():
+ [i async for i in els]
+ """,
+
+ """def bar():
+ [await i for i in els]
+ """,
+
+ """def bar():
+ [i for i in els
+ async for b in els]
+ """,
+
+ """def bar():
+ [i for i in els
+ for c in b
+ async for b in els]
+ """,
+
+ """def bar():
+ [i for i in els
+ async for b in els
+ for c in b]
+ """,
+
+ """def bar():
+ [i for i in els
+ for b in await els]
+ """,
+
+ """def bar():
+ [i for i in els
+ for b in els
+ if await b]
+ """,
+
+ """def bar():
+ [i for i in await els]
+ """,
+
+ """def bar():
+ [i for i in els if await i]
+ """,
+
+ """async def foo():
+ await
+ """,
- samples = [
"""async def foo():
def bar(): pass
await = 1
@@ -283,57 +358,110 @@ class AsyncBadSyntaxTest(unittest.TestCase):
with self.subTest(code=code), self.assertRaises(SyntaxError):
compile(code, "<test>", "exec")
- def test_goodsyntax_1(self):
- # Tests for issue 24619
+ def test_badsyntax_2(self):
+ samples = [
+ """def foo():
+ await = 1
+ """,
+
+ """class Bar:
+ def async(): pass
+ """,
- def foo(await):
- async def foo(): pass
- async def foo():
+ """class Bar:
+ async = 1
+ """,
+
+ """class async:
pass
- return await + 1
- self.assertEqual(foo(10), 11)
+ """,
- def foo(await):
- async def foo(): pass
- async def foo(): pass
- return await + 2
- self.assertEqual(foo(20), 22)
+ """class await:
+ pass
+ """,
- def foo(await):
+ """import math as await""",
- async def foo(): pass
+ """def async():
+ pass""",
- async def foo(): pass
+ """def foo(*, await=1):
+ pass"""
- return await + 2
- self.assertEqual(foo(20), 22)
+ """async = 1""",
- def foo(await):
- """spam"""
- async def foo(): \
- pass
- # 123
- async def foo(): pass
- # 456
- return await + 2
- self.assertEqual(foo(20), 22)
-
- def foo(await):
- def foo(): pass
- def foo(): pass
- async def bar(): return await_
- await_ = await
- try:
- bar().send(None)
- except StopIteration as ex:
- return ex.args[0]
- self.assertEqual(foo(42), 42)
+ """print(await=1)"""
+ ]
- async def f():
- async def g(): pass
- await z
- await = 1
- self.assertTrue(inspect.iscoroutinefunction(f))
+ for code in samples:
+ with self.subTest(code=code), self.assertWarnsRegex(
+ DeprecationWarning,
+ "'await' will become reserved keywords"):
+ compile(code, "<test>", "exec")
+
+ def test_badsyntax_3(self):
+ with self.assertRaises(DeprecationWarning):
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ compile("async = 1", "<test>", "exec")
+
+ def test_goodsyntax_1(self):
+ # Tests for issue 24619
+
+ samples = [
+ '''def foo(await):
+ async def foo(): pass
+ async def foo():
+ pass
+ return await + 1
+ ''',
+
+ '''def foo(await):
+ async def foo(): pass
+ async def foo(): pass
+ return await + 1
+ ''',
+
+ '''def foo(await):
+
+ async def foo(): pass
+
+ async def foo(): pass
+
+ return await + 1
+ ''',
+
+ '''def foo(await):
+ """spam"""
+ async def foo(): \
+ pass
+ # 123
+ async def foo(): pass
+ # 456
+ return await + 1
+ ''',
+
+ '''def foo(await):
+ def foo(): pass
+ def foo(): pass
+ async def bar(): return await_
+ await_ = await
+ try:
+ bar().send(None)
+ except StopIteration as ex:
+ return ex.args[0] + 1
+ '''
+ ]
+
+ for code in samples:
+ with self.subTest(code=code):
+ loc = {}
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ exec(code, loc, loc)
+
+ self.assertEqual(loc['foo'](10), 11)
class TokenizerRegrTest(unittest.TestCase):
@@ -906,7 +1034,7 @@ class CoroutineTest(unittest.TestCase):
return await Awaitable()
with self.assertRaisesRegex(
- TypeError, "__await__\(\) returned a coroutine"):
+ TypeError, r"__await__\(\) returned a coroutine"):
run_async(foo())
@@ -1270,7 +1398,7 @@ class CoroutineTest(unittest.TestCase):
buffer = []
async def test1():
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
+ with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i1, i2 in AsyncIter():
buffer.append(i1 + i2)
@@ -1284,7 +1412,7 @@ class CoroutineTest(unittest.TestCase):
buffer = []
async def test2():
nonlocal buffer
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
+ with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i in AsyncIter():
buffer.append(i[0])
if i[0] == 20:
@@ -1303,7 +1431,7 @@ class CoroutineTest(unittest.TestCase):
buffer = []
async def test3():
nonlocal buffer
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
+ with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i in AsyncIter():
if i[0] > 20:
continue
@@ -1348,7 +1476,7 @@ class CoroutineTest(unittest.TestCase):
with self.assertRaisesRegex(
TypeError,
- "async for' received an invalid object.*__aiter.*\: I"):
+ r"async for' received an invalid object.*__aiter.*\: I"):
run_async(foo())
@@ -1386,7 +1514,7 @@ class CoroutineTest(unittest.TestCase):
return 123
async def foo():
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
+ with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i in I():
print('never going to happen')
@@ -1495,7 +1623,7 @@ class CoroutineTest(unittest.TestCase):
1/0
async def foo():
nonlocal CNT
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
+ with self.assertWarnsRegex(DeprecationWarning, "legacy"):
async for i in AI():
CNT += 1
CNT += 10
@@ -1522,7 +1650,7 @@ class CoroutineTest(unittest.TestCase):
self.assertEqual(CNT, 0)
def test_for_9(self):
- # Test that PendingDeprecationWarning can safely be converted into
+ # Test that DeprecationWarning can safely be converted into
# an exception (__aiter__ should not have a chance to raise
# a ZeroDivisionError.)
class AI:
@@ -1532,13 +1660,13 @@ class CoroutineTest(unittest.TestCase):
async for i in AI():
pass
- with self.assertRaises(PendingDeprecationWarning):
+ with self.assertRaises(DeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
run_async(foo())
def test_for_10(self):
- # Test that PendingDeprecationWarning can safely be converted into
+ # Test that DeprecationWarning can safely be converted into
# an exception.
class AI:
async def __aiter__(self):
@@ -1547,7 +1675,7 @@ class CoroutineTest(unittest.TestCase):
async for i in AI():
pass
- with self.assertRaises(PendingDeprecationWarning):
+ with self.assertRaises(DeprecationWarning):
with warnings.catch_warnings():
warnings.simplefilter("error")
run_async(foo())
@@ -1598,6 +1726,185 @@ class CoroutineTest(unittest.TestCase):
foo().send(None)
self.assertEqual(result, [42])
+ def test_comp_1(self):
+ async def f(i):
+ return i
+
+ async def run_list():
+ return [await c for c in [f(1), f(41)]]
+
+ async def run_set():
+ return {await c for c in [f(1), f(41)]}
+
+ async def run_dict1():
+ return {await c: 'a' for c in [f(1), f(41)]}
+
+ async def run_dict2():
+ return {i: await c for i, c in enumerate([f(1), f(41)])}
+
+ self.assertEqual(run_async(run_list()), ([], [1, 41]))
+ self.assertEqual(run_async(run_set()), ([], {1, 41}))
+ self.assertEqual(run_async(run_dict1()), ([], {1: 'a', 41: 'a'}))
+ self.assertEqual(run_async(run_dict2()), ([], {0: 1, 1: 41}))
+
+ def test_comp_2(self):
+ async def f(i):
+ return i
+
+ async def run_list():
+ return [s for c in [f(''), f('abc'), f(''), f(['de', 'fg'])]
+ for s in await c]
+
+ self.assertEqual(
+ run_async(run_list()),
+ ([], ['a', 'b', 'c', 'de', 'fg']))
+
+ async def run_set():
+ return {d
+ for c in [f([f([10, 30]),
+ f([20])])]
+ for s in await c
+ for d in await s}
+
+ self.assertEqual(
+ run_async(run_set()),
+ ([], {10, 20, 30}))
+
+ async def run_set2():
+ return {await s
+ for c in [f([f(10), f(20)])]
+ for s in await c}
+
+ self.assertEqual(
+ run_async(run_set2()),
+ ([], {10, 20}))
+
+ def test_comp_3(self):
+ async def f(it):
+ for i in it:
+ yield i
+
+ async def run_list():
+ return [i + 1 async for i in f([10, 20])]
+ self.assertEqual(
+ run_async(run_list()),
+ ([], [11, 21]))
+
+ async def run_set():
+ return {i + 1 async for i in f([10, 20])}
+ self.assertEqual(
+ run_async(run_set()),
+ ([], {11, 21}))
+
+ async def run_dict():
+ return {i + 1: i + 2 async for i in f([10, 20])}
+ self.assertEqual(
+ run_async(run_dict()),
+ ([], {11: 12, 21: 22}))
+
+ async def run_gen():
+ gen = (i + 1 async for i in f([10, 20]))
+ return [g + 100 async for g in gen]
+ self.assertEqual(
+ run_async(run_gen()),
+ ([], [111, 121]))
+
+ def test_comp_4(self):
+ async def f(it):
+ for i in it:
+ yield i
+
+ async def run_list():
+ return [i + 1 async for i in f([10, 20]) if i > 10]
+ self.assertEqual(
+ run_async(run_list()),
+ ([], [21]))
+
+ async def run_set():
+ return {i + 1 async for i in f([10, 20]) if i > 10}
+ self.assertEqual(
+ run_async(run_set()),
+ ([], {21}))
+
+ async def run_dict():
+ return {i + 1: i + 2 async for i in f([10, 20]) if i > 10}
+ self.assertEqual(
+ run_async(run_dict()),
+ ([], {21: 22}))
+
+ async def run_gen():
+ gen = (i + 1 async for i in f([10, 20]) if i > 10)
+ return [g + 100 async for g in gen]
+ self.assertEqual(
+ run_async(run_gen()),
+ ([], [121]))
+
+ def test_comp_5(self):
+ async def f(it):
+ for i in it:
+ yield i
+
+ async def run_list():
+ return [i + 1 for pair in ([10, 20], [30, 40]) if pair[0] > 10
+ async for i in f(pair) if i > 30]
+ self.assertEqual(
+ run_async(run_list()),
+ ([], [41]))
+
+ def test_comp_6(self):
+ async def f(it):
+ for i in it:
+ yield i
+
+ async def run_list():
+ return [i + 1 async for seq in f([(10, 20), (30,)])
+ for i in seq]
+
+ self.assertEqual(
+ run_async(run_list()),
+ ([], [11, 21, 31]))
+
+ def test_comp_7(self):
+ async def f():
+ yield 1
+ yield 2
+ raise Exception('aaa')
+
+ async def run_list():
+ return [i async for i in f()]
+
+ with self.assertRaisesRegex(Exception, 'aaa'):
+ run_async(run_list())
+
+ def test_comp_8(self):
+ async def f():
+ return [i for i in [1, 2, 3]]
+
+ self.assertEqual(
+ run_async(f()),
+ ([], [1, 2, 3]))
+
+ def test_comp_9(self):
+ async def gen():
+ yield 1
+ yield 2
+ async def f():
+ l = [i async for i in gen()]
+ return [i for i in l]
+
+ self.assertEqual(
+ run_async(f()),
+ ([], [1, 2]))
+
+ def test_comp_10(self):
+ async def f():
+ xx = {i for i in [1, 2, 3]}
+ return {x: x for x in xx}
+
+ self.assertEqual(
+ run_async(f()),
+ ([], {1: 1, 2: 2, 3: 3}))
+
def test_copy(self):
async def func(): pass
coro = func()
@@ -1728,8 +2035,8 @@ class SysSetCoroWrapperTest(unittest.TestCase):
try:
with silence_coro_gc(), self.assertRaisesRegex(
RuntimeError,
- "coroutine wrapper.*\.wrapper at 0x.*attempted to "
- "recursively wrap .* wrap .*"):
+ r"coroutine wrapper.*\.wrapper at 0x.*attempted to "
+ r"recursively wrap .* wrap .*"):
foo()
finally: